Example: 3CH VAE

This example shows how to load in the 3CH VAE

[1]:
# import time as t
import os
from pathlib import Path

project_root = Path.cwd().parents[1]
os.chdir(project_root)   # now cwd is .../pvcracks

from pvcracks.vae.VAE_model_3CH import VAE
import requests
import torch

Set device for torch

[2]:
#GPU or CPU
print(f"Are we using the GPU: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Are we using the GPU: True

Load 3CH VAE model

[3]:
from io import BytesIO

# Load from Datahub
url = "https://datahub.duramat.org/dataset/919a555d-dd97-46ad-b77c-ae7e8894e6c4/resource/e83785e1-ba34-4212-b519-c6535b3e6804/download/model_3ch_233_weights.pth"
#Link from the project folder: https://datahub.duramat.org/dataset/pvcracks-trained-vae-model

#Download from url
response = requests.get(url)
if response.status_code == 200:
    model = VAE(latent_dim=50)  # Create an instance of your model
    model.load_state_dict(torch.load(BytesIO(response.content), weights_only=True))
    model.to(device)  # Move to the appropriate device
else:
    print(f"Failed to download model. Status code: {response.status_code}")

#Evaluate model
model.eval()
Linear(in_features=50176, out_features=50, bias=True)
Linear(in_features=50176, out_features=50, bias=True)
[3]:
VAE(
  (encoder): Encoder(
    (conv): Sequential(
      (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (3): ReLU()
      (4): Conv2d(64, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (5): ReLU()
      (6): Conv2d(128, 256, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (7): ReLU()
      (8): Conv2d(256, 512, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (9): ReLU()
      (10): Conv2d(512, 1024, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (11): ReLU()
    )
    (fc_mu): Linear(in_features=50176, out_features=50, bias=True)
    (fc_logvar): Linear(in_features=50176, out_features=50, bias=True)
  )
  (decoder): Decoder(
    (fc): Linear(in_features=50, out_features=50176, bias=True)
    (deconv): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU()
      (2): ConvTranspose2d(512, 256, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (3): ReLU()
      (4): ConvTranspose2d(256, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))
      (5): ReLU()
      (6): ConvTranspose2d(128, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))
      (7): ReLU()
      (8): ConvTranspose2d(64, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))
      (9): ReLU()
      (10): ConvTranspose2d(32, 3, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))
      (11): Sigmoid()
    )
  )
)