use-case-and-architecture/EdgeFLite/data_collection/directory_utils.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

230 lines
9.6 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
# Import necessary modules
from .vision import VisionDataset # Import the base VisionDataset class
from PIL import Image # Import PIL for image loading and processing
import os # For interacting with the file system
import torch # PyTorch for tensor operations
# Function to check if a file has an allowed extension
def validate_file_extension(filename, extensions):
"""
Check if a file has an allowed extension.
Args:
filename (str): Path to the file.
extensions (tuple of str): Extensions to consider (in lowercase).
Returns:
bool: True if the filename ends with one of the given extensions.
"""
return filename.lower().endswith(extensions)
# Function to check if a file is an image
def is_image_file(filename):
"""
Check if a file is an image based on its extension.
Args:
filename (str): Path to the file.
Returns:
bool: True if the filename is a known image format.
"""
return validate_file_extension(filename, IMG_EXTENSIONS)
# Function to create a dataset of file paths and their corresponding class indices
def generate_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
"""
Creates a list of file paths and their corresponding class indices.
Args:
directory (str): Root directory.
class_to_idx (dict): Mapping of class names to class indices.
extensions (tuple, optional): Allowed file extensions.
is_valid_file (callable, optional): Function to validate files.
Returns:
list: A list of (file_path, class_index) tuples.
"""
instances = []
directory = os.path.expanduser(directory) # Expand user directory path if needed
# Ensure only one of extensions or is_valid_file is specified
if (extensions is None and is_valid_file is None) or (extensions is not None and is_valid_file is not None):
raise ValueError("Specify either 'extensions' or 'is_valid_file', but not both.")
# Define the validation function if extensions are provided
if extensions is not None:
def is_valid_file(x):
return validate_file_extension(x, extensions)
# Iterate through the directory, searching for valid image files
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class] # Get the class index
target_dir = os.path.join(directory, target_class) # Define the target class folder
if not os.path.isdir(target_dir): # Skip if it's not a directory
continue
# Walk through the directory and subdirectories
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname) # Full file path
if is_valid_file(path): # Check if it's a valid file
instances.append((path, class_index)) # Append file path and class index to the list
return instances # Return the dataset
# DatasetFolder class: Generic data loader for samples arranged in subdirectories by class
class DatasetFolder(VisionDataset):
"""
A generic data loader where samples are arranged in subdirectories by class.
Args:
root (str): Root directory path.
loader (callable): Function to load a sample from its file path.
extensions (tuple[str]): Allowed file extensions.
apply_transformation (callable, optional): apply_transformation applied to each sample.
target_apply_transformation (callable, optional): apply_transformation applied to each target.
is_valid_file (callable, optional): Function to validate files.
split_factor (int, optional): Number of times to apply the apply_transformation.
Attributes:
classes (list): Sorted list of class names.
class_to_idx (dict): Mapping of class names to class indices.
samples (list): List of (sample_path, class_index) tuples.
targets (list): List of class indices corresponding to each sample.
"""
def __init__(self, root, loader, extensions=None, apply_transformation=None,
target_apply_transformation=None, is_valid_file=None, split_factor=1):
super().__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation)
self.classes, self.class_to_idx = self._discover_classes(self.root) # Discover classes in the root directory
self.samples = generate_dataset(self.root, self.class_to_idx, extensions, is_valid_file) # Create dataset from files
# Raise an error if no valid files are found
if len(self.samples) == 0:
raise RuntimeError(f"Found 0 files in subfolders of: {self.root}. "
f"Supported extensions are: {','.join(extensions)}")
self.loader = loader # Function to load a sample
self.extensions = extensions # Allowed file extensions
self.targets = [s[1] for s in self.samples] # List of target class indices
self.split_factor = split_factor # Number of apply_transformationations to apply
# Function to find class subdirectories in the root directory
def _discover_classes(self, dir):
"""
Discover class subdirectories in the root directory.
Args:
dir (str): Root directory.
Returns:
tuple: (classes, class_to_idx) where classes are subdirectories of 'dir',
and class_to_idx is a mapping of class names to indices.
"""
classes = sorted([d.name for d in os.scandir(dir) if d.is_dir()]) # List of subdirectory names (classes)
class_to_idx = {classes[i]: i for i in range(len(classes))} # Map class names to indices
return classes, class_to_idx
# Function to get a sample and its target by index
def __getitem__(self, index):
"""
Retrieve a sample and its target by index.
Args:
index (int): Index of the sample.
Returns:
tuple: (sample, target), where the sample is the apply_transformationed image and
the target is the class index.
"""
path, target = self.samples[index] # Get the file path and target class index
sample = self.loader(path) # Load the sample (image)
# Apply apply_transformationation to the sample 'split_factor' times
imgs = [self.apply_transformation(sample) for _ in range(self.split_factor)] if self.apply_transformation else NotImplementedError
# Apply target apply_transformationation if specified
if self.target_apply_transformation:
target = self.target_apply_transformation(target)
return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target
# Return the number of samples in the dataset
def __len__(self):
return len(self.samples)
# List of supported image file extensions
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
# Function to load an image using PIL
def load_image_pil(path):
"""
Load an image from the given path using PIL.
Args:
path (str): Path to the image.
Returns:
Image: RGB image.
"""
with open(path, 'rb') as f:
img = Image.open(f) # Open the image file
return img.convert('RGB') # Convert the image to RGB format
# Function to load an image using accimage library with fallback to PIL
def load_accimage(path):
"""
Load an image using the accimage library, falling back to PIL on failure.
Args:
path (str): Path to the image.
Returns:
Image: Image loaded with accimage or PIL.
"""
import accimage # accimage is a faster image loading library
try:
return accimage.Image(path) # Try loading with accimage
except IOError:
return load_image_pil(path) # Fall back to PIL on error
# Function to load an image using the default backend (accimage or PIL)
def basic_loader(path):
"""
Load an image using the default image backend (accimage or PIL).
Args:
path (str): Path to the image.
Returns:
Image: Loaded image.
"""
from torchvision import get_image_backend # Get the default image backend
return load_accimage(path) if get_image_backend() == 'accimage' else load_image_pil(path) # Load using the appropriate backend
# ImageFolder class: A dataset loader for images arranged in subdirectories by class
class ImageFolder(DatasetFolder):
"""
A dataset loader for images arranged in subdirectories by class.
Args:
root (str): Root directory path.
apply_transformation (callable, optional): apply_transformation applied to each image.
target_apply_transformation (callable, optional): apply_transformation applied to each target.
loader (callable, optional): Function to load an image from its path.
is_valid_file (callable, optional): Function to validate files.
Attributes:
classes (list): Sorted list of class names.
class_to_idx (dict): Mapping of class names to class indices.
imgs (list): List of (image_path, class_index) tuples.
"""
def __init__(self, root, apply_transformation=None, target_apply_transformation=None, loader=basic_loader, is_valid_file=None, split_factor=1):
super().__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation,
is_valid_file=is