4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
191 lines
6.6 KiB
Python
191 lines
6.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import os
|
|
import shutil
|
|
import torch
|
|
|
|
def store_model(state, best_model, directory, filename='checkpoint.pth'):
|
|
"""
|
|
Stores the model checkpoint in the specified directory. If it's the best model,
|
|
it saves another copy named 'best_model.pth'.
|
|
|
|
Args:
|
|
state (dict): Model's state dictionary.
|
|
best_model (bool): Flag indicating if the current model is the best.
|
|
directory (str): Directory where the model is saved.
|
|
filename (str): Name of the file to save the checkpoint (default 'checkpoint.pth').
|
|
"""
|
|
save_path = os.path.join(directory, filename)
|
|
torch.save(state, save_path)
|
|
if best_model:
|
|
# If the current model is the best, save another copy as 'best_model.pth'
|
|
shutil.copy(save_path, os.path.join(directory, 'best_model.pth'))
|
|
|
|
def save_main_client_model(state, best_model, directory):
|
|
"""
|
|
Saves the model for the main client if it's the best one.
|
|
|
|
Args:
|
|
state (dict): Model's state dictionary.
|
|
best_model (bool): Flag indicating if the current model is the best.
|
|
directory (str): Directory where the model is saved.
|
|
"""
|
|
if best_model:
|
|
print("Saving the best main client model")
|
|
torch.save(state, os.path.join(directory, 'main_client_best.pth'))
|
|
|
|
def save_proxy_clients_model(state, best_model, directory):
|
|
"""
|
|
Saves the model for proxy clients if it's the best one.
|
|
|
|
Args:
|
|
state (dict): Model's state dictionary.
|
|
best_model (bool): Flag indicating if the current model is the best.
|
|
directory (str): Directory where the model is saved.
|
|
"""
|
|
if best_model:
|
|
print("Saving the best proxy client model")
|
|
torch.save(state, os.path.join(directory, 'proxy_clients_best.pth'))
|
|
|
|
def save_individual_client_model(state, best_model, directory):
|
|
"""
|
|
Saves the model for individual clients if it's the best one.
|
|
|
|
Args:
|
|
state (dict): Model's state dictionary.
|
|
best_model (bool): Flag indicating if the current model is the best.
|
|
directory (str): Directory where the model is saved.
|
|
"""
|
|
if best_model:
|
|
print("Saving the best client model")
|
|
torch.save(state, os.path.join(directory, 'client_best.pth'))
|
|
|
|
def save_server_model(state, best_model, directory):
|
|
"""
|
|
Saves the model for the server if it's the best one.
|
|
|
|
Args:
|
|
state (dict): Model's state dictionary.
|
|
best_model (bool): Flag indicating if the current model is the best.
|
|
directory (str): Directory where the model is saved.
|
|
"""
|
|
if best_model:
|
|
print("Saving the best server model")
|
|
torch.save(state, os.path.join(directory, 'server_best.pth'))
|
|
|
|
class MetricTracker(object):
|
|
"""
|
|
A helper class to track and compute the average of a given metric.
|
|
|
|
Args:
|
|
metric_name (str): Name of the metric to track.
|
|
fmt (str): Format for printing metric values (default ':f').
|
|
"""
|
|
def __init__(self, metric_name, fmt=':f'):
|
|
self.metric_name = metric_name
|
|
self.fmt = fmt
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
"""Resets all metric counters."""
|
|
self.current_value = 0
|
|
self.total_sum = 0
|
|
self.count = 0
|
|
self.average = 0
|
|
|
|
def update(self, value, n=1):
|
|
"""
|
|
Updates the metric value.
|
|
|
|
Args:
|
|
value (float): New value of the metric.
|
|
n (int): Weight or count for the value (default 1).
|
|
"""
|
|
self.current_value = value
|
|
self.total_sum += value * n
|
|
self.count += n
|
|
self.average = self.total_sum / self.count
|
|
|
|
def __str__(self):
|
|
"""Returns the formatted metric string showing current value and average."""
|
|
return f'{self.metric_name} {self.current_value{self.fmt}} ({self.average{self.fmt}})'
|
|
|
|
class ProgressLogger(object):
|
|
"""
|
|
A class to log and display the progress of training/testing over multiple batches.
|
|
|
|
Args:
|
|
total_batches (int): Total number of batches.
|
|
*metrics (MetricTracker): Metrics to log during the process.
|
|
prefix (str): Prefix for the progress log (default "Progress:").
|
|
"""
|
|
def __init__(self, total_batches, *metrics, prefix="Progress:"):
|
|
self.batch_format = self._get_batch_format(total_batches)
|
|
self.metrics = metrics
|
|
self.prefix = prefix
|
|
|
|
def log(self, batch_idx):
|
|
"""
|
|
Logs the current progress of training/testing.
|
|
|
|
Args:
|
|
batch_idx (int): The current batch index.
|
|
"""
|
|
output = [self.prefix + self.batch_format.format(batch_idx)]
|
|
output += [str(metric) for metric in self.metrics]
|
|
print(' | '.join(output))
|
|
|
|
def _get_batch_format(self, total_batches):
|
|
"""Creates a format string to display the batch index."""
|
|
num_digits = len(str(total_batches))
|
|
return '[{:' + str(num_digits) + 'd}/{}]'.format(total_batches)
|
|
|
|
def compute_accuracy(prediction, target, top_k=(1,)):
|
|
"""
|
|
Computes the accuracy for the top-k predictions.
|
|
|
|
Args:
|
|
prediction (Tensor): Model predictions.
|
|
target (Tensor): Ground truth labels.
|
|
top_k (tuple): Tuple of top-k values to consider for accuracy (default (1,)).
|
|
|
|
Returns:
|
|
List[Tensor]: List of accuracies for each top-k value.
|
|
"""
|
|
with torch.no_grad():
|
|
max_k = max(top_k)
|
|
batch_size = target.size(0)
|
|
|
|
# Get the top-k predictions
|
|
_, top_predictions = prediction.topk(max_k, 1, largest=True, sorted=True)
|
|
top_predictions = top_predictions.t()
|
|
|
|
# Compare top-k predictions with targets
|
|
correct_predictions = top_predictions.eq(target.view(1, -1).expand_as(top_predictions))
|
|
|
|
accuracy_results = []
|
|
for k in top_k:
|
|
# Count the number of correct predictions within the top-k
|
|
correct_k = correct_predictions[:k].view(-1).float().sum(0, keepdim=True)
|
|
accuracy_results.append(correct_k.mul_(100.0 / batch_size))
|
|
return accuracy_results
|
|
|
|
def count_model_parameters(model, trainable_only=False):
|
|
"""
|
|
Counts the total number of parameters in the model.
|
|
|
|
Args:
|
|
model (nn.Module): The PyTorch model.
|
|
trainable_only (bool): Whether to count only trainable parameters (default False).
|
|
|
|
Returns:
|
|
int: Total number of parameters in the model.
|
|
"""
|
|
if trainable_only:
|
|
# Count only the parameters that require gradients (trainable parameters)
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
# Count all parameters (trainable and non-trainable)
|
|
return sum(p.numel() for p in model.parameters())
|