import torch.nn as nn
from termcolor import colored
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
nn.Conv2d(512, 1024, kernel_size=7, stride=2, dilation=1, padding=3),
nn.ReLU(),
)
self.fc_mu = nn.Linear(1024 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(1024 * 7 * 7, latent_dim)
print(colored(self.fc_mu, "green"))
print(colored(self.fc_logvar, "blue"))
[docs] def forward(self, x):
x = self.conv(x)
print(colored(x.size(), "red"))
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
[docs]class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.fc = nn.Linear(latent_dim, 1024 * 7 * 7) # * 25 * 25)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(
1024,
512,
kernel_size=7,
stride=2,
dilation=1,
padding=3,
output_padding=0,
),
nn.ReLU(),
nn.ConvTranspose2d(
512,
256,
kernel_size=7,
stride=2,
dilation=1,
padding=3,
output_padding=0,
),
nn.ReLU(),
nn.ConvTranspose2d(
256,
128,
kernel_size=7,
stride=2,
dilation=1,
padding=3,
output_padding=1,
),
nn.ReLU(),
nn.ConvTranspose2d(
128,
64,
kernel_size=7,
stride=2,
dilation=1,
padding=3,
output_padding=1,
),
nn.ReLU(),
nn.ConvTranspose2d(
64, 32, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=1
),
nn.ReLU(),
nn.ConvTranspose2d(
32, 1, kernel_size=7, stride=2, dilation=1, padding=3, output_padding=1
),
nn.Sigmoid(),
)
[docs] def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1024, 7, 7)
x_deconv = self.deconv(x)
# x_deconv = nn.functional.upsample(x_deconv, size=(400, 400), mode='bilinear', align_corners=False)
return x_deconv
[docs]class VAE(nn.Module):
def __init__(self, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
[docs] def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar).to(device)
eps = torch.randn_like(std).to(device)
return mu + eps * std
[docs] def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decoder(z)
return x_recon, mu, logvar