use-case-and-architecture/EdgeFLite/data_collection/data_cutout.py
Weisen Pan 4ec0a23e73 Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
2024-09-18 18:39:43 -07:00

72 lines
2.8 KiB
Python

# -*- coding: utf-8 -*-
# @Author: Weisen Pan
import torch
import numpy as np
class Cutout:
"""Applies random cutout augmentation by masking patches in an image.
This technique randomly cuts out square patches from the image to
augment the dataset, helping the model become invariant to occlusions.
Args:
n_holes (int): Number of patches to remove from the image.
length (int): Side length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
"""
Initializes the Cutout class with the number of patches to be removed
and the size of each patch.
Args:
n_holes (int): Number of patches (holes) to cut out from the image.
length (int): Size of each square patch.
"""
self.n_holes = n_holes # Number of holes (patches) to remove.
self.length = length # Side length of each square patch.
def __call__(self, img):
"""
Applies the cutout augmentation on the input image.
Args:
img (Tensor): The input image tensor with shape (C, H, W),
where C is the number of channels, H is the height,
and W is the width of the image.
Returns:
Tensor: The augmented image tensor with `n_holes` patches of size
`length x length` cut out, filled with zeros.
"""
# Get the height and width of the image (ignoring the channel dimension)
height, width = img.size(1), img.size(2)
# Create a mask initialized with ones, same height and width as the image
# (each pixel is set to 1, representing no masking initially)
mask = np.ones((height, width), dtype=np.float32)
# Randomly remove `n_holes` patches from the image
for _ in range(self.n_holes):
# Randomly choose the center of a patch (x_center, y_center)
y_center = np.random.randint(height)
x_center = np.random.randint(width)
# Define the coordinates of the patch based on the center
# and ensure the patch stays within the image boundaries.
y1 = np.clip(y_center - self.length // 2, 0, height)
y2 = np.clip(y_center + self.length // 2, 0, height)
x1 = np.clip(x_center - self.length // 2, 0, width)
x2 = np.clip(x_center + self.length // 2, 0, width)
# Set the mask to 0 for the patch (mark the patch as cut out)
mask[y1:y2, x1:x2] = 0.0
# Convert the mask from numpy array to a PyTorch tensor
mask_tensor = torch.from_numpy(mask).expand_as(img)
# Multiply the input image by the mask (cut out the selected patches)
return img * mask_tensor