4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
153 lines
8.7 KiB
Python
153 lines
8.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import torch
|
|
from torchvision import apply_transformations
|
|
from .cifar import CIFAR10, CIFAR100 # Import CIFAR10 and CIFAR100 datasets
|
|
from .autoaugment import CIFAR10Policy # Import CIFAR10 augmentation policy
|
|
|
|
__all__ = ['obtain_data_loader'] # Define the public API of this module
|
|
|
|
def obtain_data_loader(
|
|
data_dir, # Directory where the data is stored
|
|
split_factor=1, # Used for data partitioning, especially in federated learning
|
|
batch_size=128, # Batch size for loading data
|
|
crop_size=32, # Size to crop the input images
|
|
dataset='cifar10', # Dataset to use (CIFAR-10 by default)
|
|
split="train", # The split type: 'train', 'val', or 'test'
|
|
is_decentralized=False, # Whether to use decentralized training
|
|
is_autoaugment=1, # Use AutoAugment or not
|
|
randaa=None, # Placeholder for randomized augmentations
|
|
is_cutout=True, # Whether to apply cutout (random erasing)
|
|
erase_p=0.5, # Probability of applying random erasing
|
|
num_workers=8, # Number of workers to load data
|
|
pin_memory=True, # Use pinned memory for better GPU transfer
|
|
is_fed=False, # Whether to use federated learning
|
|
num_clusters=20, # Number of clients in federated learning
|
|
cifar10_non_iid=False, # Non-IID option for CIFAR-10 dataset
|
|
cifar100_non_iid=False # Non-IID option for CIFAR-100 dataset
|
|
):
|
|
"""Get the dataset loader"""
|
|
assert not (is_autoaugment and randaa is not None) # Autoaugment and randaa cannot be used together
|
|
|
|
# Loader settings based on multiprocessing
|
|
kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory}
|
|
assert split in ['train', 'val', 'test'] # Ensure valid split
|
|
|
|
# For CIFAR-10 dataset
|
|
if dataset == 'cifar10':
|
|
# Handle non-IID 'quantity skew' case for CIFAR-10
|
|
if cifar10_non_iid == 'quantity_skew':
|
|
non_iid = 'quantity_skew'
|
|
# If in training split
|
|
if 'train' in split:
|
|
print(f"INFO:PyTorch: Using quantity_skew CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.")
|
|
traindir = data_dir # Set data directory
|
|
# Define data apply_transformationations for training
|
|
train_apply_transformation = apply_transformations.Compose([
|
|
apply_transformations.ToPILImage(),
|
|
apply_transformations.RandomCrop(32, padding=4),
|
|
apply_transformations.RandomHorizontalFlip(),
|
|
CIFAR10Policy(), # AutoAugment policy
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
|
apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False),
|
|
])
|
|
train_sampler = None
|
|
print('INFO:PyTorch: creating quantity_skew CIFAR10 train dataloader...')
|
|
|
|
# For federated learning, create loaders for each client
|
|
if is_fed:
|
|
train_loader = obtain_data_loaders_train(
|
|
traindir,
|
|
nclients=num_clusters * split_factor, # Number of clients in federated learning
|
|
batch_size=batch_size,
|
|
verbose=True,
|
|
apply_transformations_train=train_apply_transformation,
|
|
non_iid=non_iid, # Specify non-IID type
|
|
split_factor=split_factor
|
|
)
|
|
else:
|
|
assert is_fed # Ensure that is_fed is True
|
|
return train_loader, train_sampler
|
|
else:
|
|
# If in validation or test split
|
|
valdir = data_dir # Set validation data directory
|
|
# Define data apply_transformationations for validation/testing
|
|
val_apply_transformation = apply_transformations.Compose([
|
|
apply_transformations.ToPILImage(),
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
|
])
|
|
# Create the test loader
|
|
val_loader = obtain_data_loaders_test(
|
|
valdir,
|
|
nclients=num_clusters * split_factor, # Number of clients in federated learning
|
|
batch_size=batch_size,
|
|
verbose=True,
|
|
apply_transformations_eval=val_apply_transformation,
|
|
non_iid=non_iid,
|
|
split_factor=1
|
|
)
|
|
return val_loader
|
|
else:
|
|
# For standard IID CIFAR-10 case
|
|
if 'train' in split:
|
|
print(f"INFO:PyTorch: Using CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.")
|
|
traindir = data_dir # Set training data directory
|
|
# Define data apply_transformationations for training
|
|
train_apply_transformation = apply_transformations.Compose([
|
|
apply_transformations.RandomCrop(32, padding=4),
|
|
apply_transformations.RandomHorizontalFlip(),
|
|
CIFAR10Policy(), # AutoAugment policy
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
|
apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False),
|
|
])
|
|
# Create the CIFAR-10 dataset object
|
|
train_dataset = CIFAR10(
|
|
traindir, train=True, apply_transformation=train_apply_transformation, target_apply_transformation=None, download=True, split_factor=split_factor
|
|
)
|
|
train_sampler = None # No sampler by default
|
|
|
|
# Decentralized training setup
|
|
if is_decentralized:
|
|
train_sampler = torch.utils.data.decentralized.decentralizedSampler(train_dataset, shuffle=True)
|
|
|
|
print('INFO:PyTorch: creating CIFAR10 train dataloader...')
|
|
if is_fed:
|
|
# Federated learning setup
|
|
images_per_client = int(train_dataset.data.shape[0] / (num_clusters * split_factor))
|
|
print(f"Images per client: {images_per_client}")
|
|
data_split = [images_per_client for _ in range(num_clusters * split_factor - 1)]
|
|
data_split.append(len(train_dataset) - images_per_client * (num_clusters * split_factor - 1))
|
|
# Split dataset for each client
|
|
traindata_split = torch.utils.data.random_split(train_dataset, data_split, generator=torch.Generator().manual_seed(68))
|
|
# Create data loaders for each client
|
|
train_loader = [torch.utils.data.DataLoader(
|
|
x, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs
|
|
) for x in traindata_split]
|
|
else:
|
|
# Standard data loader
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs
|
|
)
|
|
return train_loader, train_sampler
|
|
else:
|
|
# For validation or test split
|
|
valdir = data_dir # Set validation data directory
|
|
# Define data apply_transformationations for validation/testing
|
|
val_apply_transformation = apply_transformations.Compose([
|
|
apply_transformations.ToTensor(),
|
|
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
|
])
|
|
# Create CIFAR-10 dataset object for validation
|
|
val_dataset = CIFAR10(valdir, train=False, apply_transformation=val_apply_transformation, target_apply_transformation=None, download=True, split_factor=1)
|
|
print('INFO:PyTorch: creating CIFAR10 validation dataloader...')
|
|
# Create data loader for validation
|
|
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs)
|
|
return val_loader
|
|
# Additional dataset logic for CIFAR-100, decentralized setups, or other datasets can be added similarly.
|
|
else:
|
|
raise NotImplementedError(f"The DataLoader for {dataset} is not implemented.")
|