import re
import json
import torch
import numpy as np
from torch import nn
from torch.nn import Module, Sequential
from vegans.utils.torchsummary import summary
[docs]class NeuralNetwork(Module):
""" Basic abstraction for single networks.
These networks form the building blocks for the generative adversarial networks.
Mainly responsible for consistency checks.
"""
def __init__(self, network, name, input_size, device, ngpu, secure):
super(NeuralNetwork, self).__init__()
self.name = name
self.input_size = input_size
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.ngpu = ngpu
self.secure = secure
if isinstance(input_size, int):
self.input_size = tuple([input_size])
elif isinstance(input_size, list):
self.input_size = tuple(input_size)
assert isinstance(network, torch.nn.Module), "`network` must be instance of nn.Module."
try:
type(network[-1])
self.input_type = "Sequential"
except TypeError:
self.input_type = "Object"
self.network = network.to(self.device)
if self.secure:
self._validate_input()
if ngpu is not None and ngpu < 0:
self.ngpu = len([torch.cuda.device(i) for i in range(torch.cuda.device_count())])
if self.ngpu is not None and self.device=="cuda":
if self.ngpu > 1:
self.network = torch.nn.DataParallel(self.network)
self.network = network.to(self.device)
self.output_size = self._get_output_shape()[1:]
[docs] def forward(self, x):
output = self.network(x)
return output
def _validate_input(self):
iterative_layers = self._get_iterative_layers(self.network, self.input_type)
for layer in iterative_layers:
if "in_features" in layer.__dict__:
first_input = layer.__dict__["in_features"]
break
elif "in_channels" in layer.__dict__:
first_input = layer.__dict__["in_channels"]
break
elif "num_features" in layer.__dict__:
first_input = layer.__dict__["num_features"]
break
else:
raise ValueError("No layer with `in_features`, `in_channels` or `num_features` found.")
if np.prod([first_input]) == np.prod(self.input_size):
pass
elif (len(self.input_size) > 1) & (self.input_size[0] == first_input):
pass
else:
raise TypeError(
"\n\tInput mismatch for **{}**:\n".format(self.name) +
"\t\tExpected (first layer 'in_features'/'in_channels'): {}. Given input_size (z_dim/x_dim (+y_dim)): {}.\n\n".format(
first_input, self.input_size) +
"\t\tONLY RELEVANT IF CONDITIONAL NETWORK IS USED:\n" +
"\t\tIf you are trying to use a conditional model please make sure you adjusted the input size\n" +
"\t\tof the first layer in this architecture for the label vector / image.\n"
"\t\tIn this case, use vegans.utils.utils.get_input_dim(in_dim, y_dim) and adjust this architecture's\n" +
"\t\tfirst layer input accordingly. See the conditional examples on github for help."
)
return True
@staticmethod
def _get_iterative_layers(network, input_type):
if input_type == "Sequential":
return network
elif input_type == "Object":
iterative_net = []
for _, layers in network.__dict__["_modules"].items():
try:
for layer in layers:
iterative_net.append(layer)
except TypeError:
iterative_net.append(layers)
return iterative_net
else:
raise NotImplemented("Network must be Sequential or Object.")
def _get_output_shape(self):
sample_input = torch.rand([2, *self.input_size]).to(self.device)
return tuple(self.network(sample_input).shape)
#########################################################################
# Utility functions
#########################################################################
[docs] def summary(self):
print(self.name)
print("-"*len(self.name))
print("Input shape: ", self.input_size)
return summary(self, input_size=self.input_size, device=self.device)
def __str__(self):
return self.name
[docs] def get_number_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
[docs]class Generator(NeuralNetwork):
def __init__(self, network, input_size, device, ngpu, secure=True):
super().__init__(network, input_size=input_size, name="Generator", device=device, ngpu=ngpu, secure=secure)
[docs]class Adversary(NeuralNetwork):
""" Implements adversary architecture.
Might either be a discriminator (output [0, 1]) or critic (output [-Inf, Inf]).
"""
def __init__(self, network, input_size, adv_type, device, ngpu, secure=True):
if secure:
try:
last_layer_type = type(NeuralNetwork._get_iterative_layers(network=network, input_type="Sequential")[-1])
except TypeError:
last_layer_type = type(NeuralNetwork._get_iterative_layers(network=network, input_type="Object")[-1])
valid_last_layer = None
valid_types = ["Discriminator", "Critic", "Autoencoder"]
if adv_type == "Discriminator":
valid_last_layer = [torch.nn.Sigmoid]
elif adv_type == "Critic":
valid_last_layer = [torch.nn.Linear, torch.nn.Identity]
else:
if adv_type not in valid_types:
raise TypeError("`adv_type` must be one of {}. Given: {}.".format(valid_types, adv_type))
self._type = adv_type
if valid_last_layer is not None:
assert last_layer_type in valid_last_layer, (
"Last layer activation function of {} needs to be one of '{}'. Given: {}."
.format(adv_type, valid_last_layer, last_layer_type)
)
super().__init__(network, input_size=input_size, name="Adversary", device=device, ngpu=ngpu, secure=secure)
[docs] def predict(self, x):
return self(x)
[docs]class Encoder(NeuralNetwork):
def __init__(self, network, input_size, device, ngpu, secure=True):
if secure:
valid_last_layer = [torch.nn.Linear, torch.nn.Identity]
try:
last_layer_type = type(NeuralNetwork._get_iterative_layers(network=network, input_type="Sequential")[-1])
except TypeError:
last_layer_type = type(NeuralNetwork._get_iterative_layers(network=network, input_type="Object")[-1])
assert last_layer_type in valid_last_layer, (
"Last layer activation function of Encoder needs to be one of '{}'.".format(valid_last_layer) +
"Given: {}.".format(last_layer_type)
)
super().__init__(network, input_size=input_size, name="Encoder", device=device, ngpu=ngpu, secure=secure)
[docs]class Decoder(NeuralNetwork):
def __init__(self, network, input_size, device, ngpu, secure=True):
super().__init__(network, input_size=input_size, name="Decoder", device=device, ngpu=ngpu, secure=secure)
[docs]class Autoencoder(nn.Module):
def __init__(self, encoder, decoder):
super(Autoencoder, self).__init__()
self.name = "Autoencoder"
self.encoder = encoder
self.decoder = decoder
[docs] def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
[docs] def summary(self):
self.encoder.summary()
print("\n\n")
self.decoder.summary()
[docs] def get_number_params(self):
""" Returns the number of parameters in the model.
Returns
-------
dict
Dictionary containing the number of parameters per network.
"""
nr_params_dict = {
"Encoder": self.encoder.get_number_params(),
"Decoder": self.decoder.get_number_params()
}
return nr_params_dict