use-case-and-architecture/EdgeFLite/helpers/pace_controller.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

147 lines
7.0 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import math
class CustomScheduler:
def __init__(self, mode='cosine',
initial_lr=0.1,
num_epochs=100,
iters_per_epoch=300,
lr_milestones=None,
lr_step=100,
step_multiplier=0.1,
slow_start_epochs=0,
slow_start_lr=1e-4,
min_lr=1e-3,
multiplier=1.0,
lower_bound=-6.0,
upper_bound=3.0,
decay_factor=0.97,
decay_epochs=0.8,
staircase=True):
"""
Initialize the learning rate scheduler.
Parameters:
mode (str): Mode for learning rate adjustment ('cosine', 'poly', 'HTD', 'step', 'exponential').
initial_lr (float): Initial learning rate.
num_epochs (int): Total number of epochs.
iters_per_epoch (int): Number of iterations per epoch.
lr_milestones (list): Epoch milestones for learning rate decay in 'step' mode.
lr_step (int): Epoch step size for learning rate reduction in 'step' mode.
step_multiplier (float): Multiplication factor for learning rate reduction in 'step' mode.
slow_start_epochs (int): Number of slow start epochs for warm-up.
slow_start_lr (float): Learning rate during warm-up.
min_lr (float): Minimum learning rate limit.
multiplier (float): Multiplication factor for applying to different parameter groups.
lower_bound (float): Lower bound for the tanh function in 'HTD' mode.
upper_bound (float): Upper bound for the tanh function in 'HTD' mode.
decay_factor (float): Factor by which learning rate decays in 'exponential' mode.
decay_epochs (float): Number of epochs over which learning rate decays in 'exponential' mode.
staircase (bool): If True, apply step-wise learning rate decay in 'exponential' mode.
"""
# Ensure valid mode selection
assert mode in ['cosine', 'poly', 'HTD', 'step', 'exponential'], "Invalid mode."
# Initialize learning rate settings
self.initial_lr = initial_lr
self.current_lr = initial_lr
self.min_lr = min_lr
self.mode = mode
self.num_epochs = num_epochs
self.iters_per_epoch = iters_per_epoch
self.total_iterations = (num_epochs - slow_start_epochs) * iters_per_epoch
self.slow_start_iters = slow_start_epochs * iters_per_epoch
self.slow_start_lr = slow_start_lr
self.multiplier = multiplier
self.lr_step = lr_step
self.lr_milestones = lr_milestones
self.step_multiplier = step_multiplier
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.decay_factor = decay_factor
self.decay_steps = decay_epochs * iters_per_epoch
self.staircase = staircase
print(f"INFO: Using {self.mode} learning rate scheduler with {slow_start_epochs} warm-up epochs.")
def update_lr(self, optimizer, iteration, epoch):
"""Update the learning rate based on the current iteration and epoch."""
current_iter = epoch * self.iters_per_epoch + iteration
# During slow start, linearly increase the learning rate
if current_iter <= self.slow_start_iters:
lr = self.slow_start_lr + (self.initial_lr - self.slow_start_lr) * (current_iter / self.slow_start_iters)
else:
# After slow start, calculate learning rate based on the selected mode
lr = self._calculate_lr(current_iter - self.slow_start_iters)
# Ensure learning rate does not fall below the minimum limit
self.current_lr = max(lr, self.min_lr)
self._apply_lr(optimizer, self.current_lr)
def _calculate_lr(self, adjusted_iter):
"""Calculate the learning rate based on the selected scheduling mode."""
if self.mode == 'cosine':
# Cosine annealing schedule
return 0.5 * self.initial_lr * (1 + math.cos(math.pi * adjusted_iter / self.total_iterations))
elif self.mode == 'poly':
# Polynomial decay schedule
return self.initial_lr * (1 - adjusted_iter / self.total_iterations) ** 0.9
elif self.mode == 'HTD':
# Hyperbolic tangent decay schedule
ratio = adjusted_iter / self.total_iterations
return 0.5 * self.initial_lr * (1 - math.tanh(self.lower_bound + (self.upper_bound - self.lower_bound) * ratio))
elif self.mode == 'step':
# Step decay schedule
return self._step_lr(adjusted_iter)
elif self.mode == 'exponential':
# Exponential decay schedule
power = math.floor(adjusted_iter / self.decay_steps) if self.staircase else adjusted_iter / self.decay_steps
return self.initial_lr * (self.decay_factor ** power)
else:
raise NotImplementedError("Unknown learning rate mode.")
def _step_lr(self, adjusted_iter):
"""Calculate the learning rate for the 'step' mode."""
epoch = adjusted_iter // self.iters_per_epoch
# Count how many milestones or steps have passed
if self.lr_milestones:
num_steps = sum([1 for milestone in self.lr_milestones if epoch >= milestone])
else:
num_steps = epoch // self.lr_step
return self.initial_lr * (self.step_multiplier ** num_steps)
def _apply_lr(self, optimizer, lr):
"""Apply the calculated learning rate to the optimizer."""
for i, param_group in enumerate(optimizer.param_groups):
# Apply multiplier to parameter groups beyond the first one
param_group['lr'] = lr * (self.multiplier if i > 1 else 1.0)
def adjust_hyperparameters(args):
"""Adjust the learning rate and momentum based on the batch size."""
print(f'Adjusting LR and momentum. Original LR: {args.lr}, Original momentum: {args.momentum}')
# Set standard batch size for scaling
standard_batch_size = 128 if 'cifar' in args.dataset else NotImplementedError
# Scale momentum and learning rate
args.momentum = args.momentum ** (args.batch_size / standard_batch_size)
args.lr *= (args.batch_size / standard_batch_size)
print(f'Adjusted LR: {args.lr}, Adjusted momentum: {args.momentum}')
return args
def separate_parameters(model, weight_decay_for_norm=0):
"""Separate the model parameters into two groups: regular parameters and norm-based parameters."""
regular_params, norm_params = [], []
for name, param in model.named_parameters():
if param.requires_grad:
# Parameters related to normalization and biases are treated separately
if 'norm' in name or 'bias' in name:
norm_params.append(param)
else:
regular_params.append(param)
# Return parameter groups with corresponding weight decay for norm parameters
return [{'params': regular_params}, {'params': norm_params, 'weight_decay': weight_decay_for_norm}]