# -*- coding: utf-8 -*- # @Author: Weisen Pan import torch import argparse import torch.nn as nn from config import * # Import configuration from params import train_params # Import training parameters from model import coremodel, coremodelsl # Import models from utils import ( # Import utility functions label_smoothing, norm, metric, lr_scheduler, prefetch, save_hp_to_json, profile, clever_format ) from dataset import factory # Import dataset factory # Specify the GPU to be used os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Global variable for tracking the best accuracy best_acc1 = 0 # Function to calculate the average of a list of values def average(values): """Calculate average of a list.""" return sum(values) / len(values) # Function to aggregate the models from multiple clients into a global model def merge_models(global_model_main, global_model_proxy, client_main_models, client_proxy_models): """Aggregates weights of the models using simple mean.""" # Get the state dictionaries for the global models global_main_state = global_model_main.state_dict() global_proxy_state = global_model_proxy.state_dict() # Aggregate the main client models by averaging the weights for key in global_main_state.keys(): global_main_state[key] = torch.stack([client.state_dict()[key].float() for client in client_main_models], 0).mean(0) global_model_main.load_state_dict(global_main_state) # Aggregate the proxy client models similarly for key in global_proxy_state.keys(): global_proxy_state[key] = torch.stack([client.state_dict()[key].float() for client in client_proxy_models], 0).mean(0) global_model_proxy.load_state_dict(global_proxy_state) # Synchronize the client models with the updated global model for client in client_main_models: client.load_state_dict(global_model_main.state_dict()) for client in client_proxy_models: client.load_state_dict(global_model_proxy.state_dict()) # Function to perform client-side training updates def client_update(args, round_idx, main_model, proxy_models, schedulers_main, schedulers_proxy, optimizers_main, optimizers_proxy, train_loader, epochs=5, streams=None): """Client-side training update.""" main_model.train() proxy_models.train() # Train for a given number of epochs for epoch in range(epochs): # Prefetch data for faster loading prefetcher = prefetch.data_prefetcher(train_loader) images, targets = prefetcher.next() batch_idx = 0 # Zero the gradients optimizers_main.zero_grad() optimizers_proxy.zero_grad() # Process each batch of data while images is not None: # Adjust learning rates using the scheduler schedulers_main(optimizers_main, batch_idx, round_idx) schedulers_proxy(optimizers_proxy, batch_idx, round_idx) # Forward pass for the main model outputs, y_a, y_b, lam = main_model(images, target=targets, mode='train', epoch=epoch, streams=streams) main_fx = [output.clone().detach().requires_grad_(True) for output in outputs] # Forward pass for the proxy model with outputs from the main model ensemble_output, proxy_outputs, ce_loss, cot_loss = proxy_models(main_fx, y_a, y_b, lam, target=targets, mode='train', epoch=epoch, streams=streams) # Calculate total loss and perform backpropagation total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate total_loss.backward() # Backpropagate gradients for the main model for j in range(len(main_fx)): outputs[j].backward(main_fx[j].grad) # Update the model weights periodically if batch_idx % args.iters_to_accumulate == 0 or batch_idx == len(train_loader): optimizers_main.step() optimizers_main.zero_grad() optimizers_proxy.step() optimizers_proxy.zero_grad() # Fetch the next batch of images images, targets = prefetcher.next() batch_idx += 1 return total_loss.item() # Function to validate the models on a validation set def validate(val_loader, main_model, proxy_models, args, streams=None): """Validation function to evaluate models.""" main_model.eval() proxy_models.eval() # Initialize metrics for accuracy tracking top1_metrics = [metric.AverageMeter(f"Acc@1_{i}", ":6.2f") for i in range(args.loop_factor)] acc1_list, acc5_list, ce_loss_list = [], [], [] # Disable gradient computation for validation with torch.no_grad(): for images, targets in val_loader: images, targets = images.cuda(), targets.cuda() # Forward pass for main model outputs = main_model(images, target=targets, mode='val') main_fx = [output.clone().detach().requires_grad_(True) for output in outputs] # Forward pass for proxy model ensemble_output, proxy_outputs, ce_loss = proxy_models(main_fx, target=targets, mode='val') # Calculate accuracy acc1, acc5 = metric.accuracy(ensemble_output, targets, topk=(1, 5)) acc1_list.append(acc1) acc5_list.append(acc5) ce_loss_list.append(ce_loss) # Calculate average metrics over the validation set avg_acc1 = average(acc1_list) avg_acc5 = average(acc5_list) avg_ce_loss = average(ce_loss_list) return avg_ce_loss, avg_acc1, top1_metrics # Main function to set up and start decentralized training def main(args): """Main function to set up decentralized training.""" # Set loop factor based on training configuration args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor # Determine if decentralized training is needed args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Get the number of GPUs available ngpus_per_node = torch.cuda.device_count() args.ngpus_per_node = ngpus_per_node # If using decentralized training with multiprocessing if args.multiprocessing_decentralized: args.world_size *= ngpus_per_node torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: # If not using multiprocessing, proceed with a single GPU args.gpu = 0 execute_worker_process(args.gpu, ngpus_per_node, args) # Main worker function to handle training with multiple GPUs or single GPU def execute_worker_process(gpu, ngpus_per_node, args): """Main worker function for multi-GPU or single-GPU training.""" global best_acc1 args.gpu = gpu # Set process title setproctitle.setproctitle(f"{args.proc_name}_EdgeFLite_rank{args.rank}") # Set the criterion for loss calculation if args.is_label_smoothing: criterion = label_smoothing.label_smoothing_CE(reduction='mean') else: criterion = nn.CrossEntropyLoss() # Create the main and proxy models for training main_model = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() proxy_model = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() # Initialize client models for federated learning client_main_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)] client_proxy_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)] # Synchronize client models with the global models for client in client_main_models: client.load_state_dict(main_model.state_dict()) for client in client_proxy_models: client.load_state_dict(proxy_model.state_dict()) # Load training and validation data train_loader = factory.obtain_data_loader(args.data, batch_size=args.batch_size, dataset=args.dataset, split="train", num_workers=args.workers) val_loader = factory.obtain_data_loader(args.data, batch_size=args.eval_batch_size, dataset=args.dataset, split="val", num_workers=args.workers) # Loop over training rounds for r in range(args.start_round, args.num_rounds + 1): # Update client models with new training data client_update(args, r, client_main_models, client_proxy_models, lr_scheduler.lr_scheduler, lr_scheduler.lr_scheduler, torch.optim.SGD, torch.optim.SGD, train_loader) # Validate the models test_loss, acc, top1 = validate(val_loader, main_model, proxy_model, args) # Track the best accuracy achieved best_acc1 = max(acc, best_acc1) # Entry point for the script if __name__ == '__main__': parser = argparse.ArgumentParser(description="Training EdgeFLite") args = train_params.add_parser_params(parser) main(args)