4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
224 lines
9.7 KiB
Python
224 lines
9.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import os
|
|
import warnings
|
|
import torch
|
|
import torch.cuda.amp as autocast
|
|
from torch import nn
|
|
from torch.backends import cudnn
|
|
from tensorboardX import SummaryWriter
|
|
from config import *
|
|
from params import train_settings
|
|
from utils import label_smooth, metrics, scheduler, prefetch_loader
|
|
from model import net_splitter
|
|
from dataset import data_factory
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from params.train_settings import save_hyperparams_to_json
|
|
|
|
# Set the visible GPU to use for training
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
# Variable to store the best accuracy achieved during training
|
|
best_accuracy = 0
|
|
|
|
|
|
# Helper function to compute the average of a list
|
|
def compute_average(lst):
|
|
return sum(lst) / len(lst)
|
|
|
|
|
|
# Main function to initialize the training process
|
|
def main(args):
|
|
if args.gpu_index is not None:
|
|
# Warn if a specific GPU is selected, disabling data parallelism
|
|
warnings.warn("Specific GPU chosen, disabling data parallelism.")
|
|
|
|
# Adjust loop factor based on training setup
|
|
args.loop_factor = 1 if args.separate_training or args.single_branch else args.split_factor
|
|
# Determine if decentralized training is required
|
|
args.decentralized_training = args.world_size > 1 or args.multiprocessing_decentralized
|
|
num_gpus = torch.cuda.device_count()
|
|
args.num_gpus = num_gpus
|
|
|
|
# If decentralized multiprocessing is enabled, spawn multiple processes
|
|
if args.multiprocessing_decentralized:
|
|
args.world_size = num_gpus * args.world_size
|
|
torch.multiprocessing.spawn(worker_process, nprocs=num_gpus, args=(num_gpus, args))
|
|
else:
|
|
# Otherwise, proceed with single-GPU training
|
|
print(f"INFO:PyTorch: Detected {num_gpus} GPU(s) available.")
|
|
args.gpu_index = 0
|
|
worker_process(args.gpu_index, num_gpus, args)
|
|
|
|
|
|
# Client-side training function for federated learning updates
|
|
def client_train_update(args, round_num, client_model, global_model, sched, opt, train_loader, epochs=5, scaler=None):
|
|
client_model.train()
|
|
|
|
for epoch in range(epochs):
|
|
# Prefetch data for training
|
|
loader = prefetch_loader.DataPrefetcher(train_loader)
|
|
images, targets = loader.next()
|
|
batch_idx = 0
|
|
opt.zero_grad()
|
|
|
|
while images is not None:
|
|
# Apply learning rate scheduling
|
|
sched(opt, batch_idx)
|
|
|
|
# Use automatic mixed precision if enabled
|
|
if args.amp_enabled:
|
|
with autocast.autocast():
|
|
ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train',
|
|
epoch=epoch)
|
|
else:
|
|
ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train',
|
|
epoch=epoch)
|
|
|
|
# Compute accuracy for top-1 predictions
|
|
batch_size = images.size(0)
|
|
for j in range(args.loop_factor):
|
|
top1_acc = metrics.accuracy(model_outputs[j], targets, topk=(1,))
|
|
|
|
# Compute the proximal term for FedProx loss
|
|
prox_term = sum((param - global_param).norm(2) for param, global_param in
|
|
zip(client_model.parameters(), global_model.parameters()))
|
|
# Compute the total loss (cross-entropy + contrastive loss + proximal term)
|
|
total_loss = (loss_ce + loss_cot) / args.accum_steps + (args.mu / 2) * prox_term
|
|
|
|
# Backward pass with mixed precision scaling if enabled
|
|
if args.amp_enabled:
|
|
scaler.scale(total_loss).backward()
|
|
if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)):
|
|
scaler.step(opt)
|
|
scaler.update()
|
|
opt.zero_grad()
|
|
else:
|
|
total_loss.backward()
|
|
if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)):
|
|
opt.step()
|
|
opt.zero_grad()
|
|
|
|
images, targets = loader.next()
|
|
|
|
return total_loss.item()
|
|
|
|
|
|
# Function to aggregate model weights from clients on the server
|
|
def server_compute_average_weights(global_model, client_models):
|
|
global_state_dict = global_model.state_dict()
|
|
# Average weights across all clients
|
|
for key in global_state_dict.keys():
|
|
global_state_dict[key] = torch.stack(
|
|
[client_models[i].state_dict()[key].float() for i in range(len(client_models))], 0).mean(0)
|
|
global_model.load_state_dict(global_state_dict)
|
|
|
|
# Update clients with the averaged global model
|
|
for model in client_models:
|
|
model.load_state_dict(global_model.state_dict())
|
|
|
|
|
|
# Function to validate the model on the validation set
|
|
def validate_model(val_loader, model, args):
|
|
model.eval()
|
|
acc1_list, acc5_list, loss_ce_list = [], [], []
|
|
|
|
# Perform validation without gradient calculation
|
|
with torch.no_grad():
|
|
for images, targets in val_loader:
|
|
if args.gpu_index is not None:
|
|
images, targets = images.cuda(args.gpu_index, non_blocking=True), targets.cuda(args.gpu_index,
|
|
non_blocking=True)
|
|
|
|
if args.amp_enabled:
|
|
with autocast.autocast():
|
|
ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val')
|
|
else:
|
|
ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val')
|
|
|
|
for j in range(args.loop_factor):
|
|
acc1, acc5 = metrics.accuracy(model_outputs[j], targets, topk=(1, 5))
|
|
|
|
avg_acc1, avg_acc5 = metrics.accuracy(ensemble_out, targets, topk=(1, 5))
|
|
acc1_list.append(avg_acc1)
|
|
acc5_list.append(avg_acc5)
|
|
loss_ce_list.append(loss_ce)
|
|
|
|
return compute_average(loss_ce_list), compute_average(acc1_list)
|
|
|
|
|
|
# Function to handle the worker process for training on a specific GPU
|
|
def worker_process(gpu_index, num_gpus, args):
|
|
global best_accuracy
|
|
args.gpu_index = gpu_index
|
|
args.model_path = os.path.join(HOME, "models", "coremodel", str(args.model_id))
|
|
|
|
# Create summary writer for validation if not using decentralized training
|
|
if not args.decentralized_training or (args.multiprocessing_decentralized and args.rank % num_gpus == 0):
|
|
val_summary_writer = SummaryWriter(log_dir=os.path.join(args.model_path, 'validation'))
|
|
|
|
# Set the loss function based on the label smoothing option
|
|
criterion = label_smooth.smooth_ce_loss(reduction='mean') if args.use_label_smooth else nn.CrossEntropyLoss()
|
|
# Initialize the global model and client models
|
|
global_model = net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion)
|
|
client_models = [net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) for _ in
|
|
range(args.num_clients)]
|
|
|
|
# Save hyperparameters to JSON if required
|
|
if args.save_summary:
|
|
save_hyperparams_to_json(args)
|
|
return
|
|
|
|
# Move models to GPU
|
|
global_model = global_model.cuda()
|
|
for model in client_models:
|
|
model.cuda()
|
|
model.load_state_dict(global_model.state_dict())
|
|
|
|
# Create optimizers for each client
|
|
opt_list = [torch.optim.SGD(client.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
|
|
nesterov=args.use_nesterov) for client in client_models]
|
|
|
|
# Initialize gradient scaler if AMP is enabled
|
|
scaler = torch.cuda.amp.GradScaler() if args.amp_enabled else None
|
|
cudnn.performance_test = True
|
|
|
|
# Resume training from checkpoint if specified
|
|
if args.resume_training:
|
|
if os.path.isfile(args.resume_checkpoint):
|
|
checkpoint = torch.load(args.resume_checkpoint,
|
|
map_location=f'cuda:{args.gpu_index}' if args.gpu_index else None)
|
|
args.start_round = checkpoint['round']
|
|
best_accuracy = checkpoint['best_acc1']
|
|
global_model.load_state_dict(checkpoint['state_dict'])
|
|
for opt in opt_list:
|
|
opt.load_state_dict(checkpoint['optimizer'])
|
|
if "scaler" in checkpoint:
|
|
scaler.load_state_dict(checkpoint['scaler'])
|
|
for client_model in client_models:
|
|
client_model.load_state_dict(global_model.state_dict())
|
|
else:
|
|
args.start_round = 0
|
|
else:
|
|
args.start_round = 0
|
|
|
|
# Load training and validation data
|
|
train_loader, _ = data_factory.load_data(args.data_dir, args.batch_size, args.split_factor,
|
|
dataset_name=args.dataset_name, split="train",
|
|
num_workers=args.num_workers, decentralized=args.decentralized_training)
|
|
val_loader = data_factory.load_data(args.data_dir, args.eval_batch_size, args.split_factor,
|
|
dataset_name=args.dataset_name, split="val", num_workers=args.num_workers)
|
|
|
|
# Federated learning rounds
|
|
for round_num in range(args.start_round, args.num_rounds + 1):
|
|
if args.fixed_cluster:
|
|
# Select clients from fixed clusters for each round
|
|
selected_clusters = np.random.permutation(args.num_clusters)[:args.num_clients]
|
|
for i in tqdm(range(args.num_clients)):
|
|
selected_clients = np.arange(start=selected_clusters[i] * args.split_factor,
|
|
stop=(selected_clusters[i] + 1) * args.split_factor)
|
|
for client in selected_clients:
|
|
loss = client_train
|