4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
179 lines
7.7 KiB
Python
179 lines
7.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
# Import necessary libraries
|
|
from PIL import Image # For image handling
|
|
import os # For file path operations
|
|
import numpy as np # For numerical operations
|
|
import pickle # For loading serialized data
|
|
import torch # For PyTorch operations
|
|
|
|
# Import custom classes and functions from the current package
|
|
from .vision import VisionDataset
|
|
from .utils import validate_integrity, fetch_and_extract_archive
|
|
|
|
# CIFAR10 dataset class
|
|
class CIFAR10(VisionDataset):
|
|
"""
|
|
CIFAR10 Dataset class that handles the CIFAR-10 dataset loading, processing, and apply_transformationations.
|
|
|
|
Args:
|
|
root (str): Directory where the dataset is stored or will be downloaded to.
|
|
train (bool, optional): If True, load the training set. Otherwise, load the test set.
|
|
apply_transformation (callable, optional): A function/apply_transformation that takes a PIL image and returns a apply_transformationed version.
|
|
target_apply_transformation (callable, optional): A function/apply_transformation that takes the target and apply_transformations it.
|
|
download (bool, optional): If True, download the dataset if it's not found locally.
|
|
split_factor (int, optional): Number of apply_transformationations applied to each image. Default is 1.
|
|
"""
|
|
# Directory and URL details for downloading the CIFAR-10 dataset
|
|
base_folder = 'cifar-10-batches-py'
|
|
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
|
filename = "cifar-10-python.tar.gz"
|
|
tgz_md5 = 'c58f30108f718f92721af3b95e74349a' # MD5 checksum to verify the file's integrity
|
|
|
|
# List of training batches with their corresponding MD5 checksums
|
|
train_list = [
|
|
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
|
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
|
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
|
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
|
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']
|
|
]
|
|
|
|
# List of test batches with their corresponding MD5 checksums
|
|
test_list = [
|
|
['test_batch', '40351d587109b95175f43aff81a1287e']
|
|
]
|
|
|
|
# Info map to hold label names and their checksum
|
|
info_map = {
|
|
'filename': 'batches.info_map',
|
|
'key': 'label_names',
|
|
'md5': '5ff9c542aee3614f3951f8cda6e48888'
|
|
}
|
|
|
|
# Initialization method
|
|
def __init__(self, root, train=True, apply_transformation=None, target_apply_transformation=None, download=False, split_factor=1):
|
|
super(CIFAR10, self).__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation)
|
|
self.train = train # Whether to load the training set or test set
|
|
self.split_factor = split_factor # Number of apply_transformationations to apply
|
|
|
|
# Download dataset if necessary
|
|
if download:
|
|
self.download()
|
|
|
|
# Check if the dataset is already downloaded and valid
|
|
if not self._validate_integrity():
|
|
raise RuntimeError('Dataset not found or corrupted. Use download=True to download it.')
|
|
|
|
# Load the dataset
|
|
self.data, self.targets = self._load_data()
|
|
|
|
# Load the label info map (to get class names)
|
|
self._load_info_map()
|
|
|
|
# Load dataset from the files
|
|
def _load_data(self):
|
|
data, targets = [], [] # Initialize lists to hold data and labels
|
|
files = self.train_list if self.train else self.test_list # Choose train or test files
|
|
|
|
# Load each file, deserialize with pickle, and append data and labels
|
|
for file_name, _ in files:
|
|
file_path = os.path.join(self.root, self.base_folder, file_name)
|
|
with open(file_path, 'rb') as f:
|
|
entry = pickle.load(f, encoding='latin1') # Load file
|
|
data.append(entry['data']) # Append image data
|
|
targets.extend(entry.get('labels', entry.get('fine_labels', []))) # Append labels
|
|
|
|
# Reshape and format the data to (num_samples, height, width, channels)
|
|
data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) # Reshape to HWC format
|
|
return data, targets
|
|
|
|
# Load label names (info map)
|
|
def _load_info_map(self):
|
|
info_map_path = os.path.join(self.root, self.base_folder, self.info_map['filename']) # Path to info map
|
|
if not validate_integrity(info_map_path, self.info_map['md5']): # Check integrity of info map
|
|
raise RuntimeError('info_mapdata file not found or corrupted. Use download=True to download it.')
|
|
|
|
# Load the label names
|
|
with open(info_map_path, 'rb') as info_map_file:
|
|
info_map_data = pickle.load(info_map_file, encoding='latin1') # Load label names
|
|
self.classes = info_map_data[self.info_map['key']] # Extract class labels
|
|
self.class_to_idx = {label: idx for idx, label in enumerate(self.classes)} # Map class names to indices
|
|
|
|
# Get item (image and target) by index
|
|
def __getitem__(self, index):
|
|
"""
|
|
Get the item (image, target) at the specified index.
|
|
Args:
|
|
index (int): Index of the data.
|
|
|
|
Returns:
|
|
tuple: apply_transformationed image and the target class.
|
|
"""
|
|
img, target = self.data[index], self.targets[index] # Get image and target label
|
|
img = Image.fromarray(img) # Convert numpy array to PIL image
|
|
|
|
# Apply the apply_transformation multiple times based on split_factor
|
|
imgs = [self.apply_transformation(img) for _ in range(self.split_factor)] if self.apply_transformation else None
|
|
if imgs is None:
|
|
raise NotImplementedError('apply_transformation must be provided.')
|
|
|
|
# Apply target apply_transformationation if available
|
|
if self.target_apply_transformation:
|
|
target = self.target_apply_transformation(target)
|
|
|
|
return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target
|
|
|
|
# Return the number of items in the dataset
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
# Check if the dataset files are valid and downloaded
|
|
def _validate_integrity(self):
|
|
files = self.train_list + self.test_list # All files to check
|
|
for file_name, md5 in files:
|
|
file_path = os.path.join(self.root, self.base_folder, file_name)
|
|
if not validate_integrity(file_path, md5): # Verify integrity using MD5
|
|
return False
|
|
return True
|
|
|
|
# Download the dataset if it's not available
|
|
def download(self):
|
|
if self._validate_integrity():
|
|
print('Files already downloaded and verified')
|
|
else:
|
|
fetch_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
|
|
|
# Representation string to include the split type (Train/Test)
|
|
def extra_repr(self):
|
|
return f"Split: {'Train' if self.train else 'Test'}"
|
|
|
|
|
|
# CIFAR100 is a subclass of CIFAR10, with minor modifications
|
|
class CIFAR100(CIFAR10):
|
|
"""
|
|
CIFAR100 Dataset, a subclass of CIFAR10.
|
|
"""
|
|
# Directory and URL details for downloading CIFAR-100 dataset
|
|
base_folder = 'cifar-100-vision'
|
|
url = "https://www.cs.toronto.edu/~kriz/cifar-100-vision.tar.gz"
|
|
filename = "cifar-100-vision.tar.gz"
|
|
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' # MD5 checksum
|
|
|
|
# Training and test lists with their corresponding MD5 checksums for CIFAR-100
|
|
train_list = [
|
|
['train', '16019d7e3df5f24257cddd939b257f8d']
|
|
]
|
|
|
|
test_list = [
|
|
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc']
|
|
]
|
|
|
|
# Info map to hold fine label names and their checksum
|
|
info_map = {
|
|
'filename': 'info_map',
|
|
'key': 'fine_label_names',
|
|
'md5': '7973b15100ade9c7d40fb424638fde48'
|
|
}
|