4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
211 lines
9.3 KiB
Python
211 lines
9.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import argparse
|
|
import warnings
|
|
from tqdm import tqdm
|
|
from tensorboardX import SummaryWriter
|
|
from dataset import factory
|
|
from config import *
|
|
from model import coremodelsl
|
|
from utils import label_smoothing, norm, metric, lr_scheduler, prefetch
|
|
from params import train_params
|
|
from params.train_params import save_hp_to_json
|
|
|
|
# Set the visible GPU devices for the training
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
# Global best accuracy to track the performance
|
|
best_acc1 = 0 # Global best accuracy
|
|
|
|
def average(values):
|
|
"""Calculate the average of a list of values."""
|
|
return sum(values) / len(values)
|
|
|
|
def combine_model_weights(global_model_client, global_model_server, client_models, server_models):
|
|
"""
|
|
Aggregate weights from client and server models using the mean method.
|
|
This function updates the global model weights by averaging the weights
|
|
from all client and server models.
|
|
"""
|
|
# Get the state dictionaries (weights) for both client and server models
|
|
client_state_dict = global_model_client.state_dict()
|
|
server_state_dict = global_model_server.state_dict()
|
|
|
|
# Average the weights across all client models
|
|
for key in client_state_dict.keys():
|
|
client_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in client_models], dim=0).mean(0)
|
|
global_model_client.load_state_dict(client_state_dict)
|
|
|
|
# Average the weights across all server models
|
|
for key in server_state_dict.keys():
|
|
server_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in server_models], dim=0).mean(0)
|
|
global_model_server.load_state_dict(server_state_dict)
|
|
|
|
# Load the updated global model weights back into the client models
|
|
for model in client_models:
|
|
model.load_state_dict(global_model_client.state_dict())
|
|
|
|
# Load the updated global model weights back into the server models
|
|
for model in server_models:
|
|
model.load_state_dict(global_model_server.state_dict())
|
|
|
|
def client_training(args, round_num, client_model, server_model, scheduler_client, scheduler_server, optimizer_client, optimizer_server, data_loader, epochs=5, streams=None):
|
|
"""
|
|
Perform client-side model training for the given number of epochs.
|
|
The client model performs the forward pass and sends intermediate outputs
|
|
to the server model for further computation.
|
|
"""
|
|
client_model.train()
|
|
server_model.train()
|
|
|
|
for epoch in range(epochs):
|
|
# Prefetch data to improve data loading speed
|
|
prefetcher = prefetch.data_prefetcher(data_loader)
|
|
images, target = prefetcher.next()
|
|
i = 0
|
|
optimizer_client.zero_grad()
|
|
optimizer_server.zero_grad()
|
|
|
|
while images is not None:
|
|
# Adjust learning rates using the schedulers
|
|
scheduler_client(optimizer_client, i, round_num)
|
|
scheduler_server(optimizer_server, i, round_num)
|
|
i += 1
|
|
|
|
# Forward pass on the client model
|
|
outputs_client, y_a, y_b, lam = client_model(images, target=target, mode='train', epoch=epoch, streams=streams)
|
|
client_fx = [outputs.clone().detach().requires_grad_(True) for outputs in outputs_client]
|
|
|
|
# Forward pass on the server model and compute losses
|
|
ensemble_output, outputs_server, ce_loss, cot_loss = server_model(client_fx, y_a, y_b, lam, target=target, mode='train', epoch=epoch, streams=streams)
|
|
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
|
|
total_loss.backward()
|
|
|
|
# Backpropagate the gradients to the client model
|
|
for fx, grad in zip(outputs_client, client_fx):
|
|
fx.backward(grad.grad)
|
|
|
|
# Perform optimization step when the accumulation condition is met
|
|
if i % args.iters_to_accumulate == 0 or i == len(data_loader):
|
|
optimizer_client.step()
|
|
optimizer_server.step()
|
|
optimizer_client.zero_grad()
|
|
optimizer_server.zero_grad()
|
|
|
|
# Fetch the next batch of data
|
|
images, target = prefetcher.next()
|
|
|
|
return total_loss.item()
|
|
|
|
def validate_model(val_loader, client_model, server_model, args, streams=None):
|
|
"""
|
|
Validate the performance of client and server models.
|
|
This function performs forward passes without updating the model weights
|
|
and computes validation accuracy and loss.
|
|
"""
|
|
client_model.eval()
|
|
server_model.eval()
|
|
|
|
acc1_list, acc5_list, ce_loss_list = [], [], []
|
|
|
|
with torch.no_grad():
|
|
for i, (images, target) in enumerate(val_loader):
|
|
if args.gpu is not None:
|
|
images = images.cuda(args.gpu, non_blocking=True)
|
|
target = target.cuda(args.gpu, non_blocking=True)
|
|
|
|
# Forward pass on the client model
|
|
outputs_client = client_model(images, target=target, mode='val')
|
|
client_fx = [output.clone().detach().requires_grad_(True) for output in outputs_client]
|
|
|
|
# Forward pass on the server model
|
|
ensemble_output, outputs_server, ce_loss = server_model(client_fx, target=target, mode='val')
|
|
|
|
# Calculate accuracy and losses
|
|
acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5))
|
|
acc1_list.append(acc1)
|
|
acc5_list.append(acc5)
|
|
ce_loss_list.append(ce_loss)
|
|
|
|
# Calculate average accuracy and loss over the validation dataset
|
|
avg_acc1 = average(acc1_list)
|
|
avg_acc5 = average(acc5_list)
|
|
avg_ce_loss = average(ce_loss_list)
|
|
|
|
return avg_ce_loss, avg_acc1, avg_acc5
|
|
|
|
def main(args):
|
|
"""
|
|
The main entry point for the federated learning process.
|
|
Initializes models, handles multiprocessing setup, and starts training.
|
|
"""
|
|
if args.gpu is not None:
|
|
warnings.warn("A specific GPU has been chosen. Data parallelism is disabled.")
|
|
|
|
# Set loop factor based on training configuration
|
|
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
args.ngpus_per_node = ngpus_per_node
|
|
|
|
if args.multiprocessing_decentralized:
|
|
# Spawn a process for each GPU in decentralized setup
|
|
args.world_size = ngpus_per_node * args.world_size
|
|
torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
|
else:
|
|
# Use only a single GPU in non-decentralized setup
|
|
args.gpu = 0
|
|
execute_worker_process(args.gpu, ngpus_per_node, args)
|
|
|
|
def execute_worker_process(gpu, ngpus_per_node, args):
|
|
"""
|
|
Worker function that handles model initialization, training, and validation.
|
|
"""
|
|
global best_acc1
|
|
args.gpu = gpu
|
|
|
|
if args.gpu is not None:
|
|
print(f"Using GPU {args.gpu} for training.")
|
|
|
|
# Create tensorboard writer for logging validation metrics
|
|
if not args.multiprocessing_decentralized or (args.multiprocessing_decentralized and args.rank % ngpus_per_node == 0):
|
|
val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val'))
|
|
|
|
# Define loss criterion with label smoothing or cross-entropy
|
|
criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss()
|
|
|
|
# Initialize global client and server models
|
|
global_model_client = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
|
|
global_model_server = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
|
|
|
|
# Initialize client and server models for each selected client
|
|
client_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)]
|
|
server_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)]
|
|
|
|
# Save hyperparameters to a JSON file
|
|
save_hp_to_json(args)
|
|
|
|
# Move global models and client/server models to GPU
|
|
global_model_client = global_model_client.cuda()
|
|
global_model_server = global_model_server.cuda()
|
|
for model in client_models + server_models:
|
|
model.cuda()
|
|
|
|
# Load global model weights into each client and server model
|
|
for model in client_models:
|
|
model.load_state_dict(global_model_client.state_dict())
|
|
for model in server_models:
|
|
model.load_state_dict(global_model_server.state_dict())
|
|
|
|
# Initialize learning rate schedulers for clients and servers
|
|
schedulers_clients = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)]
|
|
schedulers_servers = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)]
|
|
|
|
# Start the training and validation loop for the specified number of rounds
|
|
for r in range(args.start_round, args.num_rounds + 1):
|
|
# Randomly select client indices for training in each round
|
|
client_indices = np.random.permutation(args.num_clusters * args.loop_factor)[:args.num_selected * args.loop
|