{ "cells": [ { "cell_type": "markdown", "id": "0a74ca2a-28a3-462a-8ae5-0440c0e862f2", "metadata": {}, "source": [ "# Example: 3CH VAE\n", "This example shows how to load in the 3CH VAE" ] }, { "cell_type": "code", "execution_count": 1, "id": "70eb0dfe-56e1-4b65-873b-40eda763070c", "metadata": {}, "outputs": [], "source": [ "# import time as t\n", "import os\n", "from pathlib import Path\n", "\n", "project_root = Path.cwd().parents[1]\n", "os.chdir(project_root) # now cwd is .../pvcracks\n", "\n", "from pvcracks.vae.VAE_model_3CH import VAE\n", "import requests\n", "import torch\n" ] }, { "cell_type": "markdown", "id": "df6f264c-6ea9-4e19-bdb2-88b80b172d53", "metadata": {}, "source": [ "### Set device for torch" ] }, { "cell_type": "code", "execution_count": 2, "id": "df747df8-83e5-4477-8eac-ded80659da10", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Are we using the GPU: True\n" ] } ], "source": [ "#GPU or CPU\n", "print(f\"Are we using the GPU: {torch.cuda.is_available()}\")\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "id": "fbaba4ee-be35-4eed-8e65-37952296cc17", "metadata": {}, "source": [ "## Load 3CH VAE model" ] }, { "cell_type": "code", "execution_count": 3, "id": "026e6287-4bc7-4a4f-a496-9d4331d35d5d", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32mLinear(in_features=50176, out_features=50, bias=True)\u001b[0m\n", "\u001b[34mLinear(in_features=50176, out_features=50, bias=True)\u001b[0m\n" ] }, { "data": { "text/plain": [ "VAE(\n", " (encoder): Encoder(\n", " (conv): Sequential(\n", " (0): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (1): ReLU()\n", " (2): Conv2d(32, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (3): ReLU()\n", " (4): Conv2d(64, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (5): ReLU()\n", " (6): Conv2d(128, 256, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (7): ReLU()\n", " (8): Conv2d(256, 512, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (9): ReLU()\n", " (10): Conv2d(512, 1024, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (11): ReLU()\n", " )\n", " (fc_mu): Linear(in_features=50176, out_features=50, bias=True)\n", " (fc_logvar): Linear(in_features=50176, out_features=50, bias=True)\n", " )\n", " (decoder): Decoder(\n", " (fc): Linear(in_features=50, out_features=50176, bias=True)\n", " (deconv): Sequential(\n", " (0): ConvTranspose2d(1024, 512, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (1): ReLU()\n", " (2): ConvTranspose2d(512, 256, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n", " (3): ReLU()\n", " (4): ConvTranspose2d(256, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))\n", " (5): ReLU()\n", " (6): ConvTranspose2d(128, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))\n", " (7): ReLU()\n", " (8): ConvTranspose2d(64, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))\n", " (9): ReLU()\n", " (10): ConvTranspose2d(32, 3, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), output_padding=(1, 1))\n", " (11): Sigmoid()\n", " )\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from io import BytesIO\n", "\n", "# Load from Datahub\n", "url = \"https://datahub.duramat.org/dataset/919a555d-dd97-46ad-b77c-ae7e8894e6c4/resource/e83785e1-ba34-4212-b519-c6535b3e6804/download/model_3ch_233_weights.pth\"\n", "#Link from the project folder: https://datahub.duramat.org/dataset/pvcracks-trained-vae-model\n", "\n", "#Download from url\n", "response = requests.get(url)\n", "if response.status_code == 200:\n", " model = VAE(latent_dim=50) # Create an instance of your model\n", " model.load_state_dict(torch.load(BytesIO(response.content), weights_only=True))\n", " model.to(device) # Move to the appropriate device\n", "else:\n", " print(f\"Failed to download model. Status code: {response.status_code}\")\n", "\n", "#Evaluate model\n", "model.eval()" ] } ], "metadata": { "kernelspec": { "display_name": "pyhpc_torcha3", "language": "python", "name": "pyhpc_torcha3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.2" } }, "nbformat": 4, "nbformat_minor": 5 }