4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
95 lines
4.9 KiB
Python
95 lines
4.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
import os
|
|
import torch
|
|
import torch.utils.data as data
|
|
|
|
# VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class.
|
|
# It handles the initialization and representation of a vision-related dataset,
|
|
# including optional apply_transformationation of input data and targets.
|
|
class VisionDataset(data.Dataset):
|
|
_repr_indent = 4 # Defines the indentation level for dataset representation
|
|
|
|
def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None):
|
|
# Initializes the dataset by setting root directory and optional apply_transformationations
|
|
# If root is a string, expand any user directory shortcuts like "~"
|
|
self.root = os.path.expanduser(root) if isinstance(root, str) else root
|
|
|
|
# Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both)
|
|
has_apply_transformations = apply_transformations is not None
|
|
has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None
|
|
|
|
if has_apply_transformations and has_separate_apply_transformation:
|
|
raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.")
|
|
|
|
# Set apply_transformationations
|
|
self.apply_transformation = apply_transformation
|
|
self.target_apply_transformation = target_apply_transformation
|
|
|
|
# If separate apply_transformations are provided, wrap them in a StandardTransform
|
|
if has_separate_apply_transformation:
|
|
apply_transformations = StandardTransform(apply_transformation, target_apply_transformation)
|
|
self.apply_transformations = apply_transformations
|
|
|
|
# Placeholder for the method to retrieve an item by index
|
|
def __getitem__(self, index):
|
|
raise NotImplementedError
|
|
|
|
# Placeholder for the method to return dataset length
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
# Representation of the dataset including number of datapoints, root directory, and apply_transformations
|
|
def __repr__(self):
|
|
head = f"Dataset {self.__class__.__name__}"
|
|
body = [f"Number of datapoints: {self.__len__()}"]
|
|
if self.root is not None:
|
|
body.append(f"Root location: {self.root}")
|
|
body += self.extra_repr().splitlines() # Include any additional representation details
|
|
if hasattr(self, "apply_transformations") and self.apply_transformations is not None:
|
|
body.append(repr(self.apply_transformations)) # Include apply_transformationation details if applicable
|
|
lines = [head] + [" " * self._repr_indent + line for line in body]
|
|
return '\n'.join(lines)
|
|
|
|
# Utility to format the representation of the apply_transformation and target_apply_transformation attributes
|
|
def _format_apply_transformation_repr(self, apply_transformation, head):
|
|
lines = apply_transformation.__repr__().splitlines()
|
|
return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]
|
|
|
|
# Hook for adding extra dataset-specific information in the representation
|
|
def extra_repr(self):
|
|
return ""
|
|
|
|
|
|
# StandardTransform class handles the application of the apply_transformation and target_apply_transformation
|
|
# during dataset iteration or data loading.
|
|
class StandardTransform:
|
|
def __init__(self, apply_transformation=None, target_apply_transformation=None):
|
|
# Initialize with optional input and target apply_transformationations
|
|
self.apply_transformation = apply_transformation
|
|
self.target_apply_transformation = target_apply_transformation
|
|
|
|
# Calls the appropriate apply_transformations on the input and target when invoked
|
|
def __call__(self, input, target):
|
|
if self.apply_transformation is not None:
|
|
input = self.apply_transformation(input)
|
|
if self.target_apply_transformation is not None:
|
|
target = self.target_apply_transformation(target)
|
|
return input, target
|
|
|
|
# Utility to format the apply_transformationation representation
|
|
def _format_apply_transformation_repr(self, apply_transformation, head):
|
|
lines = apply_transformation.__repr__().splitlines()
|
|
return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]
|
|
|
|
# Representation of the StandardTransform including both input and target apply_transformationations
|
|
def __repr__(self):
|
|
body = [self.__class__.__name__]
|
|
if self.apply_transformation is not None:
|
|
body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ")
|
|
if self.target_apply_transformation is not None:
|
|
body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ")
|
|
|
|
return '\n'.join(body)
|