use-case-and-architecture/EdgeFLite/data_collection/dataset_factory.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

153 lines
8.7 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
from torchvision import apply_transformations
from .cifar import CIFAR10, CIFAR100 # Import CIFAR10 and CIFAR100 datasets
from .autoaugment import CIFAR10Policy # Import CIFAR10 augmentation policy
__all__ = ['obtain_data_loader'] # Define the public API of this module
def obtain_data_loader(
data_dir, # Directory where the data is stored
split_factor=1, # Used for data partitioning, especially in federated learning
batch_size=128, # Batch size for loading data
crop_size=32, # Size to crop the input images
dataset='cifar10', # Dataset to use (CIFAR-10 by default)
split="train", # The split type: 'train', 'val', or 'test'
is_decentralized=False, # Whether to use decentralized training
is_autoaugment=1, # Use AutoAugment or not
randaa=None, # Placeholder for randomized augmentations
is_cutout=True, # Whether to apply cutout (random erasing)
erase_p=0.5, # Probability of applying random erasing
num_workers=8, # Number of workers to load data
pin_memory=True, # Use pinned memory for better GPU transfer
is_fed=False, # Whether to use federated learning
num_clusters=20, # Number of clients in federated learning
cifar10_non_iid=False, # Non-IID option for CIFAR-10 dataset
cifar100_non_iid=False # Non-IID option for CIFAR-100 dataset
):
"""Get the dataset loader"""
assert not (is_autoaugment and randaa is not None) # Autoaugment and randaa cannot be used together
# Loader settings based on multiprocessing
kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory}
assert split in ['train', 'val', 'test'] # Ensure valid split
# For CIFAR-10 dataset
if dataset == 'cifar10':
# Handle non-IID 'quantity skew' case for CIFAR-10
if cifar10_non_iid == 'quantity_skew':
non_iid = 'quantity_skew'
# If in training split
if 'train' in split:
print(f"INFO:PyTorch: Using quantity_skew CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.")
traindir = data_dir # Set data directory
# Define data apply_transformationations for training
train_apply_transformation = apply_transformations.Compose([
apply_transformations.ToPILImage(),
apply_transformations.RandomCrop(32, padding=4),
apply_transformations.RandomHorizontalFlip(),
CIFAR10Policy(), # AutoAugment policy
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False),
])
train_sampler = None
print('INFO:PyTorch: creating quantity_skew CIFAR10 train dataloader...')
# For federated learning, create loaders for each client
if is_fed:
train_loader = obtain_data_loaders_train(
traindir,
nclients=num_clusters * split_factor, # Number of clients in federated learning
batch_size=batch_size,
verbose=True,
apply_transformations_train=train_apply_transformation,
non_iid=non_iid, # Specify non-IID type
split_factor=split_factor
)
else:
assert is_fed # Ensure that is_fed is True
return train_loader, train_sampler
else:
# If in validation or test split
valdir = data_dir # Set validation data directory
# Define data apply_transformationations for validation/testing
val_apply_transformation = apply_transformations.Compose([
apply_transformations.ToPILImage(),
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
])
# Create the test loader
val_loader = obtain_data_loaders_test(
valdir,
nclients=num_clusters * split_factor, # Number of clients in federated learning
batch_size=batch_size,
verbose=True,
apply_transformations_eval=val_apply_transformation,
non_iid=non_iid,
split_factor=1
)
return val_loader
else:
# For standard IID CIFAR-10 case
if 'train' in split:
print(f"INFO:PyTorch: Using CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.")
traindir = data_dir # Set training data directory
# Define data apply_transformationations for training
train_apply_transformation = apply_transformations.Compose([
apply_transformations.RandomCrop(32, padding=4),
apply_transformations.RandomHorizontalFlip(),
CIFAR10Policy(), # AutoAugment policy
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False),
])
# Create the CIFAR-10 dataset object
train_dataset = CIFAR10(
traindir, train=True, apply_transformation=train_apply_transformation, target_apply_transformation=None, download=True, split_factor=split_factor
)
train_sampler = None # No sampler by default
# Decentralized training setup
if is_decentralized:
train_sampler = torch.utils.data.decentralized.decentralizedSampler(train_dataset, shuffle=True)
print('INFO:PyTorch: creating CIFAR10 train dataloader...')
if is_fed:
# Federated learning setup
images_per_client = int(train_dataset.data.shape[0] / (num_clusters * split_factor))
print(f"Images per client: {images_per_client}")
data_split = [images_per_client for _ in range(num_clusters * split_factor - 1)]
data_split.append(len(train_dataset) - images_per_client * (num_clusters * split_factor - 1))
# Split dataset for each client
traindata_split = torch.utils.data.random_split(train_dataset, data_split, generator=torch.Generator().manual_seed(68))
# Create data loaders for each client
train_loader = [torch.utils.data.DataLoader(
x, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs
) for x in traindata_split]
else:
# Standard data loader
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs
)
return train_loader, train_sampler
else:
# For validation or test split
valdir = data_dir # Set validation data directory
# Define data apply_transformationations for validation/testing
val_apply_transformation = apply_transformations.Compose([
apply_transformations.ToTensor(),
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
])
# Create CIFAR-10 dataset object for validation
val_dataset = CIFAR10(valdir, train=False, apply_transformation=val_apply_transformation, target_apply_transformation=None, download=True, split_factor=1)
print('INFO:PyTorch: creating CIFAR10 validation dataloader...')
# Create data loader for validation
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs)
return val_loader
# Additional dataset logic for CIFAR-100, decentralized setups, or other datasets can be added similarly.
else:
raise NotImplementedError(f"The DataLoader for {dataset} is not implemented.")