4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
195 lines
8.5 KiB
Python
195 lines
8.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import torch
|
|
from .folder import ImageFolder
|
|
from .utils import validate_integrity, extract_archive, verify_str_arg
|
|
|
|
# Dictionary that maps the dataset split (train/val/devkit) to its corresponding archive filename and checksum (md5 hash)
|
|
ARCHIVE_info_map = {
|
|
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
|
|
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
|
|
'devkit': ('ILSVRC2012_devkit_t12.tar', 'fa75699e90414af021442c21a62c3abf')
|
|
}
|
|
|
|
# File name where the information map (class info, wnid, etc.) is stored
|
|
info_map_FILE = "info_map.bin"
|
|
|
|
class ImageNet(ImageFolder):
|
|
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
|
|
|
|
Args:
|
|
root (str): Root directory of the ImageNet Dataset.
|
|
split (str, optional): Dataset split, either ``train`` or ``val``.
|
|
apply_transformation (callable, optional): A function/apply_transformation to apply to the PIL image.
|
|
target_apply_transformation (callable, optional): A function/apply_transformation to apply to the target.
|
|
loader (callable, optional): Function to load an image from its path.
|
|
|
|
Attributes:
|
|
classes (list): List of class name tuples.
|
|
class_to_idx (dict): Mapping of class names to indices.
|
|
wnids (list): List of WordNet IDs.
|
|
wnid_to_idx (dict): Mapping of WordNet IDs to class indices.
|
|
imgs (list): List of image path and class index tuples.
|
|
targets (list): Class index values for each image in the dataset.
|
|
"""
|
|
|
|
def __init__(self, root, split='train', download=None, **kwargs):
|
|
# Check if download flag is used, raise warnings since dataset is no longer publicly accessible
|
|
if download is True:
|
|
raise RuntimeError("The dataset is no longer publicly accessible. Please download archives externally and place them in the root directory.")
|
|
elif download is False:
|
|
warnings.warn("The download flag is deprecated, as the dataset is no longer publicly accessible.", RuntimeWarning)
|
|
|
|
# Expand the root directory path
|
|
root = self.root = os.path.expanduser(root)
|
|
|
|
# Validate the dataset split (should be either 'train' or 'val')
|
|
self.split = verify_str_arg(split, "split", ("train", "val"))
|
|
|
|
# Parse dataset archives (train/val/devkit) and prepare the dataset
|
|
self.extract_archives()
|
|
|
|
# Load WordNet ID to class mappings from the info_map file
|
|
wnid_to_classes = load_information_map_file(self.root)[0]
|
|
|
|
# Initialize the ImageFolder with the split folder (train/val directory)
|
|
super().__init__(self.divide_folder_contents, **kwargs)
|
|
|
|
# Set class-related attributes
|
|
self.root = root
|
|
self.wnids = self.classes
|
|
self.wnid_to_idx = self.class_to_idx
|
|
|
|
# Update classes to human-readable names and adjust the class_to_idx mapping
|
|
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
|
|
self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
|
|
|
|
def extract_archives(self):
|
|
# Check if the info_map file exists and is valid, otherwise parse the devkit archive
|
|
if not validate_integrity(os.path.join(self.root, info_map_FILE)):
|
|
extract_devkit_archive(self.root)
|
|
|
|
# If the dataset folder (train/val) does not exist, extract the respective archive
|
|
if not os.path.isdir(self.divide_folder_contents):
|
|
if self.split == 'train':
|
|
process_train_archive(self.root)
|
|
elif self.split == 'val':
|
|
process_validation_archive(self.root)
|
|
|
|
@property
|
|
def divide_folder_contents(self):
|
|
# Return the path of the folder containing the images (train/val)
|
|
return os.path.join(self.root, self.split)
|
|
|
|
def extra_repr(self):
|
|
# Additional representation for the dataset object (showing the split)
|
|
return f"Split: {self.split}"
|
|
|
|
def load_information_map_file(root, file=None):
|
|
# Load the info_map file from the root directory
|
|
file = os.path.join(root, file or info_map_FILE)
|
|
if validate_integrity(file):
|
|
return torch.load(file)
|
|
else:
|
|
raise RuntimeError(f"The info_map file {file} is either missing or corrupted. Please ensure it exists in the root directory.")
|
|
|
|
def _validate_archive_file(root, file, md5):
|
|
# Verify if the archive file is present and its checksum matches
|
|
if not validate_integrity(os.path.join(root, file), md5):
|
|
raise RuntimeError(f"The archive {file} is either missing or corrupted. Please download it and place it in {root}.")
|
|
|
|
def extract_devkit_archive(root, file=None):
|
|
"""Extract and process the ImageNet 2012 devkit archive to generate info_map information.
|
|
|
|
Args:
|
|
root (str): Root directory with the devkit archive.
|
|
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_devkit_t12.tar'.
|
|
"""
|
|
import scipy.io as sio
|
|
|
|
# Parse info_map.mat from the devkit, containing class and WordNet ID information
|
|
def read_info_map_mat_file(devkit_root):
|
|
info_map_path = os.path.join(devkit_root, "data", "info_map.mat")
|
|
info_map = sio.loadmat(info_map_path, squeeze_me=True)['synsets']
|
|
info_map = [info_map[idx] for idx, num_children in enumerate(info_map[4]) if num_children == 0]
|
|
idcs, wnids, classes = zip(*info_map)[:3]
|
|
classes = [tuple(clss.split(', ')) for clss in classes]
|
|
return {idx: wnid for idx, wnid in zip(idcs, wnids)}, {wnid: clss for wnid, clss in zip(wnids, classes)}
|
|
|
|
# Parse the validation ground truth file for image class labels
|
|
def process_val_groundtruth_txt(devkit_root):
|
|
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
|
|
with open(file) as f:
|
|
return [int(line.strip()) for line in f]
|
|
|
|
# Context manager to handle temporary directories for archive extraction
|
|
@contextmanager
|
|
def get_tmp_dir():
|
|
tmp_dir = tempfile.mkdtemp()
|
|
try:
|
|
yield tmp_dir
|
|
finally:
|
|
shutil.rmtree(tmp_dir)
|
|
|
|
# Extract and process the devkit archive
|
|
file, md5 = ARCHIVE_info_map["devkit"]
|
|
_validate_archive_file(root, file, md5)
|
|
|
|
with get_tmp_dir() as tmp_dir:
|
|
extract_archive(os.path.join(root, file), tmp_dir)
|
|
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
|
|
idx_to_wnid, wnid_to_classes = read_info_map_mat_file(devkit_root)
|
|
val_idcs = process_val_groundtruth_txt(devkit_root)
|
|
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
|
|
|
|
# Save the mappings to the info_map file
|
|
torch.save((wnid_to_classes, val_wnids), os.path.join(root, info_map_FILE))
|
|
|
|
def process_train_archive(root, file=None, folder="train"):
|
|
"""Extract and organize the ImageNet 2012 train dataset.
|
|
|
|
Args:
|
|
root (str): Root directory containing the train dataset archive.
|
|
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_train.tar'.
|
|
folder (str, optional): Destination folder. Defaults to 'train'.
|
|
"""
|
|
file, md5 = ARCHIVE_info_map["train"]
|
|
_validate_archive_file(root, file, md5)
|
|
|
|
train_root = os.path.join(root, folder)
|
|
extract_archive(os.path.join(root, file), train_root)
|
|
|
|
# Extract each class-specific archive in the train dataset
|
|
for archive in os.listdir(train_root):
|
|
extract_archive(os.path.join(train_root, archive), os.path.splitext(archive)[0], remove_finished=True)
|
|
|
|
def process_validation_archive(root, file=None, wnids=None, folder="val"):
|
|
"""Extract and organize the ImageNet 2012 validation dataset.
|
|
|
|
Args:
|
|
root (str): Root directory containing the validation dataset archive.
|
|
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_val.tar'.
|
|
wnids (list, optional): WordNet IDs for validation images. Defaults to None (loaded from info_map file).
|
|
folder (str, optional): Destination folder. Defaults to 'val'.
|
|
"""
|
|
file, md5 = ARCHIVE_info_map["val"]
|
|
if wnids is None:
|
|
wnids = load_information_map_file(root)[1]
|
|
|
|
_validate_archive_file(root, file, md5)
|
|
|
|
val_root = os.path.join(root, folder)
|
|
extract_archive(os.path.join(root, file), val_root)
|
|
|
|
# Create directories for each WordNet ID (class) and move validation images into their respective folders
|
|
for wnid in set(wnids):
|
|
os.mkdir(os.path.join(val_root, wnid))
|
|
|
|
for wnid, img in zip(wnids, sorted(os
|