a877aed45f
Change-Id: I16cd7730c1e0732253ac52f51010f6b813295aa7
22 lines
721 B
Python
22 lines
721 B
Python
"""
|
|
Author: Weisen Pan
|
|
Date: 2023-10-24
|
|
"""
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
class WarmUpLR(_LRScheduler):
|
|
def __init__(self, optimizer, total_iters, last_epoch=-1):
|
|
self.total_iters = total_iters
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
|
|
|
|
class DownLR(_LRScheduler):
|
|
def __init__(self, optimizer, total_iters, last_epoch=-1):
|
|
self.total_iters = total_iters
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
return [base_lr * (self.total_iters - self.last_epoch) / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
|