use-case-and-architecture/EdgeFLite/run_splitfed.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

211 lines
9.3 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import torch.nn as nn
import numpy as np
import argparse
import warnings
from tqdm import tqdm
from tensorboardX import SummaryWriter
from dataset import factory
from config import *
from model import coremodelsl
from utils import label_smoothing, norm, metric, lr_scheduler, prefetch
from params import train_params
from params.train_params import save_hp_to_json
# Set the visible GPU devices for the training
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Global best accuracy to track the performance
best_acc1 = 0 # Global best accuracy
def average(values):
"""Calculate the average of a list of values."""
return sum(values) / len(values)
def combine_model_weights(global_model_client, global_model_server, client_models, server_models):
"""
Aggregate weights from client and server models using the mean method.
This function updates the global model weights by averaging the weights
from all client and server models.
"""
# Get the state dictionaries (weights) for both client and server models
client_state_dict = global_model_client.state_dict()
server_state_dict = global_model_server.state_dict()
# Average the weights across all client models
for key in client_state_dict.keys():
client_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in client_models], dim=0).mean(0)
global_model_client.load_state_dict(client_state_dict)
# Average the weights across all server models
for key in server_state_dict.keys():
server_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in server_models], dim=0).mean(0)
global_model_server.load_state_dict(server_state_dict)
# Load the updated global model weights back into the client models
for model in client_models:
model.load_state_dict(global_model_client.state_dict())
# Load the updated global model weights back into the server models
for model in server_models:
model.load_state_dict(global_model_server.state_dict())
def client_training(args, round_num, client_model, server_model, scheduler_client, scheduler_server, optimizer_client, optimizer_server, data_loader, epochs=5, streams=None):
"""
Perform client-side model training for the given number of epochs.
The client model performs the forward pass and sends intermediate outputs
to the server model for further computation.
"""
client_model.train()
server_model.train()
for epoch in range(epochs):
# Prefetch data to improve data loading speed
prefetcher = prefetch.data_prefetcher(data_loader)
images, target = prefetcher.next()
i = 0
optimizer_client.zero_grad()
optimizer_server.zero_grad()
while images is not None:
# Adjust learning rates using the schedulers
scheduler_client(optimizer_client, i, round_num)
scheduler_server(optimizer_server, i, round_num)
i += 1
# Forward pass on the client model
outputs_client, y_a, y_b, lam = client_model(images, target=target, mode='train', epoch=epoch, streams=streams)
client_fx = [outputs.clone().detach().requires_grad_(True) for outputs in outputs_client]
# Forward pass on the server model and compute losses
ensemble_output, outputs_server, ce_loss, cot_loss = server_model(client_fx, y_a, y_b, lam, target=target, mode='train', epoch=epoch, streams=streams)
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
total_loss.backward()
# Backpropagate the gradients to the client model
for fx, grad in zip(outputs_client, client_fx):
fx.backward(grad.grad)
# Perform optimization step when the accumulation condition is met
if i % args.iters_to_accumulate == 0 or i == len(data_loader):
optimizer_client.step()
optimizer_server.step()
optimizer_client.zero_grad()
optimizer_server.zero_grad()
# Fetch the next batch of data
images, target = prefetcher.next()
return total_loss.item()
def validate_model(val_loader, client_model, server_model, args, streams=None):
"""
Validate the performance of client and server models.
This function performs forward passes without updating the model weights
and computes validation accuracy and loss.
"""
client_model.eval()
server_model.eval()
acc1_list, acc5_list, ce_loss_list = [], [], []
with torch.no_grad():
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# Forward pass on the client model
outputs_client = client_model(images, target=target, mode='val')
client_fx = [output.clone().detach().requires_grad_(True) for output in outputs_client]
# Forward pass on the server model
ensemble_output, outputs_server, ce_loss = server_model(client_fx, target=target, mode='val')
# Calculate accuracy and losses
acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5))
acc1_list.append(acc1)
acc5_list.append(acc5)
ce_loss_list.append(ce_loss)
# Calculate average accuracy and loss over the validation dataset
avg_acc1 = average(acc1_list)
avg_acc5 = average(acc5_list)
avg_ce_loss = average(ce_loss_list)
return avg_ce_loss, avg_acc1, avg_acc5
def main(args):
"""
The main entry point for the federated learning process.
Initializes models, handles multiprocessing setup, and starts training.
"""
if args.gpu is not None:
warnings.warn("A specific GPU has been chosen. Data parallelism is disabled.")
# Set loop factor based on training configuration
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
ngpus_per_node = torch.cuda.device_count()
args.ngpus_per_node = ngpus_per_node
if args.multiprocessing_decentralized:
# Spawn a process for each GPU in decentralized setup
args.world_size = ngpus_per_node * args.world_size
torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
# Use only a single GPU in non-decentralized setup
args.gpu = 0
execute_worker_process(args.gpu, ngpus_per_node, args)
def execute_worker_process(gpu, ngpus_per_node, args):
"""
Worker function that handles model initialization, training, and validation.
"""
global best_acc1
args.gpu = gpu
if args.gpu is not None:
print(f"Using GPU {args.gpu} for training.")
# Create tensorboard writer for logging validation metrics
if not args.multiprocessing_decentralized or (args.multiprocessing_decentralized and args.rank % ngpus_per_node == 0):
val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val'))
# Define loss criterion with label smoothing or cross-entropy
criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss()
# Initialize global client and server models
global_model_client = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
global_model_server = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
# Initialize client and server models for each selected client
client_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)]
server_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)]
# Save hyperparameters to a JSON file
save_hp_to_json(args)
# Move global models and client/server models to GPU
global_model_client = global_model_client.cuda()
global_model_server = global_model_server.cuda()
for model in client_models + server_models:
model.cuda()
# Load global model weights into each client and server model
for model in client_models:
model.load_state_dict(global_model_client.state_dict())
for model in server_models:
model.load_state_dict(global_model_server.state_dict())
# Initialize learning rate schedulers for clients and servers
schedulers_clients = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)]
schedulers_servers = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)]
# Start the training and validation loop for the specified number of rounds
for r in range(args.start_round, args.num_rounds + 1):
# Randomly select client indices for training in each round
client_indices = np.random.permutation(args.num_clusters * args.loop_factor)[:args.num_selected * args.loop