4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
280 lines
13 KiB
Python
280 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import argparse
|
|
import warnings
|
|
import setproctitle
|
|
from torch import nn, decentralized # Used for decentralized training
|
|
from torch.backends import cudnn # Optimizes performance for convolutional networks
|
|
from tensorboardX import SummaryWriter # For logging metrics and results to TensorBoard
|
|
import torch.cuda.amp as amp # For mixed precision training
|
|
from config import * # Custom configuration module
|
|
from params import train_params # Training parameters
|
|
from utils import label_smoothing, norm, summary, metric, lr_scheduler, prefetch # Utility functions
|
|
from model import coremodel # Core model implementation
|
|
from dataset import factory # Dataset and data loader factory
|
|
from params.train_params import save_hp_to_json # Function to save hyperparameters to JSON
|
|
|
|
# Global variable to store the best accuracy obtained during training
|
|
best_acc1 = 0
|
|
|
|
def main(args):
|
|
# Warn if a specific GPU is chosen, as this will disable data parallelism
|
|
if args.gpu is not None:
|
|
warnings.warn("Selecting a specific GPU will disable data parallelism.")
|
|
|
|
# Adjust loop factor based on specific training configurations
|
|
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
|
|
# Check if decentralized training is needed
|
|
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized
|
|
|
|
# Get the number of available GPUs on the machine
|
|
num_gpus = torch.cuda.device_count()
|
|
args.ngpus_per_node = num_gpus
|
|
print(f"INFO:PyTorch: GPUs available on this node: {num_gpus}")
|
|
|
|
# If multiprocessing is needed for decentralized training
|
|
if args.multiprocessing_decentralized:
|
|
# Adjust world size to account for multiple GPUs
|
|
args.world_size *= num_gpus
|
|
# Spawn multiple processes for each GPU
|
|
torch.multiprocessing.spawn(execute_worker_process, nprocs=num_gpus, args=(num_gpus, args))
|
|
else:
|
|
# If using a single GPU
|
|
print("INFO:PyTorch: Using GPU 0 for single GPU training")
|
|
args.gpu = 0
|
|
# Call main worker for single GPU
|
|
execute_worker_process(args.gpu, num_gpus, args)
|
|
|
|
def execute_worker_process(gpu, num_gpus, args):
|
|
global best_acc1
|
|
args.gpu = gpu
|
|
# Set the directory where models will be saved
|
|
args.model_dir = os.path.join(HOME, "models", "coremodel", str(args.spid))
|
|
|
|
# Initialize the decentralized training process group if needed
|
|
if args.is_decentralized:
|
|
print("INFO:PyTorch: Initializing process group for decentralized training.")
|
|
if args.dist_url == "env://" and args.rank == -1:
|
|
args.rank = int(os.environ["RANK"])
|
|
if args.multiprocessing_decentralized:
|
|
args.rank = args.rank * num_gpus + gpu
|
|
decentralized.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
|
|
|
|
# Set the GPU to be used for training or evaluation
|
|
if args.gpu is not None:
|
|
print(f"INFO:PyTorch: GPU {args.gpu} in use for training (Rank: {args.rank})" if not args.evaluate else f"INFO:PyTorch: GPU {args.gpu} in use for evaluation (Rank: {args.rank})")
|
|
|
|
# Set process title for better identification in system process monitors
|
|
setproctitle.setproctitle(f"{args.proc_name}centralized_rank{args.rank}")
|
|
|
|
# Initialize a SummaryWriter for TensorBoard logging
|
|
val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val'))
|
|
|
|
# Use label smoothing if enabled, otherwise use standard cross-entropy loss
|
|
criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss()
|
|
|
|
# Instantiate the model
|
|
model = coremodel.coremodel(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
|
|
print(f"INFO:PyTorch: Model '{args.arch}' has {metric.get_the_number_of_params(model)} parameters")
|
|
|
|
# If summary is requested, print model and exit
|
|
if args.is_summary:
|
|
print(model)
|
|
return
|
|
|
|
# Save model configuration and hyperparameters
|
|
summary.save_model_to_json(args, model)
|
|
|
|
# Convert BatchNorm layers to synchronized BatchNorm for decentralized training
|
|
if args.is_decentralized and args.world_size > 1 and args.is_syncbn:
|
|
print("INFO:PyTorch: Converting BatchNorm to SyncBatchNorm")
|
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
# Set up the model for GPU-based training
|
|
if args.gpu is not None:
|
|
torch.cuda.set_device(args.gpu)
|
|
model.cuda(args.gpu)
|
|
args.batch_size = int(args.batch_size / num_gpus) # Adjust batch size for multiple GPUs
|
|
args.workers = int((args.workers + num_gpus - 1) / num_gpus) # Adjust number of workers
|
|
# Use decentralized data parallel model
|
|
model = nn.parallel.decentralizedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
|
else:
|
|
# Use standard DataParallel for multi-GPU training
|
|
model = nn.DataParallel(model).cuda()
|
|
|
|
# Create the optimizer
|
|
optimizer = create_optimizer(args, model)
|
|
# Set up the gradient scaler for mixed precision training, if enabled
|
|
scaler = amp.GradScaler() if args.is_amp else None
|
|
|
|
# If resuming from a checkpoint, load model and optimizer state
|
|
if args.resume:
|
|
load_checkpoint(args, model, optimizer, scaler)
|
|
|
|
cudnn.performance_test = True # Enable cuDNN performance optimizations
|
|
|
|
# Set up data loader parameters
|
|
data_loader_params = {
|
|
'split_factor': args.loop_factor if args.is_diff_data_train else 1,
|
|
'batch_size': args.batch_size,
|
|
'crop_size': args.crop_size,
|
|
'dataset': args.dataset,
|
|
'is_decentralized': args.is_decentralized,
|
|
'num_workers': args.workers,
|
|
'randaa': args.randaa,
|
|
'is_autoaugment': args.is_autoaugment,
|
|
'is_cutout': args.is_cutout,
|
|
'erase_p': args.erase_p,
|
|
}
|
|
|
|
# Get the training and validation data loaders
|
|
train_loader, train_sampler = factory.obtain_data_loader(args.data, split="train", **data_loader_params)
|
|
val_loader = factory.obtain_data_loader(args.data, split="val", batch_size=args.eval_batch_size, crop_size=args.crop_size, num_workers=args.workers)
|
|
|
|
# Set up the learning rate scheduler
|
|
scheduler = lr_scheduler.create_scheduler(args, len(train_loader))
|
|
|
|
# If evaluating, run the validation function and exit
|
|
if args.evaluate:
|
|
validate(val_loader, model, args)
|
|
return
|
|
|
|
# Begin training and evaluation
|
|
train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus)
|
|
|
|
# Function to create the optimizer
|
|
def create_optimizer(args, model):
|
|
param_groups = model.parameters() if args.is_wd_all else lr_scheduler.get_parameter_groups(model)
|
|
# Select the optimizer based on input arguments
|
|
if args.optimizer == 'SGD':
|
|
return torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.is_nesterov)
|
|
elif args.optimizer == 'AdamW':
|
|
return torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4, weight_decay=args.weight_decay)
|
|
elif args.optimizer == 'RMSprop':
|
|
return torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, momentum=0.9, weight_decay=args.weight_decay)
|
|
else:
|
|
# Raise error if unsupported optimizer is selected
|
|
raise NotImplementedError(f"Optimizer {args.optimizer} not implemented")
|
|
|
|
# Function to load a checkpoint and resume training
|
|
def load_checkpoint(args, model, optimizer, scaler):
|
|
if os.path.isfile(args.resume):
|
|
print(f"INFO:PyTorch: Loading checkpoint from '{args.resume}'")
|
|
loc = f'cuda:{args.gpu}' if args.gpu is not None else None
|
|
checkpoint = torch.load(args.resume, map_location=loc)
|
|
args.start_epoch = checkpoint['epoch']
|
|
global best_acc1
|
|
best_acc1 = checkpoint['best_acc1']
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
if "scaler" in checkpoint:
|
|
scaler.load_state_dict(checkpoint['scaler'])
|
|
print(f"INFO:PyTorch: Checkpoint loaded, epoch {checkpoint['epoch']}")
|
|
else:
|
|
print(f"INFO:PyTorch: No checkpoint found at '{args.resume}'")
|
|
|
|
# Function to train and evaluate the model over multiple epochs
|
|
def train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus):
|
|
for epoch in range(args.start_epoch, args.epochs + 1):
|
|
if args.is_decentralized:
|
|
train_sampler.set_epoch(epoch)
|
|
|
|
train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args)
|
|
|
|
# Evaluate the model every 'eval_per_epoch' epochs
|
|
if (epoch + 1) % args.eval_per_epoch == 0:
|
|
acc_all = validate(val_loader, model, args)
|
|
global best_acc1
|
|
is_best = acc_all[0] > best_acc1 # Track the best accuracy
|
|
best_acc1 = max(acc_all[0], best_acc1)
|
|
# Save the model checkpoint
|
|
save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best)
|
|
|
|
# Function to perform one training epoch
|
|
def train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args):
|
|
metric_storage = create_metric_storage(args.loop_factor)
|
|
model.train() # Set the model to training mode
|
|
data_loader = prefetch.data_prefetcher(train_loader) # Use data prefetching to improve efficiency
|
|
images, target = data_loader.next()
|
|
|
|
optimizer.zero_grad() # Reset gradients
|
|
while images is not None:
|
|
# Adjust the learning rate based on the scheduler
|
|
scheduler(optimizer, epoch)
|
|
|
|
# Perform forward pass with mixed precision if enabled
|
|
if args.is_amp:
|
|
with amp.autocast():
|
|
ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch)
|
|
else:
|
|
ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch)
|
|
|
|
# Calculate total loss and normalize
|
|
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
|
|
val_writer.add_scalar('average_training_loss', total_loss, global_step=epoch)
|
|
|
|
# Perform backward pass and update gradients with mixed precision if enabled
|
|
if args.is_amp:
|
|
scaler.scale(total_loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
total_loss.backward()
|
|
optimizer.step()
|
|
|
|
images, target = data_loader.next() # Fetch the next batch of data
|
|
|
|
# Function to save the model checkpoint
|
|
def save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best):
|
|
ckpt = {
|
|
'epoch': epoch + 1,
|
|
'state_dict': model.state_dict(),
|
|
'best_acc1': best_acc1,
|
|
'optimizer': optimizer.state_dict(),
|
|
}
|
|
if args.is_amp:
|
|
ckpt['scaler'] = scaler.state_dict()
|
|
metric.save_checkpoint(ckpt, is_best, args.model_dir, filename=f"checkpoint_{epoch}.pth.tar")
|
|
|
|
# Function to validate the model on the validation dataset
|
|
def validate(val_loader, model, args):
|
|
metric_storage = create_metric_storage(args.loop_factor)
|
|
model.eval() # Set the model to evaluation mode
|
|
|
|
with torch.no_grad():
|
|
for i, (images, target) in enumerate(val_loader):
|
|
if args.gpu is not None:
|
|
images = images.cuda(args.gpu, non_blocking=True)
|
|
target = target.cuda(args.gpu, non_blocking=True)
|
|
|
|
# Perform forward pass with mixed precision if enabled
|
|
if args.is_amp:
|
|
with amp.autocast():
|
|
ensemble_output, outputs, ce_loss = model(images, target=target, mode='val')
|
|
else:
|
|
ensemble_output, outputs, ce_loss = model(images, target=target, mode='val')
|
|
|
|
batch_size = images.size(0)
|
|
acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5))
|
|
|
|
metric_storage.update(acc1, acc5, ce_loss, batch_size)
|
|
|
|
return metric_storage.results()
|
|
|
|
# Helper function to create a storage for metrics during training and validation
|
|
def create_metric_storage(loop_factor):
|
|
# Initialize metrics for accuracy and other performance metrics
|
|
top1_all = [metric.AverageMeter(f'Acc@1_{i}', ':6.2f') for i in range(loop_factor)]
|
|
avg_top1 = metric.AverageMeter('Avg_Acc@1', ':6.2f')
|
|
return metric.ProgressMeter(len(top1_all), top1_all, avg_top1)
|
|
|
|
# Main entry point for the script
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Centralized Training')
|
|
args = train_params.add_parser_params(parser) # Add parameters to the argument parser
|
|
assert args.is_fed == 0, "Centralized training requires args.is_fed to be False"
|
|
os.makedirs(args.model_dir, exist_ok=True) # Create model directory if it doesn't exist
|
|
main(args) # Call the main function
|