use-case-and-architecture/EdgeFLite/run_prox.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

224 lines
9.7 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
import warnings
import torch
import torch.cuda.amp as autocast
from torch import nn
from torch.backends import cudnn
from tensorboardX import SummaryWriter
from config import *
from params import train_settings
from utils import label_smooth, metrics, scheduler, prefetch_loader
from model import net_splitter
from dataset import data_factory
import numpy as np
from tqdm import tqdm
from params.train_settings import save_hyperparams_to_json
# Set the visible GPU to use for training
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Variable to store the best accuracy achieved during training
best_accuracy = 0
# Helper function to compute the average of a list
def compute_average(lst):
return sum(lst) / len(lst)
# Main function to initialize the training process
def main(args):
if args.gpu_index is not None:
# Warn if a specific GPU is selected, disabling data parallelism
warnings.warn("Specific GPU chosen, disabling data parallelism.")
# Adjust loop factor based on training setup
args.loop_factor = 1 if args.separate_training or args.single_branch else args.split_factor
# Determine if decentralized training is required
args.decentralized_training = args.world_size > 1 or args.multiprocessing_decentralized
num_gpus = torch.cuda.device_count()
args.num_gpus = num_gpus
# If decentralized multiprocessing is enabled, spawn multiple processes
if args.multiprocessing_decentralized:
args.world_size = num_gpus * args.world_size
torch.multiprocessing.spawn(worker_process, nprocs=num_gpus, args=(num_gpus, args))
else:
# Otherwise, proceed with single-GPU training
print(f"INFO:PyTorch: Detected {num_gpus} GPU(s) available.")
args.gpu_index = 0
worker_process(args.gpu_index, num_gpus, args)
# Client-side training function for federated learning updates
def client_train_update(args, round_num, client_model, global_model, sched, opt, train_loader, epochs=5, scaler=None):
client_model.train()
for epoch in range(epochs):
# Prefetch data for training
loader = prefetch_loader.DataPrefetcher(train_loader)
images, targets = loader.next()
batch_idx = 0
opt.zero_grad()
while images is not None:
# Apply learning rate scheduling
sched(opt, batch_idx)
# Use automatic mixed precision if enabled
if args.amp_enabled:
with autocast.autocast():
ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train',
epoch=epoch)
else:
ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train',
epoch=epoch)
# Compute accuracy for top-1 predictions
batch_size = images.size(0)
for j in range(args.loop_factor):
top1_acc = metrics.accuracy(model_outputs[j], targets, topk=(1,))
# Compute the proximal term for FedProx loss
prox_term = sum((param - global_param).norm(2) for param, global_param in
zip(client_model.parameters(), global_model.parameters()))
# Compute the total loss (cross-entropy + contrastive loss + proximal term)
total_loss = (loss_ce + loss_cot) / args.accum_steps + (args.mu / 2) * prox_term
# Backward pass with mixed precision scaling if enabled
if args.amp_enabled:
scaler.scale(total_loss).backward()
if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)):
scaler.step(opt)
scaler.update()
opt.zero_grad()
else:
total_loss.backward()
if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)):
opt.step()
opt.zero_grad()
images, targets = loader.next()
return total_loss.item()
# Function to aggregate model weights from clients on the server
def server_compute_average_weights(global_model, client_models):
global_state_dict = global_model.state_dict()
# Average weights across all clients
for key in global_state_dict.keys():
global_state_dict[key] = torch.stack(
[client_models[i].state_dict()[key].float() for i in range(len(client_models))], 0).mean(0)
global_model.load_state_dict(global_state_dict)
# Update clients with the averaged global model
for model in client_models:
model.load_state_dict(global_model.state_dict())
# Function to validate the model on the validation set
def validate_model(val_loader, model, args):
model.eval()
acc1_list, acc5_list, loss_ce_list = [], [], []
# Perform validation without gradient calculation
with torch.no_grad():
for images, targets in val_loader:
if args.gpu_index is not None:
images, targets = images.cuda(args.gpu_index, non_blocking=True), targets.cuda(args.gpu_index,
non_blocking=True)
if args.amp_enabled:
with autocast.autocast():
ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val')
else:
ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val')
for j in range(args.loop_factor):
acc1, acc5 = metrics.accuracy(model_outputs[j], targets, topk=(1, 5))
avg_acc1, avg_acc5 = metrics.accuracy(ensemble_out, targets, topk=(1, 5))
acc1_list.append(avg_acc1)
acc5_list.append(avg_acc5)
loss_ce_list.append(loss_ce)
return compute_average(loss_ce_list), compute_average(acc1_list)
# Function to handle the worker process for training on a specific GPU
def worker_process(gpu_index, num_gpus, args):
global best_accuracy
args.gpu_index = gpu_index
args.model_path = os.path.join(HOME, "models", "coremodel", str(args.model_id))
# Create summary writer for validation if not using decentralized training
if not args.decentralized_training or (args.multiprocessing_decentralized and args.rank % num_gpus == 0):
val_summary_writer = SummaryWriter(log_dir=os.path.join(args.model_path, 'validation'))
# Set the loss function based on the label smoothing option
criterion = label_smooth.smooth_ce_loss(reduction='mean') if args.use_label_smooth else nn.CrossEntropyLoss()
# Initialize the global model and client models
global_model = net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion)
client_models = [net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) for _ in
range(args.num_clients)]
# Save hyperparameters to JSON if required
if args.save_summary:
save_hyperparams_to_json(args)
return
# Move models to GPU
global_model = global_model.cuda()
for model in client_models:
model.cuda()
model.load_state_dict(global_model.state_dict())
# Create optimizers for each client
opt_list = [torch.optim.SGD(client.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov=args.use_nesterov) for client in client_models]
# Initialize gradient scaler if AMP is enabled
scaler = torch.cuda.amp.GradScaler() if args.amp_enabled else None
cudnn.performance_test = True
# Resume training from checkpoint if specified
if args.resume_training:
if os.path.isfile(args.resume_checkpoint):
checkpoint = torch.load(args.resume_checkpoint,
map_location=f'cuda:{args.gpu_index}' if args.gpu_index else None)
args.start_round = checkpoint['round']
best_accuracy = checkpoint['best_acc1']
global_model.load_state_dict(checkpoint['state_dict'])
for opt in opt_list:
opt.load_state_dict(checkpoint['optimizer'])
if "scaler" in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
for client_model in client_models:
client_model.load_state_dict(global_model.state_dict())
else:
args.start_round = 0
else:
args.start_round = 0
# Load training and validation data
train_loader, _ = data_factory.load_data(args.data_dir, args.batch_size, args.split_factor,
dataset_name=args.dataset_name, split="train",
num_workers=args.num_workers, decentralized=args.decentralized_training)
val_loader = data_factory.load_data(args.data_dir, args.eval_batch_size, args.split_factor,
dataset_name=args.dataset_name, split="val", num_workers=args.num_workers)
# Federated learning rounds
for round_num in range(args.start_round, args.num_rounds + 1):
if args.fixed_cluster:
# Select clients from fixed clusters for each round
selected_clusters = np.random.permutation(args.num_clusters)[:args.num_clients]
for i in tqdm(range(args.num_clients)):
selected_clients = np.arange(start=selected_clusters[i] * args.split_factor,
stop=(selected_clusters[i] + 1) * args.split_factor)
for client in selected_clients:
loss = client_train