use-case-and-architecture/EdgeFLite/run_local.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

280 lines
13 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import argparse
import warnings
import setproctitle
from torch import nn, decentralized # Used for decentralized training
from torch.backends import cudnn # Optimizes performance for convolutional networks
from tensorboardX import SummaryWriter # For logging metrics and results to TensorBoard
import torch.cuda.amp as amp # For mixed precision training
from config import * # Custom configuration module
from params import train_params # Training parameters
from utils import label_smoothing, norm, summary, metric, lr_scheduler, prefetch # Utility functions
from model import coremodel # Core model implementation
from dataset import factory # Dataset and data loader factory
from params.train_params import save_hp_to_json # Function to save hyperparameters to JSON
# Global variable to store the best accuracy obtained during training
best_acc1 = 0
def main(args):
# Warn if a specific GPU is chosen, as this will disable data parallelism
if args.gpu is not None:
warnings.warn("Selecting a specific GPU will disable data parallelism.")
# Adjust loop factor based on specific training configurations
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
# Check if decentralized training is needed
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized
# Get the number of available GPUs on the machine
num_gpus = torch.cuda.device_count()
args.ngpus_per_node = num_gpus
print(f"INFO:PyTorch: GPUs available on this node: {num_gpus}")
# If multiprocessing is needed for decentralized training
if args.multiprocessing_decentralized:
# Adjust world size to account for multiple GPUs
args.world_size *= num_gpus
# Spawn multiple processes for each GPU
torch.multiprocessing.spawn(execute_worker_process, nprocs=num_gpus, args=(num_gpus, args))
else:
# If using a single GPU
print("INFO:PyTorch: Using GPU 0 for single GPU training")
args.gpu = 0
# Call main worker for single GPU
execute_worker_process(args.gpu, num_gpus, args)
def execute_worker_process(gpu, num_gpus, args):
global best_acc1
args.gpu = gpu
# Set the directory where models will be saved
args.model_dir = os.path.join(HOME, "models", "coremodel", str(args.spid))
# Initialize the decentralized training process group if needed
if args.is_decentralized:
print("INFO:PyTorch: Initializing process group for decentralized training.")
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_decentralized:
args.rank = args.rank * num_gpus + gpu
decentralized.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
# Set the GPU to be used for training or evaluation
if args.gpu is not None:
print(f"INFO:PyTorch: GPU {args.gpu} in use for training (Rank: {args.rank})" if not args.evaluate else f"INFO:PyTorch: GPU {args.gpu} in use for evaluation (Rank: {args.rank})")
# Set process title for better identification in system process monitors
setproctitle.setproctitle(f"{args.proc_name}centralized_rank{args.rank}")
# Initialize a SummaryWriter for TensorBoard logging
val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val'))
# Use label smoothing if enabled, otherwise use standard cross-entropy loss
criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss()
# Instantiate the model
model = coremodel.coremodel(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
print(f"INFO:PyTorch: Model '{args.arch}' has {metric.get_the_number_of_params(model)} parameters")
# If summary is requested, print model and exit
if args.is_summary:
print(model)
return
# Save model configuration and hyperparameters
summary.save_model_to_json(args, model)
# Convert BatchNorm layers to synchronized BatchNorm for decentralized training
if args.is_decentralized and args.world_size > 1 and args.is_syncbn:
print("INFO:PyTorch: Converting BatchNorm to SyncBatchNorm")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# Set up the model for GPU-based training
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
args.batch_size = int(args.batch_size / num_gpus) # Adjust batch size for multiple GPUs
args.workers = int((args.workers + num_gpus - 1) / num_gpus) # Adjust number of workers
# Use decentralized data parallel model
model = nn.parallel.decentralizedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
else:
# Use standard DataParallel for multi-GPU training
model = nn.DataParallel(model).cuda()
# Create the optimizer
optimizer = create_optimizer(args, model)
# Set up the gradient scaler for mixed precision training, if enabled
scaler = amp.GradScaler() if args.is_amp else None
# If resuming from a checkpoint, load model and optimizer state
if args.resume:
load_checkpoint(args, model, optimizer, scaler)
cudnn.performance_test = True # Enable cuDNN performance optimizations
# Set up data loader parameters
data_loader_params = {
'split_factor': args.loop_factor if args.is_diff_data_train else 1,
'batch_size': args.batch_size,
'crop_size': args.crop_size,
'dataset': args.dataset,
'is_decentralized': args.is_decentralized,
'num_workers': args.workers,
'randaa': args.randaa,
'is_autoaugment': args.is_autoaugment,
'is_cutout': args.is_cutout,
'erase_p': args.erase_p,
}
# Get the training and validation data loaders
train_loader, train_sampler = factory.obtain_data_loader(args.data, split="train", **data_loader_params)
val_loader = factory.obtain_data_loader(args.data, split="val", batch_size=args.eval_batch_size, crop_size=args.crop_size, num_workers=args.workers)
# Set up the learning rate scheduler
scheduler = lr_scheduler.create_scheduler(args, len(train_loader))
# If evaluating, run the validation function and exit
if args.evaluate:
validate(val_loader, model, args)
return
# Begin training and evaluation
train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus)
# Function to create the optimizer
def create_optimizer(args, model):
param_groups = model.parameters() if args.is_wd_all else lr_scheduler.get_parameter_groups(model)
# Select the optimizer based on input arguments
if args.optimizer == 'SGD':
return torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.is_nesterov)
elif args.optimizer == 'AdamW':
return torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4, weight_decay=args.weight_decay)
elif args.optimizer == 'RMSprop':
return torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, momentum=0.9, weight_decay=args.weight_decay)
else:
# Raise error if unsupported optimizer is selected
raise NotImplementedError(f"Optimizer {args.optimizer} not implemented")
# Function to load a checkpoint and resume training
def load_checkpoint(args, model, optimizer, scaler):
if os.path.isfile(args.resume):
print(f"INFO:PyTorch: Loading checkpoint from '{args.resume}'")
loc = f'cuda:{args.gpu}' if args.gpu is not None else None
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
global best_acc1
best_acc1 = checkpoint['best_acc1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if "scaler" in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
print(f"INFO:PyTorch: Checkpoint loaded, epoch {checkpoint['epoch']}")
else:
print(f"INFO:PyTorch: No checkpoint found at '{args.resume}'")
# Function to train and evaluate the model over multiple epochs
def train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus):
for epoch in range(args.start_epoch, args.epochs + 1):
if args.is_decentralized:
train_sampler.set_epoch(epoch)
train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args)
# Evaluate the model every 'eval_per_epoch' epochs
if (epoch + 1) % args.eval_per_epoch == 0:
acc_all = validate(val_loader, model, args)
global best_acc1
is_best = acc_all[0] > best_acc1 # Track the best accuracy
best_acc1 = max(acc_all[0], best_acc1)
# Save the model checkpoint
save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best)
# Function to perform one training epoch
def train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args):
metric_storage = create_metric_storage(args.loop_factor)
model.train() # Set the model to training mode
data_loader = prefetch.data_prefetcher(train_loader) # Use data prefetching to improve efficiency
images, target = data_loader.next()
optimizer.zero_grad() # Reset gradients
while images is not None:
# Adjust the learning rate based on the scheduler
scheduler(optimizer, epoch)
# Perform forward pass with mixed precision if enabled
if args.is_amp:
with amp.autocast():
ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch)
else:
ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch)
# Calculate total loss and normalize
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
val_writer.add_scalar('average_training_loss', total_loss, global_step=epoch)
# Perform backward pass and update gradients with mixed precision if enabled
if args.is_amp:
scaler.scale(total_loss).backward()
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
optimizer.step()
images, target = data_loader.next() # Fetch the next batch of data
# Function to save the model checkpoint
def save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best):
ckpt = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
}
if args.is_amp:
ckpt['scaler'] = scaler.state_dict()
metric.save_checkpoint(ckpt, is_best, args.model_dir, filename=f"checkpoint_{epoch}.pth.tar")
# Function to validate the model on the validation dataset
def validate(val_loader, model, args):
metric_storage = create_metric_storage(args.loop_factor)
model.eval() # Set the model to evaluation mode
with torch.no_grad():
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# Perform forward pass with mixed precision if enabled
if args.is_amp:
with amp.autocast():
ensemble_output, outputs, ce_loss = model(images, target=target, mode='val')
else:
ensemble_output, outputs, ce_loss = model(images, target=target, mode='val')
batch_size = images.size(0)
acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5))
metric_storage.update(acc1, acc5, ce_loss, batch_size)
return metric_storage.results()
# Helper function to create a storage for metrics during training and validation
def create_metric_storage(loop_factor):
# Initialize metrics for accuracy and other performance metrics
top1_all = [metric.AverageMeter(f'Acc@1_{i}', ':6.2f') for i in range(loop_factor)]
avg_top1 = metric.AverageMeter('Avg_Acc@1', ':6.2f')
return metric.ProgressMeter(len(top1_all), top1_all, avg_top1)
# Main entry point for the script
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Centralized Training')
args = train_params.add_parser_params(parser) # Add parameters to the argument parser
assert args.is_fed == 0, "Centralized training requires args.is_fed to be False"
os.makedirs(args.model_dir, exist_ok=True) # Create model directory if it doesn't exist
main(args) # Call the main function