4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
206 lines
8.9 KiB
Python
206 lines
8.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import argparse
|
|
import torch.nn as nn
|
|
from config import * # Import configuration
|
|
from params import train_params # Import training parameters
|
|
from model import coremodel, coremodelsl # Import models
|
|
from utils import ( # Import utility functions
|
|
label_smoothing, norm, metric, lr_scheduler, prefetch,
|
|
save_hp_to_json, profile, clever_format
|
|
)
|
|
from dataset import factory # Import dataset factory
|
|
|
|
# Specify the GPU to be used
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
# Global variable for tracking the best accuracy
|
|
best_acc1 = 0
|
|
|
|
# Function to calculate the average of a list of values
|
|
def average(values):
|
|
"""Calculate average of a list."""
|
|
return sum(values) / len(values)
|
|
|
|
# Function to aggregate the models from multiple clients into a global model
|
|
def merge_models(global_model_main, global_model_proxy, client_main_models, client_proxy_models):
|
|
"""Aggregates weights of the models using simple mean."""
|
|
# Get the state dictionaries for the global models
|
|
global_main_state = global_model_main.state_dict()
|
|
global_proxy_state = global_model_proxy.state_dict()
|
|
|
|
# Aggregate the main client models by averaging the weights
|
|
for key in global_main_state.keys():
|
|
global_main_state[key] = torch.stack([client.state_dict()[key].float() for client in client_main_models], 0).mean(0)
|
|
global_model_main.load_state_dict(global_main_state)
|
|
|
|
# Aggregate the proxy client models similarly
|
|
for key in global_proxy_state.keys():
|
|
global_proxy_state[key] = torch.stack([client.state_dict()[key].float() for client in client_proxy_models], 0).mean(0)
|
|
global_model_proxy.load_state_dict(global_proxy_state)
|
|
|
|
# Synchronize the client models with the updated global model
|
|
for client in client_main_models:
|
|
client.load_state_dict(global_model_main.state_dict())
|
|
for client in client_proxy_models:
|
|
client.load_state_dict(global_model_proxy.state_dict())
|
|
|
|
# Function to perform client-side training updates
|
|
def client_update(args, round_idx, main_model, proxy_models, schedulers_main, schedulers_proxy, optimizers_main, optimizers_proxy, train_loader, epochs=5, streams=None):
|
|
"""Client-side training update."""
|
|
main_model.train()
|
|
proxy_models.train()
|
|
|
|
# Train for a given number of epochs
|
|
for epoch in range(epochs):
|
|
# Prefetch data for faster loading
|
|
prefetcher = prefetch.data_prefetcher(train_loader)
|
|
images, targets = prefetcher.next()
|
|
batch_idx = 0
|
|
|
|
# Zero the gradients
|
|
optimizers_main.zero_grad()
|
|
optimizers_proxy.zero_grad()
|
|
|
|
# Process each batch of data
|
|
while images is not None:
|
|
# Adjust learning rates using the scheduler
|
|
schedulers_main(optimizers_main, batch_idx, round_idx)
|
|
schedulers_proxy(optimizers_proxy, batch_idx, round_idx)
|
|
|
|
# Forward pass for the main model
|
|
outputs, y_a, y_b, lam = main_model(images, target=targets, mode='train', epoch=epoch, streams=streams)
|
|
main_fx = [output.clone().detach().requires_grad_(True) for output in outputs]
|
|
|
|
# Forward pass for the proxy model with outputs from the main model
|
|
ensemble_output, proxy_outputs, ce_loss, cot_loss = proxy_models(main_fx, y_a, y_b, lam, target=targets, mode='train', epoch=epoch, streams=streams)
|
|
|
|
# Calculate total loss and perform backpropagation
|
|
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
|
|
total_loss.backward()
|
|
|
|
# Backpropagate gradients for the main model
|
|
for j in range(len(main_fx)):
|
|
outputs[j].backward(main_fx[j].grad)
|
|
|
|
# Update the model weights periodically
|
|
if batch_idx % args.iters_to_accumulate == 0 or batch_idx == len(train_loader):
|
|
optimizers_main.step()
|
|
optimizers_main.zero_grad()
|
|
optimizers_proxy.step()
|
|
optimizers_proxy.zero_grad()
|
|
|
|
# Fetch the next batch of images
|
|
images, targets = prefetcher.next()
|
|
batch_idx += 1
|
|
|
|
return total_loss.item()
|
|
|
|
# Function to validate the models on a validation set
|
|
def validate(val_loader, main_model, proxy_models, args, streams=None):
|
|
"""Validation function to evaluate models."""
|
|
main_model.eval()
|
|
proxy_models.eval()
|
|
|
|
# Initialize metrics for accuracy tracking
|
|
top1_metrics = [metric.AverageMeter(f"Acc@1_{i}", ":6.2f") for i in range(args.loop_factor)]
|
|
acc1_list, acc5_list, ce_loss_list = [], [], []
|
|
|
|
# Disable gradient computation for validation
|
|
with torch.no_grad():
|
|
for images, targets in val_loader:
|
|
images, targets = images.cuda(), targets.cuda()
|
|
|
|
# Forward pass for main model
|
|
outputs = main_model(images, target=targets, mode='val')
|
|
main_fx = [output.clone().detach().requires_grad_(True) for output in outputs]
|
|
|
|
# Forward pass for proxy model
|
|
ensemble_output, proxy_outputs, ce_loss = proxy_models(main_fx, target=targets, mode='val')
|
|
|
|
# Calculate accuracy
|
|
acc1, acc5 = metric.accuracy(ensemble_output, targets, topk=(1, 5))
|
|
acc1_list.append(acc1)
|
|
acc5_list.append(acc5)
|
|
ce_loss_list.append(ce_loss)
|
|
|
|
# Calculate average metrics over the validation set
|
|
avg_acc1 = average(acc1_list)
|
|
avg_acc5 = average(acc5_list)
|
|
avg_ce_loss = average(ce_loss_list)
|
|
|
|
return avg_ce_loss, avg_acc1, top1_metrics
|
|
|
|
# Main function to set up and start decentralized training
|
|
def main(args):
|
|
"""Main function to set up decentralized training."""
|
|
# Set loop factor based on training configuration
|
|
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
|
|
# Determine if decentralized training is needed
|
|
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized
|
|
|
|
# Get the number of GPUs available
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
args.ngpus_per_node = ngpus_per_node
|
|
|
|
# If using decentralized training with multiprocessing
|
|
if args.multiprocessing_decentralized:
|
|
args.world_size *= ngpus_per_node
|
|
torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
|
else:
|
|
# If not using multiprocessing, proceed with a single GPU
|
|
args.gpu = 0
|
|
execute_worker_process(args.gpu, ngpus_per_node, args)
|
|
|
|
# Main worker function to handle training with multiple GPUs or single GPU
|
|
def execute_worker_process(gpu, ngpus_per_node, args):
|
|
"""Main worker function for multi-GPU or single-GPU training."""
|
|
global best_acc1
|
|
args.gpu = gpu
|
|
|
|
# Set process title
|
|
setproctitle.setproctitle(f"{args.proc_name}_EdgeFLite_rank{args.rank}")
|
|
|
|
# Set the criterion for loss calculation
|
|
if args.is_label_smoothing:
|
|
criterion = label_smoothing.label_smoothing_CE(reduction='mean')
|
|
else:
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
# Create the main and proxy models for training
|
|
main_model = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda()
|
|
proxy_model = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda()
|
|
|
|
# Initialize client models for federated learning
|
|
client_main_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)]
|
|
client_proxy_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)]
|
|
|
|
# Synchronize client models with the global models
|
|
for client in client_main_models:
|
|
client.load_state_dict(main_model.state_dict())
|
|
for client in client_proxy_models:
|
|
client.load_state_dict(proxy_model.state_dict())
|
|
|
|
# Load training and validation data
|
|
train_loader = factory.obtain_data_loader(args.data, batch_size=args.batch_size, dataset=args.dataset, split="train", num_workers=args.workers)
|
|
val_loader = factory.obtain_data_loader(args.data, batch_size=args.eval_batch_size, dataset=args.dataset, split="val", num_workers=args.workers)
|
|
|
|
# Loop over training rounds
|
|
for r in range(args.start_round, args.num_rounds + 1):
|
|
# Update client models with new training data
|
|
client_update(args, r, client_main_models, client_proxy_models, lr_scheduler.lr_scheduler, lr_scheduler.lr_scheduler, torch.optim.SGD, torch.optim.SGD, train_loader)
|
|
|
|
# Validate the models
|
|
test_loss, acc, top1 = validate(val_loader, main_model, proxy_model, args)
|
|
|
|
# Track the best accuracy achieved
|
|
best_acc1 = max(acc, best_acc1)
|
|
|
|
# Entry point for the script
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Training EdgeFLite")
|
|
args = train_params.add_parser_params(parser)
|
|
main(args)
|