4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
147 lines
7.0 KiB
Python
147 lines
7.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import math
|
|
|
|
class CustomScheduler:
|
|
def __init__(self, mode='cosine',
|
|
initial_lr=0.1,
|
|
num_epochs=100,
|
|
iters_per_epoch=300,
|
|
lr_milestones=None,
|
|
lr_step=100,
|
|
step_multiplier=0.1,
|
|
slow_start_epochs=0,
|
|
slow_start_lr=1e-4,
|
|
min_lr=1e-3,
|
|
multiplier=1.0,
|
|
lower_bound=-6.0,
|
|
upper_bound=3.0,
|
|
decay_factor=0.97,
|
|
decay_epochs=0.8,
|
|
staircase=True):
|
|
"""
|
|
Initialize the learning rate scheduler.
|
|
|
|
Parameters:
|
|
mode (str): Mode for learning rate adjustment ('cosine', 'poly', 'HTD', 'step', 'exponential').
|
|
initial_lr (float): Initial learning rate.
|
|
num_epochs (int): Total number of epochs.
|
|
iters_per_epoch (int): Number of iterations per epoch.
|
|
lr_milestones (list): Epoch milestones for learning rate decay in 'step' mode.
|
|
lr_step (int): Epoch step size for learning rate reduction in 'step' mode.
|
|
step_multiplier (float): Multiplication factor for learning rate reduction in 'step' mode.
|
|
slow_start_epochs (int): Number of slow start epochs for warm-up.
|
|
slow_start_lr (float): Learning rate during warm-up.
|
|
min_lr (float): Minimum learning rate limit.
|
|
multiplier (float): Multiplication factor for applying to different parameter groups.
|
|
lower_bound (float): Lower bound for the tanh function in 'HTD' mode.
|
|
upper_bound (float): Upper bound for the tanh function in 'HTD' mode.
|
|
decay_factor (float): Factor by which learning rate decays in 'exponential' mode.
|
|
decay_epochs (float): Number of epochs over which learning rate decays in 'exponential' mode.
|
|
staircase (bool): If True, apply step-wise learning rate decay in 'exponential' mode.
|
|
"""
|
|
# Ensure valid mode selection
|
|
assert mode in ['cosine', 'poly', 'HTD', 'step', 'exponential'], "Invalid mode."
|
|
|
|
# Initialize learning rate settings
|
|
self.initial_lr = initial_lr
|
|
self.current_lr = initial_lr
|
|
self.min_lr = min_lr
|
|
self.mode = mode
|
|
self.num_epochs = num_epochs
|
|
self.iters_per_epoch = iters_per_epoch
|
|
self.total_iterations = (num_epochs - slow_start_epochs) * iters_per_epoch
|
|
self.slow_start_iters = slow_start_epochs * iters_per_epoch
|
|
self.slow_start_lr = slow_start_lr
|
|
self.multiplier = multiplier
|
|
self.lr_step = lr_step
|
|
self.lr_milestones = lr_milestones
|
|
self.step_multiplier = step_multiplier
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
self.decay_factor = decay_factor
|
|
self.decay_steps = decay_epochs * iters_per_epoch
|
|
self.staircase = staircase
|
|
|
|
print(f"INFO: Using {self.mode} learning rate scheduler with {slow_start_epochs} warm-up epochs.")
|
|
|
|
def update_lr(self, optimizer, iteration, epoch):
|
|
"""Update the learning rate based on the current iteration and epoch."""
|
|
current_iter = epoch * self.iters_per_epoch + iteration
|
|
|
|
# During slow start, linearly increase the learning rate
|
|
if current_iter <= self.slow_start_iters:
|
|
lr = self.slow_start_lr + (self.initial_lr - self.slow_start_lr) * (current_iter / self.slow_start_iters)
|
|
else:
|
|
# After slow start, calculate learning rate based on the selected mode
|
|
lr = self._calculate_lr(current_iter - self.slow_start_iters)
|
|
|
|
# Ensure learning rate does not fall below the minimum limit
|
|
self.current_lr = max(lr, self.min_lr)
|
|
self._apply_lr(optimizer, self.current_lr)
|
|
|
|
def _calculate_lr(self, adjusted_iter):
|
|
"""Calculate the learning rate based on the selected scheduling mode."""
|
|
if self.mode == 'cosine':
|
|
# Cosine annealing schedule
|
|
return 0.5 * self.initial_lr * (1 + math.cos(math.pi * adjusted_iter / self.total_iterations))
|
|
elif self.mode == 'poly':
|
|
# Polynomial decay schedule
|
|
return self.initial_lr * (1 - adjusted_iter / self.total_iterations) ** 0.9
|
|
elif self.mode == 'HTD':
|
|
# Hyperbolic tangent decay schedule
|
|
ratio = adjusted_iter / self.total_iterations
|
|
return 0.5 * self.initial_lr * (1 - math.tanh(self.lower_bound + (self.upper_bound - self.lower_bound) * ratio))
|
|
elif self.mode == 'step':
|
|
# Step decay schedule
|
|
return self._step_lr(adjusted_iter)
|
|
elif self.mode == 'exponential':
|
|
# Exponential decay schedule
|
|
power = math.floor(adjusted_iter / self.decay_steps) if self.staircase else adjusted_iter / self.decay_steps
|
|
return self.initial_lr * (self.decay_factor ** power)
|
|
else:
|
|
raise NotImplementedError("Unknown learning rate mode.")
|
|
|
|
def _step_lr(self, adjusted_iter):
|
|
"""Calculate the learning rate for the 'step' mode."""
|
|
epoch = adjusted_iter // self.iters_per_epoch
|
|
# Count how many milestones or steps have passed
|
|
if self.lr_milestones:
|
|
num_steps = sum([1 for milestone in self.lr_milestones if epoch >= milestone])
|
|
else:
|
|
num_steps = epoch // self.lr_step
|
|
return self.initial_lr * (self.step_multiplier ** num_steps)
|
|
|
|
def _apply_lr(self, optimizer, lr):
|
|
"""Apply the calculated learning rate to the optimizer."""
|
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
# Apply multiplier to parameter groups beyond the first one
|
|
param_group['lr'] = lr * (self.multiplier if i > 1 else 1.0)
|
|
|
|
|
|
def adjust_hyperparameters(args):
|
|
"""Adjust the learning rate and momentum based on the batch size."""
|
|
print(f'Adjusting LR and momentum. Original LR: {args.lr}, Original momentum: {args.momentum}')
|
|
# Set standard batch size for scaling
|
|
standard_batch_size = 128 if 'cifar' in args.dataset else NotImplementedError
|
|
# Scale momentum and learning rate
|
|
args.momentum = args.momentum ** (args.batch_size / standard_batch_size)
|
|
args.lr *= (args.batch_size / standard_batch_size)
|
|
print(f'Adjusted LR: {args.lr}, Adjusted momentum: {args.momentum}')
|
|
return args
|
|
|
|
|
|
def separate_parameters(model, weight_decay_for_norm=0):
|
|
"""Separate the model parameters into two groups: regular parameters and norm-based parameters."""
|
|
regular_params, norm_params = [], []
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
# Parameters related to normalization and biases are treated separately
|
|
if 'norm' in name or 'bias' in name:
|
|
norm_params.append(param)
|
|
else:
|
|
regular_params.append(param)
|
|
# Return parameter groups with corresponding weight decay for norm parameters
|
|
return [{'params': regular_params}, {'params': norm_params, 'weight_decay': weight_decay_for_norm}]
|