4ec0a23e73
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
230 lines
9.6 KiB
Python
230 lines
9.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
# Import necessary modules
|
|
from .vision import VisionDataset # Import the base VisionDataset class
|
|
from PIL import Image # Import PIL for image loading and processing
|
|
import os # For interacting with the file system
|
|
import torch # PyTorch for tensor operations
|
|
|
|
# Function to check if a file has an allowed extension
|
|
def validate_file_extension(filename, extensions):
|
|
"""
|
|
Check if a file has an allowed extension.
|
|
|
|
Args:
|
|
filename (str): Path to the file.
|
|
extensions (tuple of str): Extensions to consider (in lowercase).
|
|
|
|
Returns:
|
|
bool: True if the filename ends with one of the given extensions.
|
|
"""
|
|
return filename.lower().endswith(extensions)
|
|
|
|
# Function to check if a file is an image
|
|
def is_image_file(filename):
|
|
"""
|
|
Check if a file is an image based on its extension.
|
|
|
|
Args:
|
|
filename (str): Path to the file.
|
|
|
|
Returns:
|
|
bool: True if the filename is a known image format.
|
|
"""
|
|
return validate_file_extension(filename, IMG_EXTENSIONS)
|
|
|
|
# Function to create a dataset of file paths and their corresponding class indices
|
|
def generate_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
|
|
"""
|
|
Creates a list of file paths and their corresponding class indices.
|
|
|
|
Args:
|
|
directory (str): Root directory.
|
|
class_to_idx (dict): Mapping of class names to class indices.
|
|
extensions (tuple, optional): Allowed file extensions.
|
|
is_valid_file (callable, optional): Function to validate files.
|
|
|
|
Returns:
|
|
list: A list of (file_path, class_index) tuples.
|
|
"""
|
|
instances = []
|
|
directory = os.path.expanduser(directory) # Expand user directory path if needed
|
|
|
|
# Ensure only one of extensions or is_valid_file is specified
|
|
if (extensions is None and is_valid_file is None) or (extensions is not None and is_valid_file is not None):
|
|
raise ValueError("Specify either 'extensions' or 'is_valid_file', but not both.")
|
|
|
|
# Define the validation function if extensions are provided
|
|
if extensions is not None:
|
|
def is_valid_file(x):
|
|
return validate_file_extension(x, extensions)
|
|
|
|
# Iterate through the directory, searching for valid image files
|
|
for target_class in sorted(class_to_idx.keys()):
|
|
class_index = class_to_idx[target_class] # Get the class index
|
|
target_dir = os.path.join(directory, target_class) # Define the target class folder
|
|
if not os.path.isdir(target_dir): # Skip if it's not a directory
|
|
continue
|
|
# Walk through the directory and subdirectories
|
|
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
|
for fname in sorted(fnames):
|
|
path = os.path.join(root, fname) # Full file path
|
|
if is_valid_file(path): # Check if it's a valid file
|
|
instances.append((path, class_index)) # Append file path and class index to the list
|
|
|
|
return instances # Return the dataset
|
|
|
|
# DatasetFolder class: Generic data loader for samples arranged in subdirectories by class
|
|
class DatasetFolder(VisionDataset):
|
|
"""
|
|
A generic data loader where samples are arranged in subdirectories by class.
|
|
|
|
Args:
|
|
root (str): Root directory path.
|
|
loader (callable): Function to load a sample from its file path.
|
|
extensions (tuple[str]): Allowed file extensions.
|
|
apply_transformation (callable, optional): apply_transformation applied to each sample.
|
|
target_apply_transformation (callable, optional): apply_transformation applied to each target.
|
|
is_valid_file (callable, optional): Function to validate files.
|
|
split_factor (int, optional): Number of times to apply the apply_transformation.
|
|
|
|
Attributes:
|
|
classes (list): Sorted list of class names.
|
|
class_to_idx (dict): Mapping of class names to class indices.
|
|
samples (list): List of (sample_path, class_index) tuples.
|
|
targets (list): List of class indices corresponding to each sample.
|
|
"""
|
|
|
|
def __init__(self, root, loader, extensions=None, apply_transformation=None,
|
|
target_apply_transformation=None, is_valid_file=None, split_factor=1):
|
|
super().__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation)
|
|
self.classes, self.class_to_idx = self._discover_classes(self.root) # Discover classes in the root directory
|
|
self.samples = generate_dataset(self.root, self.class_to_idx, extensions, is_valid_file) # Create dataset from files
|
|
|
|
# Raise an error if no valid files are found
|
|
if len(self.samples) == 0:
|
|
raise RuntimeError(f"Found 0 files in subfolders of: {self.root}. "
|
|
f"Supported extensions are: {','.join(extensions)}")
|
|
|
|
self.loader = loader # Function to load a sample
|
|
self.extensions = extensions # Allowed file extensions
|
|
self.targets = [s[1] for s in self.samples] # List of target class indices
|
|
self.split_factor = split_factor # Number of apply_transformationations to apply
|
|
|
|
# Function to find class subdirectories in the root directory
|
|
def _discover_classes(self, dir):
|
|
"""
|
|
Discover class subdirectories in the root directory.
|
|
|
|
Args:
|
|
dir (str): Root directory.
|
|
|
|
Returns:
|
|
tuple: (classes, class_to_idx) where classes are subdirectories of 'dir',
|
|
and class_to_idx is a mapping of class names to indices.
|
|
"""
|
|
classes = sorted([d.name for d in os.scandir(dir) if d.is_dir()]) # List of subdirectory names (classes)
|
|
class_to_idx = {classes[i]: i for i in range(len(classes))} # Map class names to indices
|
|
return classes, class_to_idx
|
|
|
|
# Function to get a sample and its target by index
|
|
def __getitem__(self, index):
|
|
"""
|
|
Retrieve a sample and its target by index.
|
|
|
|
Args:
|
|
index (int): Index of the sample.
|
|
|
|
Returns:
|
|
tuple: (sample, target), where the sample is the apply_transformationed image and
|
|
the target is the class index.
|
|
"""
|
|
path, target = self.samples[index] # Get the file path and target class index
|
|
sample = self.loader(path) # Load the sample (image)
|
|
|
|
# Apply apply_transformationation to the sample 'split_factor' times
|
|
imgs = [self.apply_transformation(sample) for _ in range(self.split_factor)] if self.apply_transformation else NotImplementedError
|
|
|
|
# Apply target apply_transformationation if specified
|
|
if self.target_apply_transformation:
|
|
target = self.target_apply_transformation(target)
|
|
|
|
return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target
|
|
|
|
# Return the number of samples in the dataset
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
# List of supported image file extensions
|
|
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
|
|
|
# Function to load an image using PIL
|
|
def load_image_pil(path):
|
|
"""
|
|
Load an image from the given path using PIL.
|
|
|
|
Args:
|
|
path (str): Path to the image.
|
|
|
|
Returns:
|
|
Image: RGB image.
|
|
"""
|
|
with open(path, 'rb') as f:
|
|
img = Image.open(f) # Open the image file
|
|
return img.convert('RGB') # Convert the image to RGB format
|
|
|
|
# Function to load an image using accimage library with fallback to PIL
|
|
def load_accimage(path):
|
|
"""
|
|
Load an image using the accimage library, falling back to PIL on failure.
|
|
|
|
Args:
|
|
path (str): Path to the image.
|
|
|
|
Returns:
|
|
Image: Image loaded with accimage or PIL.
|
|
"""
|
|
import accimage # accimage is a faster image loading library
|
|
try:
|
|
return accimage.Image(path) # Try loading with accimage
|
|
except IOError:
|
|
return load_image_pil(path) # Fall back to PIL on error
|
|
|
|
# Function to load an image using the default backend (accimage or PIL)
|
|
def basic_loader(path):
|
|
"""
|
|
Load an image using the default image backend (accimage or PIL).
|
|
|
|
Args:
|
|
path (str): Path to the image.
|
|
|
|
Returns:
|
|
Image: Loaded image.
|
|
"""
|
|
from torchvision import get_image_backend # Get the default image backend
|
|
return load_accimage(path) if get_image_backend() == 'accimage' else load_image_pil(path) # Load using the appropriate backend
|
|
|
|
# ImageFolder class: A dataset loader for images arranged in subdirectories by class
|
|
class ImageFolder(DatasetFolder):
|
|
"""
|
|
A dataset loader for images arranged in subdirectories by class.
|
|
|
|
Args:
|
|
root (str): Root directory path.
|
|
apply_transformation (callable, optional): apply_transformation applied to each image.
|
|
target_apply_transformation (callable, optional): apply_transformation applied to each target.
|
|
loader (callable, optional): Function to load an image from its path.
|
|
is_valid_file (callable, optional): Function to validate files.
|
|
|
|
Attributes:
|
|
classes (list): Sorted list of class names.
|
|
class_to_idx (dict): Mapping of class names to class indices.
|
|
imgs (list): List of (image_path, class_index) tuples.
|
|
"""
|
|
|
|
def __init__(self, root, apply_transformation=None, target_apply_transformation=None, loader=basic_loader, is_valid_file=None, split_factor=1):
|
|
super().__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
|
|
apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation,
|
|
is_valid_file=is
|