4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
162 lines
7.3 KiB
Python
162 lines
7.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import argparse
|
|
import os
|
|
import torch
|
|
from dataset import factory
|
|
from params import train_params
|
|
from fedml_service.data_cleaning.cifar10.data_loader import load_partition_data_cifar10
|
|
from fedml_service.data_cleaning.cifar100.data_loader import load_partition_data_cifar100
|
|
from fedml_service.data_cleaning.skin_dataset.data_loader import load_partition_data_skin_dataset
|
|
from fedml_service.data_cleaning.pillbase.data_loader import load_partition_data_pillbase
|
|
from fedml_service.model.cv.resnet_gkt.resnet import wide_resnet16_8_gkt, wide_resnet_model_50_2_gkt, resnet110_gkt
|
|
from fedml_service.decentralized.fedgkt.GKTTrainer import GKTTrainer
|
|
from fedml_service.decentralized.fedgkt.GKTServerTrainer import GKTServerTrainer
|
|
from params.train_params import save_hp_to_json
|
|
from config import HOME
|
|
from tensorboardX import SummaryWriter
|
|
|
|
# Set CUDA device to be used for training
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
device = torch.device("cuda:0")
|
|
|
|
# Initialize TensorBoard writers for logging
|
|
def initialize_writers(args):
|
|
log_dir = os.path.join(args.model_dir, 'val') # Create a log directory inside the model directory
|
|
return SummaryWriter(log_dir=log_dir) # Initialize SummaryWriter for TensorBoard logging
|
|
|
|
# Initialize dataset and data loaders
|
|
def initialize_dataset(args, data_split_factor):
|
|
# Fetch training data and sampler based on various input parameters
|
|
train_data_local_dict, train_sampler = factory.obtain_data_loader(
|
|
args.data,
|
|
split_factor=data_split_factor,
|
|
batch_size=args.batch_size,
|
|
crop_size=args.crop_size,
|
|
dataset=args.dataset,
|
|
split="train", # Split data for training
|
|
is_decentralized=args.is_decentralized,
|
|
is_autoaugment=args.is_autoaugment,
|
|
randaa=args.randaa,
|
|
is_cutout=args.is_cutout,
|
|
erase_p=args.erase_p,
|
|
num_workers=args.workers,
|
|
is_fed=args.is_fed,
|
|
num_clusters=args.num_clusters,
|
|
cifar10_non_iid=args.cifar10_non_iid,
|
|
cifar100_non_iid=args.cifar100_non_iid
|
|
)
|
|
|
|
# Fetch global test data
|
|
test_data_global = factory.obtain_data_loader(
|
|
args.data,
|
|
batch_size=args.eval_batch_size,
|
|
crop_size=args.crop_size,
|
|
dataset=args.dataset,
|
|
split="val", # Split data for validation
|
|
num_workers=args.workers,
|
|
cifar10_non_iid=args.cifar10_non_iid,
|
|
cifar100_non_iid=args.cifar100_non_iid
|
|
)
|
|
return train_data_local_dict, test_data_global # Return both train and test data loaders
|
|
|
|
# Setup models based on the dataset
|
|
def setup_models(args):
|
|
if args.dataset == "cifar10":
|
|
return load_partition_data_cifar10, wide_resnet16_8_gkt() # Model for CIFAR-10
|
|
elif args.dataset == "cifar100":
|
|
return load_partition_data_cifar100, resnet110_gkt() # Model for CIFAR-100
|
|
elif args.dataset == "skin_dataset":
|
|
return load_partition_data_skin_dataset, wide_resnet_model_50_2_gkt() # Model for skin dataset
|
|
elif args.dataset == "pill_base":
|
|
return load_partition_data_pillbase, wide_resnet_model_50_2_gkt() # Model for pill base dataset
|
|
else:
|
|
raise ValueError(f"Unsupported dataset: {args.dataset}") # Raise error for unsupported dataset
|
|
|
|
# Initialize trainers for each client in the federated learning setup
|
|
def initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict):
|
|
client_trainers = []
|
|
# Initialize a trainer for each client
|
|
for i in range(client_number):
|
|
trainer = GKTTrainer(
|
|
client_idx=i,
|
|
train_data_local_dict=train_data_local_dict,
|
|
test_data_local_dict=test_data_local_dict,
|
|
device=device,
|
|
model_client=model_client,
|
|
args=args
|
|
)
|
|
client_trainers.append(trainer) # Add client trainer to the list
|
|
return client_trainers
|
|
|
|
# Main function to initialize and run the federated learning process
|
|
def main(args):
|
|
args.model_dir = os.path.join(str(HOME), "models/coremodel", str(args.spid)) # Set model directory based on home directory and spid
|
|
|
|
# Save hyperparameters if not in summary or evaluation mode
|
|
if not args.is_summary and not args.evaluate:
|
|
save_hp_to_json(args)
|
|
|
|
# Initialize the TensorBoard writer for logging
|
|
val_writer = initialize_writers(args)
|
|
data_split_factor = args.loop_factor if args.is_diff_data_train else 1 # Set data split factor based on training mode
|
|
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Check if decentralized learning is needed
|
|
|
|
print(f"INFO: PyTorch: => The number of views of train data is '{data_split_factor}'")
|
|
|
|
# Load dataset and initialize data loaders
|
|
train_data_local_dict, test_data_global = initialize_dataset(args, data_split_factor)
|
|
|
|
# Setup models for the clients and server
|
|
data_loader, (model_client, model_server) = setup_models(args)
|
|
client_number = args.num_clusters * args.split_factor # Calculate the number of clients
|
|
|
|
# Load data for federated learning
|
|
train_data_num, test_data_num, train_data_global, _, _, _, test_data_local_dict, class_num = data_loader(
|
|
args.dataset, args.data, 'homo', 0.5, client_number, args.batch_size
|
|
)
|
|
|
|
dataset_info = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_dict, test_data_local_dict, class_num]
|
|
|
|
print("Server and clients initialized.")
|
|
round_idx = 0 # Initialize the training round index
|
|
|
|
# Initialize client trainers and server trainer
|
|
client_trainers = initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict)
|
|
server_trainer = GKTServerTrainer(client_number, device, model_server, args, val_writer)
|
|
|
|
# Start federated training rounds
|
|
for current_round in range(args.num_rounds):
|
|
# For each client, perform local training and send results to the server
|
|
for client_idx in range(client_number):
|
|
extracted_features, logits, labels, test_features, test_labels = client_trainers[client_idx].train()
|
|
print(f"Client {client_idx} finished training.")
|
|
server_trainer.add_local_trained_result(client_idx, extracted_features, logits, labels, test_features, test_labels)
|
|
|
|
# Check if the server has received all clients' results
|
|
if server_trainer.check_whether_all_receive():
|
|
print("All clients' results received by server.")
|
|
server_trainer.train(round_idx) # Server performs training using the aggregated results
|
|
round_idx += 1
|
|
|
|
# Send global model updates back to clients
|
|
for client_idx in range(client_number):
|
|
global_logits = server_trainer.get_global_logits(client_idx)
|
|
client_trainers[client_idx].update_large_model_logits(global_logits)
|
|
print("Server sent updated logits back to clients.")
|
|
|
|
# Entry point of the script
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
args = train_params.add_parser_params(parser)
|
|
|
|
# Ensure that federated learning mode is enabled
|
|
assert args.is_fed == 1, "Federated learning requires 'args.is_fed' to be set to 1."
|
|
|
|
# Create the model directory if it does not exist
|
|
os.makedirs(args.model_dir, exist_ok=True)
|
|
|
|
print(args) # Print the parsed arguments for verification
|
|
main(args) # Start the main process
|