use-case-and-architecture/EdgeFLite/train_EdgeFLite.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

206 lines
8.9 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import argparse
import torch.nn as nn
from config import * # Import configuration
from params import train_params # Import training parameters
from model import coremodel, coremodelsl # Import models
from utils import ( # Import utility functions
label_smoothing, norm, metric, lr_scheduler, prefetch,
save_hp_to_json, profile, clever_format
)
from dataset import factory # Import dataset factory
# Specify the GPU to be used
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Global variable for tracking the best accuracy
best_acc1 = 0
# Function to calculate the average of a list of values
def average(values):
"""Calculate average of a list."""
return sum(values) / len(values)
# Function to aggregate the models from multiple clients into a global model
def merge_models(global_model_main, global_model_proxy, client_main_models, client_proxy_models):
"""Aggregates weights of the models using simple mean."""
# Get the state dictionaries for the global models
global_main_state = global_model_main.state_dict()
global_proxy_state = global_model_proxy.state_dict()
# Aggregate the main client models by averaging the weights
for key in global_main_state.keys():
global_main_state[key] = torch.stack([client.state_dict()[key].float() for client in client_main_models], 0).mean(0)
global_model_main.load_state_dict(global_main_state)
# Aggregate the proxy client models similarly
for key in global_proxy_state.keys():
global_proxy_state[key] = torch.stack([client.state_dict()[key].float() for client in client_proxy_models], 0).mean(0)
global_model_proxy.load_state_dict(global_proxy_state)
# Synchronize the client models with the updated global model
for client in client_main_models:
client.load_state_dict(global_model_main.state_dict())
for client in client_proxy_models:
client.load_state_dict(global_model_proxy.state_dict())
# Function to perform client-side training updates
def client_update(args, round_idx, main_model, proxy_models, schedulers_main, schedulers_proxy, optimizers_main, optimizers_proxy, train_loader, epochs=5, streams=None):
"""Client-side training update."""
main_model.train()
proxy_models.train()
# Train for a given number of epochs
for epoch in range(epochs):
# Prefetch data for faster loading
prefetcher = prefetch.data_prefetcher(train_loader)
images, targets = prefetcher.next()
batch_idx = 0
# Zero the gradients
optimizers_main.zero_grad()
optimizers_proxy.zero_grad()
# Process each batch of data
while images is not None:
# Adjust learning rates using the scheduler
schedulers_main(optimizers_main, batch_idx, round_idx)
schedulers_proxy(optimizers_proxy, batch_idx, round_idx)
# Forward pass for the main model
outputs, y_a, y_b, lam = main_model(images, target=targets, mode='train', epoch=epoch, streams=streams)
main_fx = [output.clone().detach().requires_grad_(True) for output in outputs]
# Forward pass for the proxy model with outputs from the main model
ensemble_output, proxy_outputs, ce_loss, cot_loss = proxy_models(main_fx, y_a, y_b, lam, target=targets, mode='train', epoch=epoch, streams=streams)
# Calculate total loss and perform backpropagation
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
total_loss.backward()
# Backpropagate gradients for the main model
for j in range(len(main_fx)):
outputs[j].backward(main_fx[j].grad)
# Update the model weights periodically
if batch_idx % args.iters_to_accumulate == 0 or batch_idx == len(train_loader):
optimizers_main.step()
optimizers_main.zero_grad()
optimizers_proxy.step()
optimizers_proxy.zero_grad()
# Fetch the next batch of images
images, targets = prefetcher.next()
batch_idx += 1
return total_loss.item()
# Function to validate the models on a validation set
def validate(val_loader, main_model, proxy_models, args, streams=None):
"""Validation function to evaluate models."""
main_model.eval()
proxy_models.eval()
# Initialize metrics for accuracy tracking
top1_metrics = [metric.AverageMeter(f"Acc@1_{i}", ":6.2f") for i in range(args.loop_factor)]
acc1_list, acc5_list, ce_loss_list = [], [], []
# Disable gradient computation for validation
with torch.no_grad():
for images, targets in val_loader:
images, targets = images.cuda(), targets.cuda()
# Forward pass for main model
outputs = main_model(images, target=targets, mode='val')
main_fx = [output.clone().detach().requires_grad_(True) for output in outputs]
# Forward pass for proxy model
ensemble_output, proxy_outputs, ce_loss = proxy_models(main_fx, target=targets, mode='val')
# Calculate accuracy
acc1, acc5 = metric.accuracy(ensemble_output, targets, topk=(1, 5))
acc1_list.append(acc1)
acc5_list.append(acc5)
ce_loss_list.append(ce_loss)
# Calculate average metrics over the validation set
avg_acc1 = average(acc1_list)
avg_acc5 = average(acc5_list)
avg_ce_loss = average(ce_loss_list)
return avg_ce_loss, avg_acc1, top1_metrics
# Main function to set up and start decentralized training
def main(args):
"""Main function to set up decentralized training."""
# Set loop factor based on training configuration
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
# Determine if decentralized training is needed
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized
# Get the number of GPUs available
ngpus_per_node = torch.cuda.device_count()
args.ngpus_per_node = ngpus_per_node
# If using decentralized training with multiprocessing
if args.multiprocessing_decentralized:
args.world_size *= ngpus_per_node
torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
# If not using multiprocessing, proceed with a single GPU
args.gpu = 0
execute_worker_process(args.gpu, ngpus_per_node, args)
# Main worker function to handle training with multiple GPUs or single GPU
def execute_worker_process(gpu, ngpus_per_node, args):
"""Main worker function for multi-GPU or single-GPU training."""
global best_acc1
args.gpu = gpu
# Set process title
setproctitle.setproctitle(f"{args.proc_name}_EdgeFLite_rank{args.rank}")
# Set the criterion for loss calculation
if args.is_label_smoothing:
criterion = label_smoothing.label_smoothing_CE(reduction='mean')
else:
criterion = nn.CrossEntropyLoss()
# Create the main and proxy models for training
main_model = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda()
proxy_model = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda()
# Initialize client models for federated learning
client_main_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)]
client_proxy_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)]
# Synchronize client models with the global models
for client in client_main_models:
client.load_state_dict(main_model.state_dict())
for client in client_proxy_models:
client.load_state_dict(proxy_model.state_dict())
# Load training and validation data
train_loader = factory.obtain_data_loader(args.data, batch_size=args.batch_size, dataset=args.dataset, split="train", num_workers=args.workers)
val_loader = factory.obtain_data_loader(args.data, batch_size=args.eval_batch_size, dataset=args.dataset, split="val", num_workers=args.workers)
# Loop over training rounds
for r in range(args.start_round, args.num_rounds + 1):
# Update client models with new training data
client_update(args, r, client_main_models, client_proxy_models, lr_scheduler.lr_scheduler, lr_scheduler.lr_scheduler, torch.optim.SGD, torch.optim.SGD, train_loader)
# Validate the models
test_loss, acc, top1 = validate(val_loader, main_model, proxy_model, args)
# Track the best accuracy achieved
best_acc1 = max(acc, best_acc1)
# Entry point for the script
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Training EdgeFLite")
args = train_params.add_parser_params(parser)
main(args)