Source code for vegans.utils.loading.architectures.mnist

import torch
import pickle

import numpy as np
import torch.nn as nn

from vegans.utils.utils import get_input_dim
from vegans.utils.layers import LayerReshape, LayerPrintSize


[docs]class MyGenerator(nn.Module): def __init__(self, x_dim, gen_in_dim): super().__init__() if len(gen_in_dim) == 1: self.prepare = nn.Sequential( nn.Linear(in_features=gen_in_dim[0], out_features=256), LayerReshape(shape=[1, 16, 16]) ) nr_channels = 1 else: current_dim = z_dim[1] nr_channels = gen_in_dim[0] self.prepare = [] while current_dim < 16: self.prepare.append(nn.ConvTranspose2d( in_channels=nr_channels, out_channels=5, kernel_size=4, stride=2, padding=1 ) ) nr_channels = 5 current_dim *= 2 self.prepare = nn.Sequential(*self.prepare) self.encoding = nn.Sequential( nn.Conv2d(in_channels=nr_channels, out_channels=64, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(0.2), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=128), nn.LeakyReLU(0.2), nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(num_features=256), nn.LeakyReLU(0.2), ) self.decoding = nn.Sequential( nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=128), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=32, out_channels=x_dim[0], kernel_size=3, stride=1, padding=1), ) self.output = nn.Sigmoid()
[docs] def forward(self, x): x = self.prepare(x) x = self.encoding(x) x = self.decoding(x) return self.output(x)
[docs]def load_mnist_generator(x_dim, z_dim, y_dim=None): """ Load some mnist architecture for the generator. Parameters ---------- z_dim : integer, list Indicating the number of dimensions for the latent space. y_dim : None, optional Indicating the number of dimensions for the labels. Returns ------- torch.nn.Module Architectures for generator,. """ z_dim = [z_dim] if isinstance(z_dim, int) else z_dim y_dim = tuple([y_dim]) if isinstance(y_dim, int) else y_dim if len(z_dim) > 1: assert (z_dim[1] <= 16) and (z_dim[1] % 2 == 0), "z_dim[1] must be smaller 16 and divisible by 2. Given: {}.".format(z_dim[1]) assert z_dim[1] == z_dim[2], "z_dim[1] must be equal to z_dim[2]. Given: {} and {}.".format(z_dim[1], z_dim[2]) gen_in_dim = get_input_dim(dim1=z_dim, dim2=y_dim) if y_dim is not None else z_dim return MyGenerator(x_dim=x_dim, gen_in_dim=gen_in_dim)
[docs]class MyAdversary(nn.Module): def __init__(self, adv_in_dim, last_layer): super().__init__() self.hidden_part = nn.Sequential( nn.Conv2d(in_channels=adv_in_dim[0], out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=64), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=64), nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1), ) self.output = last_layer()
[docs] def forward(self, x): x = self.hidden_part(x) return self.output(x)
[docs]def load_mnist_adversary(x_dim=(1, 32, 32), y_dim=None, adv_type="Critic"): """ Load some mnist architecture for the adversary. Parameters ---------- y_dim : integer, list, optional Indicating the number of dimensions for the labels. Returns ------- torch.nn.Module Architectures for adversary. """ possible_types = ["Discriminator", "Critic"] if adv_type == "Critic": last_layer = nn.Identity elif adv_type == "Discriminator": last_layer = nn.Sigmoid else: raise ValueError("'adv_type' must be one of: {}.".format(possible_types)) adv_in_dim = get_input_dim(dim1=x_dim, dim2=y_dim) if y_dim is not None else x_dim return MyAdversary(adv_in_dim=adv_in_dim, last_layer=last_layer)
[docs]class MyEncoder(nn.Module): def __init__(self, enc_in_dim, z_dim): super().__init__() self.hidden_part = nn.Sequential( nn.Conv2d(in_channels=enc_in_dim[0], out_channels=16, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=5, stride=2, padding=2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=5, stride=2, padding=2), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=5, stride=2, padding=2), nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1), nn.Flatten(), ) sample_input = torch.rand([2, *enc_in_dim]) flattened_nodes = tuple(self.hidden_part(sample_input).shape)[1] self.linear = nn.Linear(in_features=flattened_nodes, out_features=np.prod(z_dim)) self.output = nn.Identity()
[docs] def forward(self, x): x = self.hidden_part(x) x = self.linear(x) return self.output(x)
[docs]def load_mnist_encoder(x_dim, z_dim, y_dim=None): """ Load some mnist architecture for the encoder. Parameters ---------- x_dim : integer, list Indicating the number of dimensions for the real data. z_dim : integer, list Indicating the number of dimensions for the latent space. y_dim : integer, list, optional Indicating the number of dimensions for the labels. Returns ------- torch.nn.Module Architectures for encoder. """ z_dim = [z_dim] if isinstance(z_dim, int) else z_dim assert len(z_dim) == 1, "z_dim must be of length one. Given: {}.".format(z_dim) enc_in_dim = get_input_dim(dim1=x_dim, dim2=y_dim) if y_dim is not None else x_dim return MyEncoder(enc_in_dim=enc_in_dim, z_dim=z_dim)
[docs]class MyDecoder(nn.Module): def __init__(self, x_dim, dec_in_dim): super().__init__() self.hidden_part = nn.Sequential( nn.Linear(in_features=np.prod(dec_in_dim), out_features=np.prod([1, 8, 8])), LayerReshape(shape=[1, 8, 8]), nn.ConvTranspose2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=32, out_channels=x_dim[0], kernel_size=3, stride=1, padding=1), ) self.output = nn.Sigmoid()
[docs] def forward(self, x): x = self.hidden_part(x) return self.output(x)
[docs]def load_mnist_decoder(x_dim, z_dim, y_dim=None): """ Load some mnist architecture for the decoder. Parameters ---------- z_dim : integer, list Indicating the number of dimensions for the latent space. y_dim : integer, list, optional Indicating the number of dimensions for the labels. Returns ------- torch.nn.Module Architectures for decoder. """ dec_in_dim = get_input_dim(dim1=z_dim, dim2=y_dim) if y_dim is not None else z_dim return MyDecoder(x_dim=x_dim, dec_in_dim=dec_in_dim)
[docs]class MyAutoEncoder(nn.Module): def __init__(self, ae_in_dim): super().__init__() self.encoding = nn.Sequential( nn.Conv2d(in_channels=ae_in_dim[0], out_channels=32, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.2), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(0.2), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(num_features=128), nn.LeakyReLU(0.2), nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(num_features=256), nn.LeakyReLU(0.2), ) self.decoding = nn.Sequential( nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=128), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(num_features=32), nn.LeakyReLU(0.2), nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1), ) self.output = nn.Sigmoid()
[docs] def forward(self, x): x = self.encoding(x) x = self.decoding(x) return self.output(x)
[docs]def load_mnist_autoencoder(x_dim=(1, 32, 32), y_dim=None): """ Load some mnist architecture for the auto-encoder. Parameters ---------- x_dim : integer, list Indicating the number of dimensions for the real data. y_dim : integer, list, optional Indicating the number of dimensions for the labels. Returns ------- torch.nn.Module Architectures for autoencoder. """ ae_in_dim = get_input_dim(dim1=x_dim, dim2=y_dim) if y_dim is not None else x_dim return MyAutoEncoder(ae_in_dim=ae_in_dim)