# -*- coding: utf-8 -*- # @Author: Weisen Pan import argparse import os import torch from dataset import factory from params import train_params from fedml_service.data_cleaning.cifar10.data_loader import load_partition_data_cifar10 from fedml_service.data_cleaning.cifar100.data_loader import load_partition_data_cifar100 from fedml_service.data_cleaning.skin_dataset.data_loader import load_partition_data_skin_dataset from fedml_service.data_cleaning.pillbase.data_loader import load_partition_data_pillbase from fedml_service.model.cv.resnet_gkt.resnet import wide_resnet16_8_gkt, wide_resnet_model_50_2_gkt, resnet110_gkt from fedml_service.decentralized.fedgkt.GKTTrainer import GKTTrainer from fedml_service.decentralized.fedgkt.GKTServerTrainer import GKTServerTrainer from params.train_params import save_hp_to_json from config import HOME from tensorboardX import SummaryWriter # Set CUDA device to be used for training os.environ["CUDA_VISIBLE_DEVICES"] = "0" device = torch.device("cuda:0") # Initialize TensorBoard writers for logging def initialize_writers(args): log_dir = os.path.join(args.model_dir, 'val') # Create a log directory inside the model directory return SummaryWriter(log_dir=log_dir) # Initialize SummaryWriter for TensorBoard logging # Initialize dataset and data loaders def initialize_dataset(args, data_split_factor): # Fetch training data and sampler based on various input parameters train_data_local_dict, train_sampler = factory.obtain_data_loader( args.data, split_factor=data_split_factor, batch_size=args.batch_size, crop_size=args.crop_size, dataset=args.dataset, split="train", # Split data for training is_decentralized=args.is_decentralized, is_autoaugment=args.is_autoaugment, randaa=args.randaa, is_cutout=args.is_cutout, erase_p=args.erase_p, num_workers=args.workers, is_fed=args.is_fed, num_clusters=args.num_clusters, cifar10_non_iid=args.cifar10_non_iid, cifar100_non_iid=args.cifar100_non_iid ) # Fetch global test data test_data_global = factory.obtain_data_loader( args.data, batch_size=args.eval_batch_size, crop_size=args.crop_size, dataset=args.dataset, split="val", # Split data for validation num_workers=args.workers, cifar10_non_iid=args.cifar10_non_iid, cifar100_non_iid=args.cifar100_non_iid ) return train_data_local_dict, test_data_global # Return both train and test data loaders # Setup models based on the dataset def setup_models(args): if args.dataset == "cifar10": return load_partition_data_cifar10, wide_resnet16_8_gkt() # Model for CIFAR-10 elif args.dataset == "cifar100": return load_partition_data_cifar100, resnet110_gkt() # Model for CIFAR-100 elif args.dataset == "skin_dataset": return load_partition_data_skin_dataset, wide_resnet_model_50_2_gkt() # Model for skin dataset elif args.dataset == "pill_base": return load_partition_data_pillbase, wide_resnet_model_50_2_gkt() # Model for pill base dataset else: raise ValueError(f"Unsupported dataset: {args.dataset}") # Raise error for unsupported dataset # Initialize trainers for each client in the federated learning setup def initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict): client_trainers = [] # Initialize a trainer for each client for i in range(client_number): trainer = GKTTrainer( client_idx=i, train_data_local_dict=train_data_local_dict, test_data_local_dict=test_data_local_dict, device=device, model_client=model_client, args=args ) client_trainers.append(trainer) # Add client trainer to the list return client_trainers # Main function to initialize and run the federated learning process def main(args): args.model_dir = os.path.join(str(HOME), "models/coremodel", str(args.spid)) # Set model directory based on home directory and spid # Save hyperparameters if not in summary or evaluation mode if not args.is_summary and not args.evaluate: save_hp_to_json(args) # Initialize the TensorBoard writer for logging val_writer = initialize_writers(args) data_split_factor = args.loop_factor if args.is_diff_data_train else 1 # Set data split factor based on training mode args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Check if decentralized learning is needed print(f"INFO: PyTorch: => The number of views of train data is '{data_split_factor}'") # Load dataset and initialize data loaders train_data_local_dict, test_data_global = initialize_dataset(args, data_split_factor) # Setup models for the clients and server data_loader, (model_client, model_server) = setup_models(args) client_number = args.num_clusters * args.split_factor # Calculate the number of clients # Load data for federated learning train_data_num, test_data_num, train_data_global, _, _, _, test_data_local_dict, class_num = data_loader( args.dataset, args.data, 'homo', 0.5, client_number, args.batch_size ) dataset_info = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_dict, test_data_local_dict, class_num] print("Server and clients initialized.") round_idx = 0 # Initialize the training round index # Initialize client trainers and server trainer client_trainers = initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict) server_trainer = GKTServerTrainer(client_number, device, model_server, args, val_writer) # Start federated training rounds for current_round in range(args.num_rounds): # For each client, perform local training and send results to the server for client_idx in range(client_number): extracted_features, logits, labels, test_features, test_labels = client_trainers[client_idx].train() print(f"Client {client_idx} finished training.") server_trainer.add_local_trained_result(client_idx, extracted_features, logits, labels, test_features, test_labels) # Check if the server has received all clients' results if server_trainer.check_whether_all_receive(): print("All clients' results received by server.") server_trainer.train(round_idx) # Server performs training using the aggregated results round_idx += 1 # Send global model updates back to clients for client_idx in range(client_number): global_logits = server_trainer.get_global_logits(client_idx) client_trainers[client_idx].update_large_model_logits(global_logits) print("Server sent updated logits back to clients.") # Entry point of the script if __name__ == "__main__": parser = argparse.ArgumentParser() args = train_params.add_parser_params(parser) # Ensure that federated learning mode is enabled assert args.is_fed == 1, "Federated learning requires 'args.is_fed' to be set to 1." # Create the model directory if it does not exist os.makedirs(args.model_dir, exist_ok=True) print(args) # Print the parsed arguments for verification main(args) # Start the main process