import os
from pathlib import Path
import cv2 as cv
import numpy as np
import torch
import torchvision.transforms.functional as F
from imutils.paths import list_files
from PIL import Image
from torchvision import transforms
from torchvision.datasets.vision import VisionDataset
[docs]class SolarDataset(VisionDataset):
"""A dataset for our solar panel images and masks."""
def __init__(
self, root, image_folder, mask_folder, transforms, mode="train", random_seed=42
):
"""Set up the dataset by collecting and shuffling image and mask paths.
Args:
root (str or Path): Root directory that contains the data folders.
image_folder (str or Path): Subdirectory with image files.
mask_folder (str or Path): Subdirectory with mask files stored as numpy arrays.
transforms (callable): Callable applied to `(image, mask)` during `__getitem__`.
mode (str, optional): Dataset split indicator; kept for backward compatibility.
random_seed (int, optional): Seed used when shuffling to keep pairs aligned.
"""
super().__init__(root, transforms)
self.image_path = Path(self.root) / image_folder
self.mask_path = Path(self.root) / mask_folder
if not os.path.exists(self.image_path):
raise OSError(f"{self.image_path} not found.")
if not os.path.exists(self.mask_path):
raise OSError(f"{self.mask_path} not found.")
self.image_list = sorted(
[
c
for c in list(list_files(self.image_path))
if ".ipynb_checkpoints" not in c
]
)
self.mask_list = sorted(
[
c
for c in list(list_files(self.mask_path))
if ".ipynb_checkpoints" not in c
]
)
self.image_list = np.array(self.image_list)
self.mask_list = np.array(self.mask_list)
np.random.seed(random_seed)
index = np.arange(len(self.image_list))
np.random.shuffle(index)
self.image_list = self.image_list[index]
self.mask_list = self.mask_list[index]
def __len__(self):
"""Return the total number of samples available."""
return len(self.image_list)
[docs] def get_img_path(self, index):
"""Return the full path to the image file at `index`.
Args:
index (int): Dataset index referencing the desired image.
Returns:
str: Absolute path to the image file.
"""
return self.image_list[index]
def __get_mask_path__(self, index):
"""Return the full path to the mask file at `index`.
Args:
index (int): Dataset index referencing the desired mask.
Returns:
str: Absolute path to the mask file.
"""
return self.mask_list[index]
def __getname__(self, index):
"""Return the name of the image and mask at the given index.
Args:
index (int): The index of the image and mask.
Returns:
str: The name of the image and mask.
Raises:
IndexError: If the filenames of the image and mask do not match.
"""
image_name = os.path.splitext(os.path.split(self.image_list[index])[-1])[0]
mask_name = os.path.splitext(os.path.split(self.mask_list[index])[-1])[0]
if image_name == mask_name:
return image_name
else:
return IndexError("Image and mask names do not match.")
def __getraw__(self, index) -> tuple[Image.Image, np.ndarray]:
"""Load the raw PIL image and numpy mask for the given index.
Args:
index (int): Dataset index referencing the desired sample.
Returns:
tuple: `(PIL.Image.Image, numpy.ndarray)` for the image and mask.
Raises:
ValueError: If the filenames of the image and mask do not match.
"""
if not self.__getname__(index):
raise ValueError(
"{}: Image doesn't match with mask".format(
os.path.split(self.image_list[index])[-1]
)
)
image = Image.open(self.image_list[index])
mask = np.load(self.mask_list[index], allow_pickle=True)
return image, mask
def __getitem__(self, index):
"""Load and transform the sample identified by `index`.
Args:
index (int): Dataset index referencing the desired sample.
Returns:
tuple: Transformed `(image, mask)` pair ready for the model.
"""
image, mask = self.__getraw__(index)
image, mask = self.transforms(image, mask)
return image, mask
[docs]class Compose:
def __init__(self, transforms):
"""Store a sequence of paired image/mask transforms.
Args:
transforms (Iterable[callable]): Transform callables accepting `(image, mask)`.
"""
self.transforms = transforms
def __call__(self, image, target):
"""Sequentially apply all stored transforms to the `(image, target)` pair.
Args:
image: Input image passed to each transform.
target: Segmentation mask passed to each transform.
Returns:
tuple: The transformed `(image, target)` pair.
"""
for t in self.transforms:
image, target = t(image, target)
return image, target
# MODIFIED FOR NUMPY ARRAY INPUT
[docs]class FixResize:
"""Resize PIL images and numpy masks to a fixed square size. This is for single-channel masks."""
# UNet requires input size to be multiple of 16
def __init__(self, size):
"""Store the target square size for resizing operations.
Args:
size (int): Desired height and width after resizing. Must be multiple of 16.
"""
self.size = size
def __call__(self, image, target):
"""Resize inputs while respecting their data types.
Args:
image (PIL.Image.Image): Input image to be resized with bilinear interpolation.
target (numpy.ndarray): Segmentation mask resized with nearest neighbor.
Returns:
tuple: `(image, target)` resized to `(size, size)`.
"""
image = F.resize(
image,
(self.size, self.size),
interpolation=transforms.InterpolationMode.BILINEAR,
)
target = cv.resize(target, (self.size, self.size), interpolation=0)
return image, target
[docs]class ChanneledFixResize:
"""Resize images and multi-channel masks to a common square shape."""
def __init__(self, size):
"""Store the target square size for per-channel resizing.
Args:
size (int): Desired height and width after resizing.
"""
self.size = size
def __call__(self, image, target):
"""Resize an image and handle 2D or 3D numpy masks appropriately.
Args:
image (PIL.Image.Image): Image to be resized with bilinear interpolation.
target (numpy.ndarray): Mask array that may be single- or multi-channel.
Returns:
tuple: `(image, target)` resized to `(size, size)`.
"""
# Resize image (assumed to be a PIL image) using torchvision transforms
image = F.resize(
image,
(self.size, self.size),
interpolation=transforms.InterpolationMode.BILINEAR,
)
# Resize target:
# If target is a multi-channel numpy array with shape (C, H, W)
if isinstance(target, np.ndarray) and len(target.shape) == 3:
channels = []
for c in range(target.shape[0]):
# Use nearest neighbor interpolation for segmentation masks
resized_channel = cv.resize(
target[c], (self.size, self.size), interpolation=cv.INTER_NEAREST
)
channels.append(resized_channel)
target_resized = np.stack(channels, axis=0)
else:
# Otherwise assume target is a 2D numpy array
target_resized = cv.resize(
target, (self.size, self.size), interpolation=cv.INTER_NEAREST
)
return image, target_resized
[docs]class ToTensor:
"""Transform the image to tensor. Scale the image to [0,1] float32.
Transform the mask to tensor.
"""
def __call__(self, image, target):
"""Convert the image and mask to PyTorch tensors.
Args:
image (PIL.Image.Image): Input image expected by torchvision's `ToTensor`.
target (numpy.ndarray or PIL.Image.Image): Segmentation mask to convert.
Returns:
tuple: `(torch.Tensor, torch.Tensor)` ready for training.
"""
image = transforms.ToTensor()(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
[docs]class PILToTensor:
"""Transform the image to tensor. Keep raw type."""
def __call__(self, image, target):
"""Convert the image to a tensor without scaling pixel intensities."""
image = F.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
[docs]class Normalize:
def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""Store channel statistics used to normalize image tensors.
Args:
mean (tuple): Per-channel mean values.
std (tuple): Per-channel standard deviations.
"""
self.mean = mean
self.std = std
def __call__(self, image, target):
"""Normalize the image tensor and leave the target unchanged."""
image = F.normalize(image, mean=self.mean, std=self.std)
return image, target