Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

49 lines
2.5 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define the SmoothEntropyLoss class, which inherits from nn.Module
class SmoothEntropyLoss(nn.Module):
def __init__(self, smoothing=0.1, reduction='mean'):
# Initialize the parent class (nn.Module) and set the smoothing factor and reduction method
super(SmoothEntropyLoss, self).__init__()
self.smoothing = smoothing # Label smoothing factor
self.reduction_method = reduction # Reduction method to apply to the loss
def forward(self, predictions, targets):
# Ensure that the batch sizes of predictions and targets match
if predictions.shape[0] != targets.shape[0]:
raise ValueError(f"Batch size of predictions ({predictions.shape[0]}) does not match targets ({targets.shape[0]}).")
# Ensure that the predictions tensor has at least 2 dimensions (batch_size x num_classes)
if predictions.dim() < 2:
raise ValueError(f"Predictions should have at least 2 dimensions, got {predictions.dim()}.")
# Get the number of classes from the last dimension of predictions (num_classes)
num_classes = predictions.size(-1)
# Convert targets (class indices) to one-hot encoded format
target_one_hot = F.one_hot(targets, num_classes=num_classes).type_as(predictions)
# Apply label smoothing: smooth the one-hot encoded targets by distributing some probability mass across all classes
smooth_targets = target_one_hot * (1.0 - self.smoothing) + (self.smoothing / num_classes)
# Compute the log probabilities of predictions using softmax (log-softmax for numerical stability)
log_probabilities = F.log_softmax(predictions, dim=-1)
# Compute the per-sample loss by multiplying log probabilities with the smoothed targets and summing across classes
loss_per_sample = -torch.sum(log_probabilities * smooth_targets, dim=-1)
# Apply the specified reduction method to the computed loss
if self.reduction_method == 'none':
return loss_per_sample # Return the unreduced loss for each sample
elif self.reduction_method == 'sum':
return torch.sum(loss_per_sample) # Return the sum of the losses over all samples
elif self.reduction_method == 'mean':
return torch.mean(loss_per_sample) # Return the mean loss over all samples
else:
raise ValueError(f"Invalid reduction option: {self.reduction_method}. Expected 'none', 'sum', or 'mean'.")