use-case-and-architecture/EdgeFLite/helpers/evaluation_metrics.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

191 lines
6.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
import shutil
import torch
def store_model(state, best_model, directory, filename='checkpoint.pth'):
"""
Stores the model checkpoint in the specified directory. If it's the best model,
it saves another copy named 'best_model.pth'.
Args:
state (dict): Model's state dictionary.
best_model (bool): Flag indicating if the current model is the best.
directory (str): Directory where the model is saved.
filename (str): Name of the file to save the checkpoint (default 'checkpoint.pth').
"""
save_path = os.path.join(directory, filename)
torch.save(state, save_path)
if best_model:
# If the current model is the best, save another copy as 'best_model.pth'
shutil.copy(save_path, os.path.join(directory, 'best_model.pth'))
def save_main_client_model(state, best_model, directory):
"""
Saves the model for the main client if it's the best one.
Args:
state (dict): Model's state dictionary.
best_model (bool): Flag indicating if the current model is the best.
directory (str): Directory where the model is saved.
"""
if best_model:
print("Saving the best main client model")
torch.save(state, os.path.join(directory, 'main_client_best.pth'))
def save_proxy_clients_model(state, best_model, directory):
"""
Saves the model for proxy clients if it's the best one.
Args:
state (dict): Model's state dictionary.
best_model (bool): Flag indicating if the current model is the best.
directory (str): Directory where the model is saved.
"""
if best_model:
print("Saving the best proxy client model")
torch.save(state, os.path.join(directory, 'proxy_clients_best.pth'))
def save_individual_client_model(state, best_model, directory):
"""
Saves the model for individual clients if it's the best one.
Args:
state (dict): Model's state dictionary.
best_model (bool): Flag indicating if the current model is the best.
directory (str): Directory where the model is saved.
"""
if best_model:
print("Saving the best client model")
torch.save(state, os.path.join(directory, 'client_best.pth'))
def save_server_model(state, best_model, directory):
"""
Saves the model for the server if it's the best one.
Args:
state (dict): Model's state dictionary.
best_model (bool): Flag indicating if the current model is the best.
directory (str): Directory where the model is saved.
"""
if best_model:
print("Saving the best server model")
torch.save(state, os.path.join(directory, 'server_best.pth'))
class MetricTracker(object):
"""
A helper class to track and compute the average of a given metric.
Args:
metric_name (str): Name of the metric to track.
fmt (str): Format for printing metric values (default ':f').
"""
def __init__(self, metric_name, fmt=':f'):
self.metric_name = metric_name
self.fmt = fmt
self.reset()
def reset(self):
"""Resets all metric counters."""
self.current_value = 0
self.total_sum = 0
self.count = 0
self.average = 0
def update(self, value, n=1):
"""
Updates the metric value.
Args:
value (float): New value of the metric.
n (int): Weight or count for the value (default 1).
"""
self.current_value = value
self.total_sum += value * n
self.count += n
self.average = self.total_sum / self.count
def __str__(self):
"""Returns the formatted metric string showing current value and average."""
return f'{self.metric_name} {self.current_value{self.fmt}} ({self.average{self.fmt}})'
class ProgressLogger(object):
"""
A class to log and display the progress of training/testing over multiple batches.
Args:
total_batches (int): Total number of batches.
*metrics (MetricTracker): Metrics to log during the process.
prefix (str): Prefix for the progress log (default "Progress:").
"""
def __init__(self, total_batches, *metrics, prefix="Progress:"):
self.batch_format = self._get_batch_format(total_batches)
self.metrics = metrics
self.prefix = prefix
def log(self, batch_idx):
"""
Logs the current progress of training/testing.
Args:
batch_idx (int): The current batch index.
"""
output = [self.prefix + self.batch_format.format(batch_idx)]
output += [str(metric) for metric in self.metrics]
print(' | '.join(output))
def _get_batch_format(self, total_batches):
"""Creates a format string to display the batch index."""
num_digits = len(str(total_batches))
return '[{:' + str(num_digits) + 'd}/{}]'.format(total_batches)
def compute_accuracy(prediction, target, top_k=(1,)):
"""
Computes the accuracy for the top-k predictions.
Args:
prediction (Tensor): Model predictions.
target (Tensor): Ground truth labels.
top_k (tuple): Tuple of top-k values to consider for accuracy (default (1,)).
Returns:
List[Tensor]: List of accuracies for each top-k value.
"""
with torch.no_grad():
max_k = max(top_k)
batch_size = target.size(0)
# Get the top-k predictions
_, top_predictions = prediction.topk(max_k, 1, largest=True, sorted=True)
top_predictions = top_predictions.t()
# Compare top-k predictions with targets
correct_predictions = top_predictions.eq(target.view(1, -1).expand_as(top_predictions))
accuracy_results = []
for k in top_k:
# Count the number of correct predictions within the top-k
correct_k = correct_predictions[:k].view(-1).float().sum(0, keepdim=True)
accuracy_results.append(correct_k.mul_(100.0 / batch_size))
return accuracy_results
def count_model_parameters(model, trainable_only=False):
"""
Counts the total number of parameters in the model.
Args:
model (nn.Module): The PyTorch model.
trainable_only (bool): Whether to count only trainable parameters (default False).
Returns:
int: Total number of parameters in the model.
"""
if trainable_only:
# Count only the parameters that require gradients (trainable parameters)
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Count all parameters (trainable and non-trainable)
return sum(p.numel() for p in model.parameters())