use-case-and-architecture/EdgeFLite/data_collection/vision_utils.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

95 lines
4.9 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import os
import torch
import torch.utils.data as data
# VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class.
# It handles the initialization and representation of a vision-related dataset,
# including optional apply_transformationation of input data and targets.
class VisionDataset(data.Dataset):
_repr_indent = 4 # Defines the indentation level for dataset representation
def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None):
# Initializes the dataset by setting root directory and optional apply_transformationations
# If root is a string, expand any user directory shortcuts like "~"
self.root = os.path.expanduser(root) if isinstance(root, str) else root
# Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both)
has_apply_transformations = apply_transformations is not None
has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None
if has_apply_transformations and has_separate_apply_transformation:
raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.")
# Set apply_transformationations
self.apply_transformation = apply_transformation
self.target_apply_transformation = target_apply_transformation
# If separate apply_transformations are provided, wrap them in a StandardTransform
if has_separate_apply_transformation:
apply_transformations = StandardTransform(apply_transformation, target_apply_transformation)
self.apply_transformations = apply_transformations
# Placeholder for the method to retrieve an item by index
def __getitem__(self, index):
raise NotImplementedError
# Placeholder for the method to return dataset length
def __len__(self):
raise NotImplementedError
# Representation of the dataset including number of datapoints, root directory, and apply_transformations
def __repr__(self):
head = f"Dataset {self.__class__.__name__}"
body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None:
body.append(f"Root location: {self.root}")
body += self.extra_repr().splitlines() # Include any additional representation details
if hasattr(self, "apply_transformations") and self.apply_transformations is not None:
body.append(repr(self.apply_transformations)) # Include apply_transformationation details if applicable
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
# Utility to format the representation of the apply_transformation and target_apply_transformation attributes
def _format_apply_transformation_repr(self, apply_transformation, head):
lines = apply_transformation.__repr__().splitlines()
return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]
# Hook for adding extra dataset-specific information in the representation
def extra_repr(self):
return ""
# StandardTransform class handles the application of the apply_transformation and target_apply_transformation
# during dataset iteration or data loading.
class StandardTransform:
def __init__(self, apply_transformation=None, target_apply_transformation=None):
# Initialize with optional input and target apply_transformationations
self.apply_transformation = apply_transformation
self.target_apply_transformation = target_apply_transformation
# Calls the appropriate apply_transformations on the input and target when invoked
def __call__(self, input, target):
if self.apply_transformation is not None:
input = self.apply_transformation(input)
if self.target_apply_transformation is not None:
target = self.target_apply_transformation(target)
return input, target
# Utility to format the apply_transformationation representation
def _format_apply_transformation_repr(self, apply_transformation, head):
lines = apply_transformation.__repr__().splitlines()
return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]
# Representation of the StandardTransform including both input and target apply_transformationations
def __repr__(self):
body = [self.__class__.__name__]
if self.apply_transformation is not None:
body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ")
if self.target_apply_transformation is not None:
body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ")
return '\n'.join(body)