use-case-and-architecture/EdgeFLite/data_collection/dataset_imagenet.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

195 lines
8.5 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import warnings
from contextlib import contextmanager
import os
import shutil
import tempfile
import torch
from .folder import ImageFolder
from .utils import validate_integrity, extract_archive, verify_str_arg
# Dictionary that maps the dataset split (train/val/devkit) to its corresponding archive filename and checksum (md5 hash)
ARCHIVE_info_map = {
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
'devkit': ('ILSVRC2012_devkit_t12.tar', 'fa75699e90414af021442c21a62c3abf')
}
# File name where the information map (class info, wnid, etc.) is stored
info_map_FILE = "info_map.bin"
class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
Args:
root (str): Root directory of the ImageNet Dataset.
split (str, optional): Dataset split, either ``train`` or ``val``.
apply_transformation (callable, optional): A function/apply_transformation to apply to the PIL image.
target_apply_transformation (callable, optional): A function/apply_transformation to apply to the target.
loader (callable, optional): Function to load an image from its path.
Attributes:
classes (list): List of class name tuples.
class_to_idx (dict): Mapping of class names to indices.
wnids (list): List of WordNet IDs.
wnid_to_idx (dict): Mapping of WordNet IDs to class indices.
imgs (list): List of image path and class index tuples.
targets (list): Class index values for each image in the dataset.
"""
def __init__(self, root, split='train', download=None, **kwargs):
# Check if download flag is used, raise warnings since dataset is no longer publicly accessible
if download is True:
raise RuntimeError("The dataset is no longer publicly accessible. Please download archives externally and place them in the root directory.")
elif download is False:
warnings.warn("The download flag is deprecated, as the dataset is no longer publicly accessible.", RuntimeWarning)
# Expand the root directory path
root = self.root = os.path.expanduser(root)
# Validate the dataset split (should be either 'train' or 'val')
self.split = verify_str_arg(split, "split", ("train", "val"))
# Parse dataset archives (train/val/devkit) and prepare the dataset
self.extract_archives()
# Load WordNet ID to class mappings from the info_map file
wnid_to_classes = load_information_map_file(self.root)[0]
# Initialize the ImageFolder with the split folder (train/val directory)
super().__init__(self.divide_folder_contents, **kwargs)
# Set class-related attributes
self.root = root
self.wnids = self.classes
self.wnid_to_idx = self.class_to_idx
# Update classes to human-readable names and adjust the class_to_idx mapping
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
def extract_archives(self):
# Check if the info_map file exists and is valid, otherwise parse the devkit archive
if not validate_integrity(os.path.join(self.root, info_map_FILE)):
extract_devkit_archive(self.root)
# If the dataset folder (train/val) does not exist, extract the respective archive
if not os.path.isdir(self.divide_folder_contents):
if self.split == 'train':
process_train_archive(self.root)
elif self.split == 'val':
process_validation_archive(self.root)
@property
def divide_folder_contents(self):
# Return the path of the folder containing the images (train/val)
return os.path.join(self.root, self.split)
def extra_repr(self):
# Additional representation for the dataset object (showing the split)
return f"Split: {self.split}"
def load_information_map_file(root, file=None):
# Load the info_map file from the root directory
file = os.path.join(root, file or info_map_FILE)
if validate_integrity(file):
return torch.load(file)
else:
raise RuntimeError(f"The info_map file {file} is either missing or corrupted. Please ensure it exists in the root directory.")
def _validate_archive_file(root, file, md5):
# Verify if the archive file is present and its checksum matches
if not validate_integrity(os.path.join(root, file), md5):
raise RuntimeError(f"The archive {file} is either missing or corrupted. Please download it and place it in {root}.")
def extract_devkit_archive(root, file=None):
"""Extract and process the ImageNet 2012 devkit archive to generate info_map information.
Args:
root (str): Root directory with the devkit archive.
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_devkit_t12.tar'.
"""
import scipy.io as sio
# Parse info_map.mat from the devkit, containing class and WordNet ID information
def read_info_map_mat_file(devkit_root):
info_map_path = os.path.join(devkit_root, "data", "info_map.mat")
info_map = sio.loadmat(info_map_path, squeeze_me=True)['synsets']
info_map = [info_map[idx] for idx, num_children in enumerate(info_map[4]) if num_children == 0]
idcs, wnids, classes = zip(*info_map)[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
return {idx: wnid for idx, wnid in zip(idcs, wnids)}, {wnid: clss for wnid, clss in zip(wnids, classes)}
# Parse the validation ground truth file for image class labels
def process_val_groundtruth_txt(devkit_root):
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
with open(file) as f:
return [int(line.strip()) for line in f]
# Context manager to handle temporary directories for archive extraction
@contextmanager
def get_tmp_dir():
tmp_dir = tempfile.mkdtemp()
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
# Extract and process the devkit archive
file, md5 = ARCHIVE_info_map["devkit"]
_validate_archive_file(root, file, md5)
with get_tmp_dir() as tmp_dir:
extract_archive(os.path.join(root, file), tmp_dir)
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
idx_to_wnid, wnid_to_classes = read_info_map_mat_file(devkit_root)
val_idcs = process_val_groundtruth_txt(devkit_root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
# Save the mappings to the info_map file
torch.save((wnid_to_classes, val_wnids), os.path.join(root, info_map_FILE))
def process_train_archive(root, file=None, folder="train"):
"""Extract and organize the ImageNet 2012 train dataset.
Args:
root (str): Root directory containing the train dataset archive.
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_train.tar'.
folder (str, optional): Destination folder. Defaults to 'train'.
"""
file, md5 = ARCHIVE_info_map["train"]
_validate_archive_file(root, file, md5)
train_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), train_root)
# Extract each class-specific archive in the train dataset
for archive in os.listdir(train_root):
extract_archive(os.path.join(train_root, archive), os.path.splitext(archive)[0], remove_finished=True)
def process_validation_archive(root, file=None, wnids=None, folder="val"):
"""Extract and organize the ImageNet 2012 validation dataset.
Args:
root (str): Root directory containing the validation dataset archive.
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_val.tar'.
wnids (list, optional): WordNet IDs for validation images. Defaults to None (loaded from info_map file).
folder (str, optional): Destination folder. Defaults to 'val'.
"""
file, md5 = ARCHIVE_info_map["val"]
if wnids is None:
wnids = load_information_map_file(root)[1]
_validate_archive_file(root, file, md5)
val_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), val_root)
# Create directories for each WordNet ID (class) and move validation images into their respective folders
for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid))
for wnid, img in zip(wnids, sorted(os