use-case-and-architecture/EdgeFLite/data_collection/dataset_cifar.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

179 lines
7.7 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
# Import necessary libraries
from PIL import Image # For image handling
import os # For file path operations
import numpy as np # For numerical operations
import pickle # For loading serialized data
import torch # For PyTorch operations
# Import custom classes and functions from the current package
from .vision import VisionDataset
from .utils import validate_integrity, fetch_and_extract_archive
# CIFAR10 dataset class
class CIFAR10(VisionDataset):
"""
CIFAR10 Dataset class that handles the CIFAR-10 dataset loading, processing, and apply_transformationations.
Args:
root (str): Directory where the dataset is stored or will be downloaded to.
train (bool, optional): If True, load the training set. Otherwise, load the test set.
apply_transformation (callable, optional): A function/apply_transformation that takes a PIL image and returns a apply_transformationed version.
target_apply_transformation (callable, optional): A function/apply_transformation that takes the target and apply_transformations it.
download (bool, optional): If True, download the dataset if it's not found locally.
split_factor (int, optional): Number of apply_transformationations applied to each image. Default is 1.
"""
# Directory and URL details for downloading the CIFAR-10 dataset
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a' # MD5 checksum to verify the file's integrity
# List of training batches with their corresponding MD5 checksums
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']
]
# List of test batches with their corresponding MD5 checksums
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e']
]
# Info map to hold label names and their checksum
info_map = {
'filename': 'batches.info_map',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888'
}
# Initialization method
def __init__(self, root, train=True, apply_transformation=None, target_apply_transformation=None, download=False, split_factor=1):
super(CIFAR10, self).__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation)
self.train = train # Whether to load the training set or test set
self.split_factor = split_factor # Number of apply_transformationations to apply
# Download dataset if necessary
if download:
self.download()
# Check if the dataset is already downloaded and valid
if not self._validate_integrity():
raise RuntimeError('Dataset not found or corrupted. Use download=True to download it.')
# Load the dataset
self.data, self.targets = self._load_data()
# Load the label info map (to get class names)
self._load_info_map()
# Load dataset from the files
def _load_data(self):
data, targets = [], [] # Initialize lists to hold data and labels
files = self.train_list if self.train else self.test_list # Choose train or test files
# Load each file, deserialize with pickle, and append data and labels
for file_name, _ in files:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1') # Load file
data.append(entry['data']) # Append image data
targets.extend(entry.get('labels', entry.get('fine_labels', []))) # Append labels
# Reshape and format the data to (num_samples, height, width, channels)
data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) # Reshape to HWC format
return data, targets
# Load label names (info map)
def _load_info_map(self):
info_map_path = os.path.join(self.root, self.base_folder, self.info_map['filename']) # Path to info map
if not validate_integrity(info_map_path, self.info_map['md5']): # Check integrity of info map
raise RuntimeError('info_mapdata file not found or corrupted. Use download=True to download it.')
# Load the label names
with open(info_map_path, 'rb') as info_map_file:
info_map_data = pickle.load(info_map_file, encoding='latin1') # Load label names
self.classes = info_map_data[self.info_map['key']] # Extract class labels
self.class_to_idx = {label: idx for idx, label in enumerate(self.classes)} # Map class names to indices
# Get item (image and target) by index
def __getitem__(self, index):
"""
Get the item (image, target) at the specified index.
Args:
index (int): Index of the data.
Returns:
tuple: apply_transformationed image and the target class.
"""
img, target = self.data[index], self.targets[index] # Get image and target label
img = Image.fromarray(img) # Convert numpy array to PIL image
# Apply the apply_transformation multiple times based on split_factor
imgs = [self.apply_transformation(img) for _ in range(self.split_factor)] if self.apply_transformation else None
if imgs is None:
raise NotImplementedError('apply_transformation must be provided.')
# Apply target apply_transformationation if available
if self.target_apply_transformation:
target = self.target_apply_transformation(target)
return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target
# Return the number of items in the dataset
def __len__(self):
return len(self.data)
# Check if the dataset files are valid and downloaded
def _validate_integrity(self):
files = self.train_list + self.test_list # All files to check
for file_name, md5 in files:
file_path = os.path.join(self.root, self.base_folder, file_name)
if not validate_integrity(file_path, md5): # Verify integrity using MD5
return False
return True
# Download the dataset if it's not available
def download(self):
if self._validate_integrity():
print('Files already downloaded and verified')
else:
fetch_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
# Representation string to include the split type (Train/Test)
def extra_repr(self):
return f"Split: {'Train' if self.train else 'Test'}"
# CIFAR100 is a subclass of CIFAR10, with minor modifications
class CIFAR100(CIFAR10):
"""
CIFAR100 Dataset, a subclass of CIFAR10.
"""
# Directory and URL details for downloading CIFAR-100 dataset
base_folder = 'cifar-100-vision'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-vision.tar.gz"
filename = "cifar-100-vision.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' # MD5 checksum
# Training and test lists with their corresponding MD5 checksums for CIFAR-100
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d']
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc']
]
# Info map to hold fine label names and their checksum
info_map = {
'filename': 'info_map',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48'
}