use-case-and-architecture/EdgeFLite/helpers/normalization.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

132 lines
5.5 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import torch.nn as nn
class PassThrough(nn.Module):
"""
A placeholder module that simply returns the input tensor unchanged.
"""
def __init__(self, **kwargs):
super(PassThrough, self).__init__()
def forward(self, input_tensor):
return input_tensor
class LayerNormalization2D(nn.Module):
"""
A custom layer normalization module for 2D inputs (typically used for
convolutional layers). It optionally applies learned scaling (weight)
and shifting (bias) parameters.
Arguments:
epsilon: A small value to avoid division by zero.
use_weight: Whether to learn and apply weight parameters.
use_bias: Whether to learn and apply bias parameters.
"""
def __init__(self, epsilon=1e-05, use_weight=True, use_bias=True, **kwargs):
super(LayerNormalization2D, self).__init__()
self.epsilon = epsilon
self.use_weight = use_weight
self.use_bias = use_bias
def forward(self, input_tensor):
# Initialize weight and bias parameters if they are not nn.Parameter instances
if (not isinstance(self.use_weight, nn.parameter.Parameter) and
not isinstance(self.use_bias, nn.parameter.Parameter) and
(self.use_weight or self.use_bias)):
self._initialize_parameters(input_tensor)
# Apply layer normalization
return nn.functional.layer_norm(input_tensor, input_tensor.shape[1:],
weight=self.use_weight, bias=self.use_bias,
eps=self.epsilon)
def _initialize_parameters(self, input_tensor):
"""
Initialize weight and bias parameters for layer normalization.
Arguments:
input_tensor: The input tensor to the normalization layer.
"""
channels, height, width = input_tensor.shape[1:]
param_shape = [channels, height, width]
# Initialize weight parameter if applicable
if self.use_weight:
self.use_weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
else:
self.register_parameter('use_weight', None)
# Initialize bias parameter if applicable
if self.use_bias:
self.use_bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
else:
self.register_parameter('use_bias', None)
class NormalizationLayer(nn.Module):
"""
A flexible normalization layer that supports different types of normalization
(batch, group, layer, instance, or none). This class is a wrapper that selects
the appropriate normalization technique based on the norm_type argument.
Arguments:
norm_type: The type of normalization to apply ('batch', 'group', 'layer', 'instance', or 'none').
epsilon: A small value to avoid division by zero (Default: 1e-05).
momentum: Momentum for updating running statistics (Default: 0.1, applicable for batch norm).
use_weight: Whether to learn weight parameters (Default: True).
use_bias: Whether to learn bias parameters (Default: True).
track_stats: Whether to track running statistics (Default: True, applicable for batch norm).
group_norm_groups: Number of groups to use for group normalization (Default: 32).
"""
def __init__(self, norm_type='batch', epsilon=1e-05, momentum=0.1,
use_weight=True, use_bias=True, track_stats=True, group_norm_groups=32, **kwargs):
super(NormalizationLayer, self).__init__()
if norm_type not in ['batch', 'group', 'layer', 'instance', 'none']:
raise ValueError('Unsupported norm_type: {}. Supported options: '
'"batch" | "group" | "layer" | "instance" | "none".'.format(norm_type))
self.norm_type = norm_type
self.epsilon = epsilon
self.momentum = momentum
self.use_weight = use_weight
self.use_bias = use_bias
self.affine = self.use_weight and self.use_bias # Check if affine apply_transformationation is needed
self.track_stats = track_stats
self.group_norm_groups = group_norm_groups
def forward(self, num_features):
"""
Select and apply the appropriate normalization technique based on the norm_type.
Arguments:
num_features: The number of input channels or features.
Returns:
A normalization layer corresponding to the norm_type.
"""
if self.norm_type == 'batch':
# Apply Batch Normalization
normalizer = nn.BatchNorm2d(num_features=num_features, eps=self.epsilon,
momentum=self.momentum, affine=self.affine,
track_running_stats=self.track_stats)
elif self.norm_type == 'group':
# Apply Group Normalization
normalizer = nn.GroupNorm(self.group_norm_groups, num_features,
eps=self.epsilon, affine=self.affine)
elif self.norm_type == 'layer':
# Apply Layer Normalization
normalizer = LayerNormalization2D(epsilon=self.epsilon, use_weight=self.use_weight, use_bias=self.use_bias)
elif self.norm_type == 'instance':
# Apply Instance Normalization
normalizer = nn.InstanceNorm2d(num_features, eps=self.epsilon, affine=self.affine)
else:
# No normalization applied, just pass the input through
normalizer = PassThrough()
return normalizer