Source code for pvcracks.vae.custom_dataset

# -*- coding: utf-8 -*-
"""
Created on Wed Apr  5 12:11:11 2023

@author: jlbraid
"""

import torch
from torch.utils.data import Dataset
from torchvision import transforms


[docs]class CustomDataset(Dataset): def __init__(self, data, transform=None): self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, idx): if self.transform is not None: image = self.data[idx] print(f"Shape of image from data: {image.shape}") # image = image.numpy() # print(f"Shape of image after conversion to numpy: {image.shape}") # image = torch.from_numpy(image).float() # print(f"Shape of image after conversion to tensor: {image.shape}") image = self.transform(image) print(f"Shape of image after transformation: {image.shape}") image = image.squeeze() image = image.unsqueeze(0) # Add channel dimension (C, H, W) print(f"Shape of image after adding channel dimension: {image.shape}") else: image = self.data[idx] print(f"Shape of image from data: {image.shape}") image = torch.from_numpy(image).float() print(f"Shape of image after conversion to tensor: {image.shape}") image = image.squeeze() image = image.unsqueeze(0) # Add channel dimension (C, H, W) print(f"Shape of image after adding channel dimension: {image.shape}") return image
[docs]class FixedRotation: def __init__(self, angle): self.angle = angle def __call__(self, x): return transforms.functional.rotate(x, self.angle)
[docs]class FixedHorizontalFlip: def __call__(self, x): return torch.flip(x, [-1])
[docs]class FixedVerticalFlip: def __call__(self, x): return torch.flip(x, [-2])