use-case-and-architecture/EdgeFLite/helpers/optimizer_rmsprop.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

130 lines
5.4 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
from torch.optim import Optimizer
class CustomRMSprop(Optimizer):
"""
Implements a modified version of the RMSprop algorithm with TensorFlow-style epsilon handling.
Main differences in this implementation:
1. Epsilon is incorporated within the square root operation.
2. The moving average of squared gradients is initialized to 1.
3. The momentum buffer accumulates updates scaled by the learning rate.
"""
def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0, centered=False, decoupled_decay=False, lr_in_momentum=True):
"""
Initializes the optimizer with the provided parameters.
Arguments:
- params: iterable of parameters to optimize or dicts defining parameter groups
- lr: learning rate (default: 0.01)
- alpha: smoothing constant for the moving average (default: 0.99)
- eps: small value to prevent division by zero (default: 1e-8)
- momentum: momentum factor (default: 0)
- weight_decay: weight decay (L2 penalty) (default: 0)
- centered: if True, compute centered RMSprop (default: False)
- decoupled_decay: if True, decouples weight decay from gradient update (default: False)
- lr_in_momentum: if True, applies learning rate within the momentum buffer (default: True)
"""
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight decay: {weight_decay}")
if alpha < 0.0:
raise ValueError(f"Invalid alpha value: {alpha}")
# Store the optimizer defaults
defaults = {
'lr': lr,
'alpha': alpha,
'eps': eps,
'momentum': momentum,
'centered': centered,
'weight_decay': weight_decay,
'decoupled_decay': decoupled_decay,
'lr_in_momentum': lr_in_momentum
}
super().__init__(params, defaults)
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
- closure: A closure that reevaluates the model and returns the loss.
"""
# Get the loss value if a closure is provided
loss = closure() if closure is not None else None
# Iterate over parameter groups
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
weight_decay = group['weight_decay']
alpha = group['alpha']
eps = group['eps']
# Iterate over parameters in the group
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data # Get gradient data
if grad.is_sparse:
raise RuntimeError("RMSprop does not support sparse gradients.")
# Get the state of the parameter
state = self.state[p]
# Initialize state if it doesn't exist
if not state:
state['step'] = 0
state['square_avg'] = torch.ones_like(p.data) # Initialize moving average of squared gradients to 1
if momentum > 0:
state['momentum_buffer'] = torch.zeros_like(p.data) # Initialize momentum buffer
if group['centered']:
state['grad_avg'] = torch.zeros_like(p.data) # Initialize moving average of gradients if centered
square_avg = state['square_avg']
one_minus_alpha = 1 - alpha
state['step'] += 1 # Update the step count
# Apply weight decay
if weight_decay != 0:
if group['decoupled_decay']:
p.data.mul_(1 - lr * weight_decay) # Apply decoupled weight decay
else:
grad.add_(p.data, alpha=weight_decay) # Apply traditional weight decay
# Update the moving average of squared gradients
square_avg.add_((grad ** 2) - square_avg, alpha=one_minus_alpha)
# Compute the denominator for gradient update
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
avg = (square_avg - grad_avg ** 2).add_(eps).sqrt_() # Centered RMSprop
else:
avg = square_avg.add_(eps).sqrt_() # Standard RMSprop
# Apply momentum if needed
if momentum > 0:
buf = state['momentum_buffer']
if group['lr_in_momentum']:
buf.mul_(momentum).addcdiv_(grad, avg, value=lr) # Apply learning rate inside momentum buffer
p.data.add_(-buf)
else:
buf.mul_(momentum).addcdiv_(grad, avg) # Standard momentum update
p.data.add_(buf, alpha=-lr)
else:
p.data.addcdiv_(grad, avg, value=-lr) # Update parameter without momentum
return loss # Return the loss if closure was provided