4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
130 lines
5.4 KiB
Python
130 lines
5.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
class CustomRMSprop(Optimizer):
|
|
"""
|
|
Implements a modified version of the RMSprop algorithm with TensorFlow-style epsilon handling.
|
|
|
|
Main differences in this implementation:
|
|
1. Epsilon is incorporated within the square root operation.
|
|
2. The moving average of squared gradients is initialized to 1.
|
|
3. The momentum buffer accumulates updates scaled by the learning rate.
|
|
"""
|
|
|
|
def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0, centered=False, decoupled_decay=False, lr_in_momentum=True):
|
|
"""
|
|
Initializes the optimizer with the provided parameters.
|
|
|
|
Arguments:
|
|
- params: iterable of parameters to optimize or dicts defining parameter groups
|
|
- lr: learning rate (default: 0.01)
|
|
- alpha: smoothing constant for the moving average (default: 0.99)
|
|
- eps: small value to prevent division by zero (default: 1e-8)
|
|
- momentum: momentum factor (default: 0)
|
|
- weight_decay: weight decay (L2 penalty) (default: 0)
|
|
- centered: if True, compute centered RMSprop (default: False)
|
|
- decoupled_decay: if True, decouples weight decay from gradient update (default: False)
|
|
- lr_in_momentum: if True, applies learning rate within the momentum buffer (default: True)
|
|
"""
|
|
if lr < 0.0:
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
if eps < 0.0:
|
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
if momentum < 0.0:
|
|
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
if weight_decay < 0.0:
|
|
raise ValueError(f"Invalid weight decay: {weight_decay}")
|
|
if alpha < 0.0:
|
|
raise ValueError(f"Invalid alpha value: {alpha}")
|
|
|
|
# Store the optimizer defaults
|
|
defaults = {
|
|
'lr': lr,
|
|
'alpha': alpha,
|
|
'eps': eps,
|
|
'momentum': momentum,
|
|
'centered': centered,
|
|
'weight_decay': weight_decay,
|
|
'decoupled_decay': decoupled_decay,
|
|
'lr_in_momentum': lr_in_momentum
|
|
}
|
|
super().__init__(params, defaults)
|
|
|
|
def step(self, closure=None):
|
|
"""
|
|
Performs a single optimization step.
|
|
|
|
Arguments:
|
|
- closure: A closure that reevaluates the model and returns the loss.
|
|
"""
|
|
# Get the loss value if a closure is provided
|
|
loss = closure() if closure is not None else None
|
|
|
|
# Iterate over parameter groups
|
|
for group in self.param_groups:
|
|
lr = group['lr']
|
|
momentum = group['momentum']
|
|
weight_decay = group['weight_decay']
|
|
alpha = group['alpha']
|
|
eps = group['eps']
|
|
|
|
# Iterate over parameters in the group
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
grad = p.grad.data # Get gradient data
|
|
if grad.is_sparse:
|
|
raise RuntimeError("RMSprop does not support sparse gradients.")
|
|
|
|
# Get the state of the parameter
|
|
state = self.state[p]
|
|
|
|
# Initialize state if it doesn't exist
|
|
if not state:
|
|
state['step'] = 0
|
|
state['square_avg'] = torch.ones_like(p.data) # Initialize moving average of squared gradients to 1
|
|
if momentum > 0:
|
|
state['momentum_buffer'] = torch.zeros_like(p.data) # Initialize momentum buffer
|
|
if group['centered']:
|
|
state['grad_avg'] = torch.zeros_like(p.data) # Initialize moving average of gradients if centered
|
|
|
|
square_avg = state['square_avg']
|
|
one_minus_alpha = 1 - alpha
|
|
state['step'] += 1 # Update the step count
|
|
|
|
# Apply weight decay
|
|
if weight_decay != 0:
|
|
if group['decoupled_decay']:
|
|
p.data.mul_(1 - lr * weight_decay) # Apply decoupled weight decay
|
|
else:
|
|
grad.add_(p.data, alpha=weight_decay) # Apply traditional weight decay
|
|
|
|
# Update the moving average of squared gradients
|
|
square_avg.add_((grad ** 2) - square_avg, alpha=one_minus_alpha)
|
|
|
|
# Compute the denominator for gradient update
|
|
if group['centered']:
|
|
grad_avg = state['grad_avg']
|
|
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
|
|
avg = (square_avg - grad_avg ** 2).add_(eps).sqrt_() # Centered RMSprop
|
|
else:
|
|
avg = square_avg.add_(eps).sqrt_() # Standard RMSprop
|
|
|
|
# Apply momentum if needed
|
|
if momentum > 0:
|
|
buf = state['momentum_buffer']
|
|
if group['lr_in_momentum']:
|
|
buf.mul_(momentum).addcdiv_(grad, avg, value=lr) # Apply learning rate inside momentum buffer
|
|
p.data.add_(-buf)
|
|
else:
|
|
buf.mul_(momentum).addcdiv_(grad, avg) # Standard momentum update
|
|
p.data.add_(buf, alpha=-lr)
|
|
else:
|
|
p.data.addcdiv_(grad, avg, value=-lr) # Update parameter without momentum
|
|
|
|
return loss # Return the loss if closure was provided
|