use-case-and-architecture/EdgeFLite/data_collection/cifar100_noniid.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

221 lines
8.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
#### Get CIFAR-100 dataset in X and Y form
import torchvision
import numpy as np
import random
import torch
from torchvision import apply_transformations
from torch.utils.data import DataLoader, Dataset
from .cifar10_non_iid import *
# Set random seeds for reproducibility
np.random.seed(68)
random.seed(68)
def get_cifar100(data_dir):
'''
Load and return CIFAR-100 train/test data and labels as numpy arrays.
Parameters:
data_dir (str): Directory where the CIFAR-100 dataset will be downloaded/saved.
Returns:
x_train (ndarray): Training data.
y_train (ndarray): Training labels.
x_test (ndarray): Test data.
y_test (ndarray): Test labels.
'''
# Download CIFAR-100 training and test datasets
data_train = torchvision.datasets.CIFAR100(data_dir, train=True, download=True)
data_test = torchvision.datasets.CIFAR100(data_dir, train=False, download=True)
# Transpose data for proper channel order and convert labels to numpy arrays
x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets)
x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets)
return x_train, y_train, x_test, y_test
def split_cf100_real_world_images(data, labels, n_clients=100, verbose=True):
'''
Splits data and labels among n_clients to simulate a non-IID distribution.
Parameters:
data (ndarray): Dataset images [n_data x shape].
labels (ndarray): Dataset labels [n_data].
n_clients (int): Number of clients to split the data among.
verbose (bool): Print detailed information if True.
Returns:
clients_split (ndarray): Split data and labels for each client.
'''
n_labels = np.max(labels) + 1 # Number of unique labels/classes
def divide_into_sections(n, m):
'''Return m random integers that sum up to n.'''
result = [1] * m
for _ in range(n - m):
result[random.randint(0, m - 1)] += 1
return result
# Shuffle and partition classes
n_classes = len(set(labels)) # Number of unique classes
classes = list(range(n_classes))
np.random.shuffle(classes) # Shuffle class indices
label_indices = [list(np.where(labels == class_)[0]) for class_ in classes] # Indices of each class in labels
# Define number of classes for each client (randomized)
tmp = [np.random.randint(1, 100) for _ in range(n_clients)]
total_partition = sum(tmp)
class_partition = divide_into_sections(total_partition, len(classes)) # Partition classes randomly
# Split class indices among clients
class_partition = sorted(class_partition, reverse=True)
class_partition_split = {}
for idx, class_ in enumerate(classes):
# Split each class' indices according to the partition
class_partition_split[class_] = [list(i) for i in np.array_split(label_indices[idx], class_partition[idx])]
clients_split = []
for i in range(n_clients):
n = tmp[i] # Number of classes for this client
indices = []
j = 0
# Assign class data to the client
while n > 0:
class_ = classes[j]
if class_partition_split[class_]:
indices.extend(class_partition_split[class_].pop()) # Add indices of the class to the client
n -= 1
j += 1
clients_split.append([data[indices], labels[indices]]) # Add client's data split
# Re-sort classes based on available data to balance further splits
classes = sorted(classes, key=lambda x: len(class_partition_split[x]), reverse=True)
# Raise error if client partition criteria cannot be met
if n > 0:
raise ValueError("Unable to fulfill the client partition criteria.")
# Verbose option to print split information
if verbose:
display_data_split(clients_split)
return np.array(clients_split)
def display_data_split(clients_split):
'''Print the split information of the dataset for each client.'''
print("Data split:")
for i, client in enumerate(clients_split):
split = np.sum(client[1].reshape(1, -1) == np.arange(np.max(client[1]) + 1).reshape(-1, 1), axis=1)
print(f" - Client {i}: {split}")
print()
def get_default_data_apply_transformations_cf100(train=True, verbose=True):
'''
Return default data apply_transformationations for CIFAR-100.
Parameters:
train (bool): Whether to apply apply_transformationations for training data.
verbose (bool): Print apply_transformationation details if True.
Returns:
apply_transformations_train (Compose): Training apply_transformationations.
apply_transformations_eval (Compose): Evaluation (test) apply_transformationations.
'''
# Define apply_transformationations for training data
apply_transformations_train = {
'cifar100': apply_transformations.Compose([
apply_transformations.ToPILImage(),
apply_transformations.RandomCrop(32, padding=4),
apply_transformations.RandomHorizontalFlip(),
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
}
# Define apply_transformationations for test data
apply_transformations_eval = {
'cifar100': apply_transformations.Compose([
apply_transformations.ToPILImage(),
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
}
# Verbose option to print apply_transformationation steps
if verbose:
print("\nData preprocessing:")
for apply_transformationation in apply_transformations_train['cifar100'].apply_transformations:
print(f' - {apply_transformationation}')
print()
return apply_transformations_train['cifar100'], apply_transformations_eval['cifar100']
def obtain_data_loaders_train_cf100(data_dir, n_clients, batch_size, classes_per_client=10, verbose=True,
apply_transformations_train=None, apply_transformations_eval=None, non_iid=None, split_factor=1):
'''
Return data loaders for training on CIFAR-100.
Parameters:
data_dir (str): Directory where the CIFAR-100 dataset will be saved.
n_clients (int): Number of clients for splitting the dataset.
batch_size (int): Batch size for each data loader.
classes_per_client (int): Number of classes per client.
verbose (bool): Print detailed information if True.
apply_transformations_train (Compose): apply_transformationations for training data.
apply_transformations_eval (Compose): apply_transformationations for evaluation data.
non_iid (str): Strategy to create a non-IID dataset split.
split_factor (float): Factor to control the degree of splitting.
Returns:
client_loaders (list): Data loaders for each client.
'''
x_train, y_train, _, _ = get_cifar100(data_dir)
# Verbose option to print dataset statistics
if verbose:
print_image_data_stats_train(x_train, y_train)
# Split data according to non-IID strategy (e.g., quantity_skew)
split = None
if non_iid == 'quantity_skew':
split = split_cf100_real_world_images(x_train, y_train, n_clients=n_clients, verbose=verbose)
split_tmp = shuffle_list(split)
# Create DataLoaders for each client
client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor),
batch_size=batch_size, shuffle=True) for x, y in split_tmp]
return client_loaders
def obtain_data_loaders_test_cf100(data_dir, batch_size, verbose=True, apply_transformations_eval=None):
'''
Return data loaders for testing on CIFAR-100.
Parameters:
data_dir (str): Directory where the CIFAR-100 dataset will be saved.
batch_size (int): Batch size for the test data loader.
verbose (bool): Print detailed information if True.
apply_transformations_eval (Compose): apply_transformationations for evaluation data.
Returns:
test_loader (DataLoader): Test data loader.
'''
_, _, x_test, y_test = get_cifar100(data_dir)
# Verbose option to print dataset statistics
if verbose:
print_image_data_stats_test(x_test, y_test)
# Create DataLoader for the test dataset
test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval, split_factor=1),
batch_size=100, shuffle=False)
return test_loader