# -*- coding: utf-8 -*- # @Author: Weisen Pan import os import warnings import torch import torch.cuda.amp as autocast from torch import nn from torch.backends import cudnn from tensorboardX import SummaryWriter from config import * from params import train_settings from utils import label_smooth, metrics, scheduler, prefetch_loader from model import net_splitter from dataset import data_factory import numpy as np from tqdm import tqdm from params.train_settings import save_hyperparams_to_json # Set the visible GPU to use for training os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Variable to store the best accuracy achieved during training best_accuracy = 0 # Helper function to compute the average of a list def compute_average(lst): return sum(lst) / len(lst) # Main function to initialize the training process def main(args): if args.gpu_index is not None: # Warn if a specific GPU is selected, disabling data parallelism warnings.warn("Specific GPU chosen, disabling data parallelism.") # Adjust loop factor based on training setup args.loop_factor = 1 if args.separate_training or args.single_branch else args.split_factor # Determine if decentralized training is required args.decentralized_training = args.world_size > 1 or args.multiprocessing_decentralized num_gpus = torch.cuda.device_count() args.num_gpus = num_gpus # If decentralized multiprocessing is enabled, spawn multiple processes if args.multiprocessing_decentralized: args.world_size = num_gpus * args.world_size torch.multiprocessing.spawn(worker_process, nprocs=num_gpus, args=(num_gpus, args)) else: # Otherwise, proceed with single-GPU training print(f"INFO:PyTorch: Detected {num_gpus} GPU(s) available.") args.gpu_index = 0 worker_process(args.gpu_index, num_gpus, args) # Client-side training function for federated learning updates def client_train_update(args, round_num, client_model, global_model, sched, opt, train_loader, epochs=5, scaler=None): client_model.train() for epoch in range(epochs): # Prefetch data for training loader = prefetch_loader.DataPrefetcher(train_loader) images, targets = loader.next() batch_idx = 0 opt.zero_grad() while images is not None: # Apply learning rate scheduling sched(opt, batch_idx) # Use automatic mixed precision if enabled if args.amp_enabled: with autocast.autocast(): ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train', epoch=epoch) else: ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train', epoch=epoch) # Compute accuracy for top-1 predictions batch_size = images.size(0) for j in range(args.loop_factor): top1_acc = metrics.accuracy(model_outputs[j], targets, topk=(1,)) # Compute the proximal term for FedProx loss prox_term = sum((param - global_param).norm(2) for param, global_param in zip(client_model.parameters(), global_model.parameters())) # Compute the total loss (cross-entropy + contrastive loss + proximal term) total_loss = (loss_ce + loss_cot) / args.accum_steps + (args.mu / 2) * prox_term # Backward pass with mixed precision scaling if enabled if args.amp_enabled: scaler.scale(total_loss).backward() if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)): scaler.step(opt) scaler.update() opt.zero_grad() else: total_loss.backward() if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)): opt.step() opt.zero_grad() images, targets = loader.next() return total_loss.item() # Function to aggregate model weights from clients on the server def server_compute_average_weights(global_model, client_models): global_state_dict = global_model.state_dict() # Average weights across all clients for key in global_state_dict.keys(): global_state_dict[key] = torch.stack( [client_models[i].state_dict()[key].float() for i in range(len(client_models))], 0).mean(0) global_model.load_state_dict(global_state_dict) # Update clients with the averaged global model for model in client_models: model.load_state_dict(global_model.state_dict()) # Function to validate the model on the validation set def validate_model(val_loader, model, args): model.eval() acc1_list, acc5_list, loss_ce_list = [], [], [] # Perform validation without gradient calculation with torch.no_grad(): for images, targets in val_loader: if args.gpu_index is not None: images, targets = images.cuda(args.gpu_index, non_blocking=True), targets.cuda(args.gpu_index, non_blocking=True) if args.amp_enabled: with autocast.autocast(): ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val') else: ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val') for j in range(args.loop_factor): acc1, acc5 = metrics.accuracy(model_outputs[j], targets, topk=(1, 5)) avg_acc1, avg_acc5 = metrics.accuracy(ensemble_out, targets, topk=(1, 5)) acc1_list.append(avg_acc1) acc5_list.append(avg_acc5) loss_ce_list.append(loss_ce) return compute_average(loss_ce_list), compute_average(acc1_list) # Function to handle the worker process for training on a specific GPU def worker_process(gpu_index, num_gpus, args): global best_accuracy args.gpu_index = gpu_index args.model_path = os.path.join(HOME, "models", "coremodel", str(args.model_id)) # Create summary writer for validation if not using decentralized training if not args.decentralized_training or (args.multiprocessing_decentralized and args.rank % num_gpus == 0): val_summary_writer = SummaryWriter(log_dir=os.path.join(args.model_path, 'validation')) # Set the loss function based on the label smoothing option criterion = label_smooth.smooth_ce_loss(reduction='mean') if args.use_label_smooth else nn.CrossEntropyLoss() # Initialize the global model and client models global_model = net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) client_models = [net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) for _ in range(args.num_clients)] # Save hyperparameters to JSON if required if args.save_summary: save_hyperparams_to_json(args) return # Move models to GPU global_model = global_model.cuda() for model in client_models: model.cuda() model.load_state_dict(global_model.state_dict()) # Create optimizers for each client opt_list = [torch.optim.SGD(client.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.use_nesterov) for client in client_models] # Initialize gradient scaler if AMP is enabled scaler = torch.cuda.amp.GradScaler() if args.amp_enabled else None cudnn.performance_test = True # Resume training from checkpoint if specified if args.resume_training: if os.path.isfile(args.resume_checkpoint): checkpoint = torch.load(args.resume_checkpoint, map_location=f'cuda:{args.gpu_index}' if args.gpu_index else None) args.start_round = checkpoint['round'] best_accuracy = checkpoint['best_acc1'] global_model.load_state_dict(checkpoint['state_dict']) for opt in opt_list: opt.load_state_dict(checkpoint['optimizer']) if "scaler" in checkpoint: scaler.load_state_dict(checkpoint['scaler']) for client_model in client_models: client_model.load_state_dict(global_model.state_dict()) else: args.start_round = 0 else: args.start_round = 0 # Load training and validation data train_loader, _ = data_factory.load_data(args.data_dir, args.batch_size, args.split_factor, dataset_name=args.dataset_name, split="train", num_workers=args.num_workers, decentralized=args.decentralized_training) val_loader = data_factory.load_data(args.data_dir, args.eval_batch_size, args.split_factor, dataset_name=args.dataset_name, split="val", num_workers=args.num_workers) # Federated learning rounds for round_num in range(args.start_round, args.num_rounds + 1): if args.fixed_cluster: # Select clients from fixed clusters for each round selected_clusters = np.random.permutation(args.num_clusters)[:args.num_clients] for i in tqdm(range(args.num_clients)): selected_clients = np.arange(start=selected_clusters[i] * args.split_factor, stop=(selected_clusters[i] + 1) * args.split_factor) for client in selected_clients: loss = client_train