import torch
import pickle
import numpy as np
import torch.nn as nn
from vegans.utils.utils import get_input_dim
from vegans.utils.layers import LayerReshape, LayerPrintSize
[docs]class MyGenerator(nn.Module):
def __init__(self, x_dim, gen_in_dim):
super().__init__()
if len(gen_in_dim) == 1:
out_shape = (128, 8, 8)
self.linear_part = nn.Sequential(
nn.Linear(in_features=gen_in_dim[0], out_features=1024),
nn.LeakyReLU(0.1),
nn.Linear(in_features=1024, out_features=np.prod(out_shape)),
nn.LeakyReLU(0.1),
LayerReshape(shape=out_shape)
)
gen_in_dim = out_shape
else:
self.linear_part = nn.Identity()
self.hidden_part = nn.Sequential(
nn.ConvTranspose2d(in_channels=gen_in_dim[0], out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(num_features=128),
nn.LeakyReLU(0.1),
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(num_features=128),
nn.LeakyReLU(0.1),
)
desired_output = x_dim[1]
current_output = gen_in_dim[1]
in_channels = 128
i = 3
while current_output != desired_output:
out_channels = in_channels // 2
current_output *= 2
if current_output != desired_output:
self.hidden_part.add_module("ConvTraspose{}".format(i), nn.ConvTranspose2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1
)
)
self.hidden_part.add_module("Batchnorm{}".format(i), nn.BatchNorm2d(num_features=out_channels))
self.hidden_part.add_module("LeakyRelu{}".format(i), nn.LeakyReLU(0.1))
else: # Last layer
self.hidden_part.add_module("ConvTraspose{}".format(i), nn.ConvTranspose2d(
in_channels=in_channels, out_channels=3, kernel_size=4, stride=2, padding=1
)
)
in_channels = in_channels // 2
i += 1
self.output = nn.Sigmoid()
[docs] def forward(self, x):
x = self.linear_part(x)
x = self.hidden_part(x)
return self.output(x)
[docs]def load_celeba_generator(x_dim, z_dim, y_dim=None):
""" Load some celeba architecture for the generator.
Parameters
----------
z_dim : integer, list
Indicating the number of dimensions for the latent space.
y_dim : None, optional
Indicating the number of dimensions for the labels.
Returns
-------
torch.nn.Module
Architectures for generator,.
"""
z_dim = [z_dim] if isinstance(z_dim, int) else z_dim
y_dim = tuple([y_dim]) if isinstance(y_dim, int) else y_dim
if len(z_dim) == 3:
assert z_dim[1] % 2 == 0, "z_dim[1] must be divisible by 2. Given: {}.".format(z_dim[1])
assert x_dim[1] % 2 == 0, "`x_dim[1]` must be divisible by 2. Given: {}.".format(x_dim[1])
assert x_dim[1] % z_dim[1] == 0, "`x_dim[1]` must be divisible by `z_dim[1]`. Given: {} and {}.".format(x_dim[1], z_dim[1])
assert (x_dim[1] / z_dim[1]) % 2 == 0, "`x_dim[1]/z_dim[1]` must be divisible by 2. Given: {} and {}.".format(x_dim[1], z_dim[1])
assert z_dim[1] == z_dim[2], "`z_dim[1]` must be equal to `z_dim[2]`. Given: {} and {}.".format(z_dim[1], z_dim[2])
gen_in_dim = get_input_dim(dim1=z_dim, dim2=y_dim) if y_dim is not None else z_dim
return MyGenerator(x_dim=x_dim, gen_in_dim=gen_in_dim)
[docs]class MyAdversary(nn.Module):
def __init__(self, adv_in_dim, last_layer_activation):
super().__init__()
self.hidden_part = nn.Sequential(
nn.Conv2d(in_channels=adv_in_dim[0], out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(num_features=32),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=64),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(num_features=256),
nn.Flatten()
)
current_output = self.hidden_part(torch.randn(size=(2, *adv_in_dim))).shape
self.linear_part = nn.Sequential(
nn.Linear(in_features=current_output[1], out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=1),
)
self.output = last_layer_activation()
[docs] def forward(self, x):
x = self.hidden_part(x)
x = self.linear_part(x)
return self.output(x)
[docs]def load_celeba_adversary(x_dim, y_dim=None, adv_type="Critic"):
""" Load some celeba architecture for the adversary.
Parameters
----------
y_dim : integer, list, optional
Indicating the number of dimensions for the labels.
Returns
-------
torch.nn.Module
Architectures for adversary.
"""
possible_types = ["Discriminator", "Critic"]
if adv_type == "Critic":
last_layer_activation = nn.Identity
elif adv_type == "Discriminator":
last_layer_activation = nn.Sigmoid
else:
raise ValueError("'adv_type' must be one of: {}.".format(possible_types))
adv_in_dim = get_input_dim(dim1=x_dim, dim2=y_dim) if y_dim is not None else x_dim
return MyAdversary(adv_in_dim=adv_in_dim, last_layer_activation=last_layer_activation)
[docs]class MyEncoder(nn.Module):
def __init__(self, enc_in_dim, z_dim):
super().__init__()
self.hidden_part = nn.Sequential(
nn.Conv2d(in_channels=enc_in_dim[0], out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=5, stride=2, padding=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=5, stride=2, padding=2),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=5, stride=2, padding=2),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.Flatten(),
)
sample_input = torch.rand([2, *enc_in_dim])
flattened_nodes = tuple(self.hidden_part(sample_input).shape)[1]
self.linear = nn.Linear(in_features=flattened_nodes, out_features=np.prod(z_dim))
self.reshape = LayerReshape(shape=z_dim)
self.output = nn.Identity()
[docs] def forward(self, x):
x = self.hidden_part(x)
x = self.linear(x)
x = self.reshape(x)
return self.output(x)
[docs]def load_celeba_encoder(x_dim, z_dim, y_dim=None):
""" Load some celeba architecture for the encoder.
Parameters
----------
x_dim : integer, list
Indicating the number of dimensions for the real data.
z_dim : integer, list
Indicating the number of dimensions for the latent space.
y_dim : integer, list, optional
Indicating the number of dimensions for the labels.
Returns
-------
torch.nn.Module
Architectures for encoder.
"""
enc_in_dim = get_input_dim(dim1=x_dim, dim2=y_dim) if y_dim is not None else x_dim
return MyEncoder(enc_in_dim=enc_in_dim, z_dim=z_dim)
[docs]class MyDecoder(nn.Module):
def __init__(self, x_dim, dec_in_dim):
super().__init__()
self.hidden_part = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=np.prod(dec_in_dim), out_features=np.prod([1, 8, 8])),
LayerReshape(shape=[1, 8, 8]),
)
desired_output = x_dim[1]
current_output = 8
in_channels = 1
i = 2
while current_output != desired_output:
out_channels = in_channels * 2
current_output *= 2
if current_output != desired_output:
self.hidden_part.add_module("ConvTraspose{}".format(i), nn.ConvTranspose2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1
)
)
self.hidden_part.add_module("Batchnorm{}".format(i), nn.BatchNorm2d(num_features=out_channels))
self.hidden_part.add_module("LeakyRelu{}".format(i), nn.LeakyReLU(0.1))
else: # Last layer
self.hidden_part.add_module("ConvTraspose{}".format(i), nn.ConvTranspose2d(
in_channels=in_channels, out_channels=3, kernel_size=4, stride=2, padding=1
)
)
in_channels = in_channels * 2
i += 1
self.output = nn.Sigmoid()
[docs] def forward(self, x):
x = self.hidden_part(x)
return self.output(x)
[docs]def load_celeba_decoder(x_dim, z_dim, y_dim=None):
""" Load some mnist architecture for the decoder.
Parameters
----------
z_dim : integer, list
Indicating the number of dimensions for the latent space.
y_dim : integer, list, optional
Indicating the number of dimensions for the labels.
Returns
-------
torch.nn.Module
Architectures for decoder.
"""
assert x_dim[1] % 2 == 0, "`x_dim[1]` must be divisible by 2. Given: {}.".format(x_dim[1])
assert x_dim[1] % 8 == 0, "`x_dim[1]` must be divisible by 8. Given: {}.".format(x_dim[1])
assert (x_dim[1] / 8) % 2 == 0, "`x_dim[1]/8` must be divisible by 2. Given: {}.".format(x_dim[1])
dec_in_dim = get_input_dim(dim1=z_dim, dim2=y_dim) if y_dim is not None else z_dim
return MyDecoder(x_dim=x_dim, dec_in_dim=dec_in_dim)