85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Author: Weisen Pan
|
|
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset
|
|
import torch
|
|
import os
|
|
|
|
# Importing the HOME configuration
|
|
from config import HOME
|
|
|
|
class PillDataBase(Dataset):
|
|
def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1):
|
|
"""
|
|
Initialize the dataset.
|
|
|
|
Args:
|
|
data_dir (str): Directory where the dataset is stored.
|
|
train (bool): Flag to indicate if it's a training or testing dataset.
|
|
apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization).
|
|
split_factor (int): Number of times each image is split into parts for augmentation purposes.
|
|
"""
|
|
self.train = train
|
|
self.apply_transformation = apply_transformation
|
|
self.split_factor = split_factor
|
|
self.data_dir = data_dir + '/pill_base'
|
|
self.dataset = self._load_data()
|
|
|
|
def __len__(self):
|
|
"""Return the number of samples in the dataset."""
|
|
return len(self.dataset)
|
|
|
|
def _load_data(self):
|
|
"""
|
|
Load the dataset by reading the corresponding text file (train.txt or test.txt).
|
|
|
|
The dataset text file contains the image file paths and corresponding labels.
|
|
|
|
Returns:
|
|
dataset (list): List of image file paths and their respective labels.
|
|
"""
|
|
dataset = []
|
|
txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt')
|
|
|
|
with open(txt_path, 'r') as file:
|
|
lines = file.readlines()
|
|
for line in lines:
|
|
# Each line contains an image path and a label separated by space
|
|
filename, label = line.strip().split(' ')
|
|
# Adjust the image path to the correct directory structure
|
|
filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir)
|
|
# Append the image file path and label as an integer
|
|
dataset.append([filename, int(label)])
|
|
|
|
return dataset
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Retrieve a specific sample from the dataset at the given index.
|
|
|
|
Args:
|
|
index (int): Index of the image and label to retrieve.
|
|
|
|
Returns:
|
|
tuple: A tensor of concatenated apply_transformationed images and the corresponding label.
|
|
"""
|
|
images = []
|
|
image_path = self.dataset[index][0]
|
|
label = torch.tensor(int(self.dataset[index][1]))
|
|
|
|
# Open the image file
|
|
image = Image.open(image_path)
|
|
|
|
# Apply apply_transformationations to the image if provided and split into parts as specified by split_factor
|
|
if self.apply_transformation:
|
|
for _ in range(self.split_factor):
|
|
images.append(self.apply_transformation(image))
|
|
|
|
# Concatenate all apply_transformationed image splits into a single tensor
|
|
return torch.cat(images, dim=0), label
|
|
|
|
if __name__ == "__main__":
|
|
# Example of how to instantiate and use the dataset
|
|
dataset = PillDataBase()
|