4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
221 lines
8.6 KiB
Python
221 lines
8.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
#### Get CIFAR-100 dataset in X and Y form
|
|
import torchvision
|
|
import numpy as np
|
|
import random
|
|
import torch
|
|
from torchvision import apply_transformations
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from .cifar10_non_iid import *
|
|
|
|
# Set random seeds for reproducibility
|
|
np.random.seed(68)
|
|
random.seed(68)
|
|
|
|
def get_cifar100(data_dir):
|
|
'''
|
|
Load and return CIFAR-100 train/test data and labels as numpy arrays.
|
|
|
|
Parameters:
|
|
data_dir (str): Directory where the CIFAR-100 dataset will be downloaded/saved.
|
|
|
|
Returns:
|
|
x_train (ndarray): Training data.
|
|
y_train (ndarray): Training labels.
|
|
x_test (ndarray): Test data.
|
|
y_test (ndarray): Test labels.
|
|
'''
|
|
# Download CIFAR-100 training and test datasets
|
|
data_train = torchvision.datasets.CIFAR100(data_dir, train=True, download=True)
|
|
data_test = torchvision.datasets.CIFAR100(data_dir, train=False, download=True)
|
|
|
|
# Transpose data for proper channel order and convert labels to numpy arrays
|
|
x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets)
|
|
x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets)
|
|
|
|
return x_train, y_train, x_test, y_test
|
|
|
|
def split_cf100_real_world_images(data, labels, n_clients=100, verbose=True):
|
|
'''
|
|
Splits data and labels among n_clients to simulate a non-IID distribution.
|
|
|
|
Parameters:
|
|
data (ndarray): Dataset images [n_data x shape].
|
|
labels (ndarray): Dataset labels [n_data].
|
|
n_clients (int): Number of clients to split the data among.
|
|
verbose (bool): Print detailed information if True.
|
|
|
|
Returns:
|
|
clients_split (ndarray): Split data and labels for each client.
|
|
'''
|
|
n_labels = np.max(labels) + 1 # Number of unique labels/classes
|
|
|
|
def divide_into_sections(n, m):
|
|
'''Return m random integers that sum up to n.'''
|
|
result = [1] * m
|
|
for _ in range(n - m):
|
|
result[random.randint(0, m - 1)] += 1
|
|
return result
|
|
|
|
# Shuffle and partition classes
|
|
n_classes = len(set(labels)) # Number of unique classes
|
|
classes = list(range(n_classes))
|
|
np.random.shuffle(classes) # Shuffle class indices
|
|
label_indices = [list(np.where(labels == class_)[0]) for class_ in classes] # Indices of each class in labels
|
|
|
|
# Define number of classes for each client (randomized)
|
|
tmp = [np.random.randint(1, 100) for _ in range(n_clients)]
|
|
total_partition = sum(tmp)
|
|
class_partition = divide_into_sections(total_partition, len(classes)) # Partition classes randomly
|
|
|
|
# Split class indices among clients
|
|
class_partition = sorted(class_partition, reverse=True)
|
|
class_partition_split = {}
|
|
|
|
for idx, class_ in enumerate(classes):
|
|
# Split each class' indices according to the partition
|
|
class_partition_split[class_] = [list(i) for i in np.array_split(label_indices[idx], class_partition[idx])]
|
|
|
|
clients_split = []
|
|
for i in range(n_clients):
|
|
n = tmp[i] # Number of classes for this client
|
|
indices = []
|
|
j = 0
|
|
|
|
# Assign class data to the client
|
|
while n > 0:
|
|
class_ = classes[j]
|
|
if class_partition_split[class_]:
|
|
indices.extend(class_partition_split[class_].pop()) # Add indices of the class to the client
|
|
n -= 1
|
|
j += 1
|
|
|
|
clients_split.append([data[indices], labels[indices]]) # Add client's data split
|
|
|
|
# Re-sort classes based on available data to balance further splits
|
|
classes = sorted(classes, key=lambda x: len(class_partition_split[x]), reverse=True)
|
|
|
|
# Raise error if client partition criteria cannot be met
|
|
if n > 0:
|
|
raise ValueError("Unable to fulfill the client partition criteria.")
|
|
|
|
# Verbose option to print split information
|
|
if verbose:
|
|
display_data_split(clients_split)
|
|
|
|
return np.array(clients_split)
|
|
|
|
def display_data_split(clients_split):
|
|
'''Print the split information of the dataset for each client.'''
|
|
print("Data split:")
|
|
for i, client in enumerate(clients_split):
|
|
split = np.sum(client[1].reshape(1, -1) == np.arange(np.max(client[1]) + 1).reshape(-1, 1), axis=1)
|
|
print(f" - Client {i}: {split}")
|
|
print()
|
|
|
|
def get_default_data_apply_transformations_cf100(train=True, verbose=True):
|
|
'''
|
|
Return default data apply_transformationations for CIFAR-100.
|
|
|
|
Parameters:
|
|
train (bool): Whether to apply apply_transformationations for training data.
|
|
verbose (bool): Print apply_transformationation details if True.
|
|
|
|
Returns:
|
|
apply_transformations_train (Compose): Training apply_transformationations.
|
|
apply_transformations_eval (Compose): Evaluation (test) apply_transformationations.
|
|
'''
|
|
# Define apply_transformationations for training data
|
|
apply_transformations_train = {
|
|
'cifar100': apply_transformations.Compose([
|
|
apply_transformations.ToPILImage(),
|
|
apply_transformations.RandomCrop(32, padding=4),
|
|
apply_transformations.RandomHorizontalFlip(),
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
|
])
|
|
}
|
|
|
|
# Define apply_transformationations for test data
|
|
apply_transformations_eval = {
|
|
'cifar100': apply_transformations.Compose([
|
|
apply_transformations.ToPILImage(),
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
|
])
|
|
}
|
|
|
|
# Verbose option to print apply_transformationation steps
|
|
if verbose:
|
|
print("\nData preprocessing:")
|
|
for apply_transformationation in apply_transformations_train['cifar100'].apply_transformations:
|
|
print(f' - {apply_transformationation}')
|
|
print()
|
|
|
|
return apply_transformations_train['cifar100'], apply_transformations_eval['cifar100']
|
|
|
|
def obtain_data_loaders_train_cf100(data_dir, n_clients, batch_size, classes_per_client=10, verbose=True,
|
|
apply_transformations_train=None, apply_transformations_eval=None, non_iid=None, split_factor=1):
|
|
'''
|
|
Return data loaders for training on CIFAR-100.
|
|
|
|
Parameters:
|
|
data_dir (str): Directory where the CIFAR-100 dataset will be saved.
|
|
n_clients (int): Number of clients for splitting the dataset.
|
|
batch_size (int): Batch size for each data loader.
|
|
classes_per_client (int): Number of classes per client.
|
|
verbose (bool): Print detailed information if True.
|
|
apply_transformations_train (Compose): apply_transformationations for training data.
|
|
apply_transformations_eval (Compose): apply_transformationations for evaluation data.
|
|
non_iid (str): Strategy to create a non-IID dataset split.
|
|
split_factor (float): Factor to control the degree of splitting.
|
|
|
|
Returns:
|
|
client_loaders (list): Data loaders for each client.
|
|
'''
|
|
x_train, y_train, _, _ = get_cifar100(data_dir)
|
|
|
|
# Verbose option to print dataset statistics
|
|
if verbose:
|
|
print_image_data_stats_train(x_train, y_train)
|
|
|
|
# Split data according to non-IID strategy (e.g., quantity_skew)
|
|
split = None
|
|
if non_iid == 'quantity_skew':
|
|
split = split_cf100_real_world_images(x_train, y_train, n_clients=n_clients, verbose=verbose)
|
|
|
|
split_tmp = shuffle_list(split)
|
|
|
|
# Create DataLoaders for each client
|
|
client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor),
|
|
batch_size=batch_size, shuffle=True) for x, y in split_tmp]
|
|
|
|
return client_loaders
|
|
|
|
def obtain_data_loaders_test_cf100(data_dir, batch_size, verbose=True, apply_transformations_eval=None):
|
|
'''
|
|
Return data loaders for testing on CIFAR-100.
|
|
|
|
Parameters:
|
|
data_dir (str): Directory where the CIFAR-100 dataset will be saved.
|
|
batch_size (int): Batch size for the test data loader.
|
|
verbose (bool): Print detailed information if True.
|
|
apply_transformations_eval (Compose): apply_transformationations for evaluation data.
|
|
|
|
Returns:
|
|
test_loader (DataLoader): Test data loader.
|
|
'''
|
|
_, _, x_test, y_test = get_cifar100(data_dir)
|
|
|
|
# Verbose option to print dataset statistics
|
|
if verbose:
|
|
print_image_data_stats_test(x_test, y_test)
|
|
|
|
# Create DataLoader for the test dataset
|
|
test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval, split_factor=1),
|
|
batch_size=100, shuffle=False)
|
|
|
|
return test_loader
|