import torch.nn as nn
from termcolor import colored
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=7, stride=2, dilation=1, padding=3), # Modified for 3CH
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(512, 1024, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
)
self.fc_mu = nn.Linear(1024 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(1024 * 7 * 7, latent_dim)
print(colored(self.fc_mu, 'green'))
print(colored(self.fc_logvar, 'blue'))
[docs] def forward(self, x):
x = self.conv(x)
print(colored(x.size(), 'red'))
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
[docs]class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.fc = nn.Linear(latent_dim, 1024 * 7 * 7)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(1024, 512, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=0),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=0),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=1), # Modified for 3CH
nn.Sigmoid()
)
[docs] def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1024, 7, 7)
x_deconv = self.deconv(x)
return x_deconv
[docs]class VAE(nn.Module):
def __init__(self, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
[docs] def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar).to(device)
eps = torch.randn_like(std).to(device)
return mu + eps * std
[docs] def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decoder(z)
return x_recon, mu, logvar