use-case-and-architecture/EdgeFLite/resnet_federated.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

162 lines
7.3 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import argparse
import os
import torch
from dataset import factory
from params import train_params
from fedml_service.data_cleaning.cifar10.data_loader import load_partition_data_cifar10
from fedml_service.data_cleaning.cifar100.data_loader import load_partition_data_cifar100
from fedml_service.data_cleaning.skin_dataset.data_loader import load_partition_data_skin_dataset
from fedml_service.data_cleaning.pillbase.data_loader import load_partition_data_pillbase
from fedml_service.model.cv.resnet_gkt.resnet import wide_resnet16_8_gkt, wide_resnet_model_50_2_gkt, resnet110_gkt
from fedml_service.decentralized.fedgkt.GKTTrainer import GKTTrainer
from fedml_service.decentralized.fedgkt.GKTServerTrainer import GKTServerTrainer
from params.train_params import save_hp_to_json
from config import HOME
from tensorboardX import SummaryWriter
# Set CUDA device to be used for training
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0")
# Initialize TensorBoard writers for logging
def initialize_writers(args):
log_dir = os.path.join(args.model_dir, 'val') # Create a log directory inside the model directory
return SummaryWriter(log_dir=log_dir) # Initialize SummaryWriter for TensorBoard logging
# Initialize dataset and data loaders
def initialize_dataset(args, data_split_factor):
# Fetch training data and sampler based on various input parameters
train_data_local_dict, train_sampler = factory.obtain_data_loader(
args.data,
split_factor=data_split_factor,
batch_size=args.batch_size,
crop_size=args.crop_size,
dataset=args.dataset,
split="train", # Split data for training
is_decentralized=args.is_decentralized,
is_autoaugment=args.is_autoaugment,
randaa=args.randaa,
is_cutout=args.is_cutout,
erase_p=args.erase_p,
num_workers=args.workers,
is_fed=args.is_fed,
num_clusters=args.num_clusters,
cifar10_non_iid=args.cifar10_non_iid,
cifar100_non_iid=args.cifar100_non_iid
)
# Fetch global test data
test_data_global = factory.obtain_data_loader(
args.data,
batch_size=args.eval_batch_size,
crop_size=args.crop_size,
dataset=args.dataset,
split="val", # Split data for validation
num_workers=args.workers,
cifar10_non_iid=args.cifar10_non_iid,
cifar100_non_iid=args.cifar100_non_iid
)
return train_data_local_dict, test_data_global # Return both train and test data loaders
# Setup models based on the dataset
def setup_models(args):
if args.dataset == "cifar10":
return load_partition_data_cifar10, wide_resnet16_8_gkt() # Model for CIFAR-10
elif args.dataset == "cifar100":
return load_partition_data_cifar100, resnet110_gkt() # Model for CIFAR-100
elif args.dataset == "skin_dataset":
return load_partition_data_skin_dataset, wide_resnet_model_50_2_gkt() # Model for skin dataset
elif args.dataset == "pill_base":
return load_partition_data_pillbase, wide_resnet_model_50_2_gkt() # Model for pill base dataset
else:
raise ValueError(f"Unsupported dataset: {args.dataset}") # Raise error for unsupported dataset
# Initialize trainers for each client in the federated learning setup
def initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict):
client_trainers = []
# Initialize a trainer for each client
for i in range(client_number):
trainer = GKTTrainer(
client_idx=i,
train_data_local_dict=train_data_local_dict,
test_data_local_dict=test_data_local_dict,
device=device,
model_client=model_client,
args=args
)
client_trainers.append(trainer) # Add client trainer to the list
return client_trainers
# Main function to initialize and run the federated learning process
def main(args):
args.model_dir = os.path.join(str(HOME), "models/coremodel", str(args.spid)) # Set model directory based on home directory and spid
# Save hyperparameters if not in summary or evaluation mode
if not args.is_summary and not args.evaluate:
save_hp_to_json(args)
# Initialize the TensorBoard writer for logging
val_writer = initialize_writers(args)
data_split_factor = args.loop_factor if args.is_diff_data_train else 1 # Set data split factor based on training mode
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Check if decentralized learning is needed
print(f"INFO: PyTorch: => The number of views of train data is '{data_split_factor}'")
# Load dataset and initialize data loaders
train_data_local_dict, test_data_global = initialize_dataset(args, data_split_factor)
# Setup models for the clients and server
data_loader, (model_client, model_server) = setup_models(args)
client_number = args.num_clusters * args.split_factor # Calculate the number of clients
# Load data for federated learning
train_data_num, test_data_num, train_data_global, _, _, _, test_data_local_dict, class_num = data_loader(
args.dataset, args.data, 'homo', 0.5, client_number, args.batch_size
)
dataset_info = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_dict, test_data_local_dict, class_num]
print("Server and clients initialized.")
round_idx = 0 # Initialize the training round index
# Initialize client trainers and server trainer
client_trainers = initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict)
server_trainer = GKTServerTrainer(client_number, device, model_server, args, val_writer)
# Start federated training rounds
for current_round in range(args.num_rounds):
# For each client, perform local training and send results to the server
for client_idx in range(client_number):
extracted_features, logits, labels, test_features, test_labels = client_trainers[client_idx].train()
print(f"Client {client_idx} finished training.")
server_trainer.add_local_trained_result(client_idx, extracted_features, logits, labels, test_features, test_labels)
# Check if the server has received all clients' results
if server_trainer.check_whether_all_receive():
print("All clients' results received by server.")
server_trainer.train(round_idx) # Server performs training using the aggregated results
round_idx += 1
# Send global model updates back to clients
for client_idx in range(client_number):
global_logits = server_trainer.get_global_logits(client_idx)
client_trainers[client_idx].update_large_model_logits(global_logits)
print("Server sent updated logits back to clients.")
# Entry point of the script
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = train_params.add_parser_params(parser)
# Ensure that federated learning mode is enabled
assert args.is_fed == 1, "Federated learning requires 'args.is_fed' to be set to 1."
# Create the model directory if it does not exist
os.makedirs(args.model_dir, exist_ok=True)
print(args) # Print the parsed arguments for verification
main(args) # Start the main process