Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

85 lines
3.2 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch
import os
# Importing the HOME configuration
from config import HOME
class PillDataBase(Dataset):
def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1):
"""
Initialize the dataset.
Args:
data_dir (str): Directory where the dataset is stored.
train (bool): Flag to indicate if it's a training or testing dataset.
apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization).
split_factor (int): Number of times each image is split into parts for augmentation purposes.
"""
self.train = train
self.apply_transformation = apply_transformation
self.split_factor = split_factor
self.data_dir = data_dir + '/pill_base'
self.dataset = self._load_data()
def __len__(self):
"""Return the number of samples in the dataset."""
return len(self.dataset)
def _load_data(self):
"""
Load the dataset by reading the corresponding text file (train.txt or test.txt).
The dataset text file contains the image file paths and corresponding labels.
Returns:
dataset (list): List of image file paths and their respective labels.
"""
dataset = []
txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt')
with open(txt_path, 'r') as file:
lines = file.readlines()
for line in lines:
# Each line contains an image path and a label separated by space
filename, label = line.strip().split(' ')
# Adjust the image path to the correct directory structure
filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir)
# Append the image file path and label as an integer
dataset.append([filename, int(label)])
return dataset
def __getitem__(self, index):
"""
Retrieve a specific sample from the dataset at the given index.
Args:
index (int): Index of the image and label to retrieve.
Returns:
tuple: A tensor of concatenated apply_transformationed images and the corresponding label.
"""
images = []
image_path = self.dataset[index][0]
label = torch.tensor(int(self.dataset[index][1]))
# Open the image file
image = Image.open(image_path)
# Apply apply_transformationations to the image if provided and split into parts as specified by split_factor
if self.apply_transformation:
for _ in range(self.split_factor):
images.append(self.apply_transformation(image))
# Concatenate all apply_transformationed image splits into a single tensor
return torch.cat(images, dim=0), label
if __name__ == "__main__":
# Example of how to instantiate and use the dataset
dataset = PillDataBase()