4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
132 lines
5.5 KiB
Python
132 lines
5.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class PassThrough(nn.Module):
|
|
"""
|
|
A placeholder module that simply returns the input tensor unchanged.
|
|
"""
|
|
def __init__(self, **kwargs):
|
|
super(PassThrough, self).__init__()
|
|
|
|
def forward(self, input_tensor):
|
|
return input_tensor
|
|
|
|
|
|
class LayerNormalization2D(nn.Module):
|
|
"""
|
|
A custom layer normalization module for 2D inputs (typically used for
|
|
convolutional layers). It optionally applies learned scaling (weight)
|
|
and shifting (bias) parameters.
|
|
|
|
Arguments:
|
|
epsilon: A small value to avoid division by zero.
|
|
use_weight: Whether to learn and apply weight parameters.
|
|
use_bias: Whether to learn and apply bias parameters.
|
|
"""
|
|
def __init__(self, epsilon=1e-05, use_weight=True, use_bias=True, **kwargs):
|
|
super(LayerNormalization2D, self).__init__()
|
|
|
|
self.epsilon = epsilon
|
|
self.use_weight = use_weight
|
|
self.use_bias = use_bias
|
|
|
|
def forward(self, input_tensor):
|
|
# Initialize weight and bias parameters if they are not nn.Parameter instances
|
|
if (not isinstance(self.use_weight, nn.parameter.Parameter) and
|
|
not isinstance(self.use_bias, nn.parameter.Parameter) and
|
|
(self.use_weight or self.use_bias)):
|
|
self._initialize_parameters(input_tensor)
|
|
|
|
# Apply layer normalization
|
|
return nn.functional.layer_norm(input_tensor, input_tensor.shape[1:],
|
|
weight=self.use_weight, bias=self.use_bias,
|
|
eps=self.epsilon)
|
|
|
|
def _initialize_parameters(self, input_tensor):
|
|
"""
|
|
Initialize weight and bias parameters for layer normalization.
|
|
Arguments:
|
|
input_tensor: The input tensor to the normalization layer.
|
|
"""
|
|
channels, height, width = input_tensor.shape[1:]
|
|
param_shape = [channels, height, width]
|
|
|
|
# Initialize weight parameter if applicable
|
|
if self.use_weight:
|
|
self.use_weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
|
else:
|
|
self.register_parameter('use_weight', None)
|
|
|
|
# Initialize bias parameter if applicable
|
|
if self.use_bias:
|
|
self.use_bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
|
else:
|
|
self.register_parameter('use_bias', None)
|
|
|
|
|
|
class NormalizationLayer(nn.Module):
|
|
"""
|
|
A flexible normalization layer that supports different types of normalization
|
|
(batch, group, layer, instance, or none). This class is a wrapper that selects
|
|
the appropriate normalization technique based on the norm_type argument.
|
|
|
|
Arguments:
|
|
norm_type: The type of normalization to apply ('batch', 'group', 'layer', 'instance', or 'none').
|
|
epsilon: A small value to avoid division by zero (Default: 1e-05).
|
|
momentum: Momentum for updating running statistics (Default: 0.1, applicable for batch norm).
|
|
use_weight: Whether to learn weight parameters (Default: True).
|
|
use_bias: Whether to learn bias parameters (Default: True).
|
|
track_stats: Whether to track running statistics (Default: True, applicable for batch norm).
|
|
group_norm_groups: Number of groups to use for group normalization (Default: 32).
|
|
"""
|
|
def __init__(self, norm_type='batch', epsilon=1e-05, momentum=0.1,
|
|
use_weight=True, use_bias=True, track_stats=True, group_norm_groups=32, **kwargs):
|
|
super(NormalizationLayer, self).__init__()
|
|
|
|
if norm_type not in ['batch', 'group', 'layer', 'instance', 'none']:
|
|
raise ValueError('Unsupported norm_type: {}. Supported options: '
|
|
'"batch" | "group" | "layer" | "instance" | "none".'.format(norm_type))
|
|
|
|
self.norm_type = norm_type
|
|
self.epsilon = epsilon
|
|
self.momentum = momentum
|
|
self.use_weight = use_weight
|
|
self.use_bias = use_bias
|
|
self.affine = self.use_weight and self.use_bias # Check if affine apply_transformationation is needed
|
|
self.track_stats = track_stats
|
|
self.group_norm_groups = group_norm_groups
|
|
|
|
def forward(self, num_features):
|
|
"""
|
|
Select and apply the appropriate normalization technique based on the norm_type.
|
|
|
|
Arguments:
|
|
num_features: The number of input channels or features.
|
|
Returns:
|
|
A normalization layer corresponding to the norm_type.
|
|
"""
|
|
if self.norm_type == 'batch':
|
|
# Apply Batch Normalization
|
|
normalizer = nn.BatchNorm2d(num_features=num_features, eps=self.epsilon,
|
|
momentum=self.momentum, affine=self.affine,
|
|
track_running_stats=self.track_stats)
|
|
elif self.norm_type == 'group':
|
|
# Apply Group Normalization
|
|
normalizer = nn.GroupNorm(self.group_norm_groups, num_features,
|
|
eps=self.epsilon, affine=self.affine)
|
|
elif self.norm_type == 'layer':
|
|
# Apply Layer Normalization
|
|
normalizer = LayerNormalization2D(epsilon=self.epsilon, use_weight=self.use_weight, use_bias=self.use_bias)
|
|
elif self.norm_type == 'instance':
|
|
# Apply Instance Normalization
|
|
normalizer = nn.InstanceNorm2d(num_features, eps=self.epsilon, affine=self.affine)
|
|
else:
|
|
# No normalization applied, just pass the input through
|
|
normalizer = PassThrough()
|
|
|
|
return normalizer
|