import torch
import pickle
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from vegans.utils.layers import LayerReshape
[docs]class DataSet(Dataset):
def __init__(self, X, y=None):
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.y is not None:
return self.X[index], self.y[index]
return self.X[index]
[docs]class KLLoss():
def __init__(self, eps):
self.eps = eps
[docs] def __call__(self, input, target):
""" Compute the Kullback-Leibler loss for GANs.
Parameters
----------
input : torch.Tensor
Input tensor. Output of a critic.
Returns
-------
torch.Tensor
KL divergence
"""
return -torch.mean(torch.log(input / (1 + self.eps - input) + self.eps))
[docs]class WassersteinLoss():
[docs] def __call__(self, input, target):
""" Compute the Wasserstein loss / divergence.
Also known as earthmover distance.
Parameters
----------
input : torch.Tensor
Input tensor. Output of a critic.
target : torch.Tensor
Label, either 1 or -1. Zeros are translated to -1.
Returns
-------
torch.Tensor
Wasserstein divergence
"""
assert torch.unique(target).shape[0] <= 2, "Only two different values for target allowed."
target[target==0] = -1
return torch.mean(target*input)
[docs]class NormalNegativeLogLikelihood():
def __call__(self, x, mu, variance, eps=1e-6):
negative_log_likelihood = 1/(2*variance + eps)*(x-mu)**2 + 0.5*torch.log(variance + eps)
negative_log_likelihood = negative_log_likelihood.sum(axis=1).mean()
return negative_log_likelihood
[docs]def concatenate(tensor1, tensor2):
""" Concatenates two 2D or 4D tensors.
Parameters
----------
tensor1 : torch.Tensor
2D or 4D tensor.
tensor2 : torch.Tensor
2D or 4D tensor.
Returns
-------
torch.Tensor
Cncatenation of tensor1 and tensor2.
Raises
------
NotImplementedError
If tensors do not have 2 or 4 dimensions.
"""
assert tensor1.shape[0] == tensor2.shape[0], (
"Tensors to concatenate must have same dim 0. Tensor1: {}. Tensor2: {}.".format(tensor1.shape[0], tensor2.shape[0])
)
batch_size = tensor1.shape[0]
if tensor1.shape == tensor2.shape:
return torch.cat((tensor1, tensor2), axis=1).float()
elif (len(tensor1.shape) == 2) and (len(tensor2.shape) == 2):
return torch.cat((tensor1, tensor2), axis=1).float()
elif (len(tensor1.shape) == 4) and (len(tensor2.shape) == 2):
y_dim = tensor2.shape[1]
tensor2 = torch.reshape(tensor2, shape=(batch_size, y_dim, 1, 1))
tensor2 = torch.tile(tensor2, dims=(1, 1, *tensor1.shape[2:]))
elif (len(tensor1.shape) == 2) and (len(tensor2.shape) == 4):
y_dim = tensor1.shape[1]
tensor1 = torch.reshape(tensor1, shape=(batch_size, y_dim, 1, 1))
tensor1 = torch.tile(tensor1, dims=(1, 1, *tensor2.shape[2:]))
elif (len(tensor1.shape) == 4) and (len(tensor2.shape) == 4):
return torch.cat((tensor1, tensor2), axis=1).float()
else:
raise AssertionError("tensor1 and tensor2 must have 2 or 4 dimensions. Given: {} and {}.".format(tensor1.shape, tensor2.shape))
return torch.cat((tensor1, tensor2), axis=1).float()
[docs]def plot_losses(losses, show=True, share=False):
"""
Plots losses for generator and discriminator on a common plot.
Parameters
----------
losses : dict
Dictionary containing the losses for some networks. The structure of the dictionary is:
```
{
mode1: {loss_type1_1: losses1_1, loss_type1_2: losses1_2, ...},
mode2: {loss_type2_1: losses2_1, loss_type2_2: losses2_2, ...},
...
}
```
where `mode` is probably one of "Train" or "Test", loss_type might be "Generator", "Adversary", "Encoder", ...
and losses are lists of loss values collected during training.
show : bool, optional
If True, `plt.show` is called to visualise the images directly.
share : bool, optional
If true, axis ticks are shared between plots.
Returns
-------
plt.figure, plt.axis
Created figure and axis objects.
"""
if share:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
for mode, loss_dict in losses.items():
for loss_type, loss in loss_dict.items():
ax.plot(loss, lw=2, label=mode+loss_type)
ax.set_xlabel('Iterations')
ax.legend()
else:
n = len(losses["Train"])
nrows = int(np.sqrt(n))
ncols = n // nrows
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 9))
axs = np.ravel(axs)
for mode, loss_dict in losses.items():
for ax, (loss_type, loss) in zip(axs, loss_dict.items()):
ax.plot(loss, lw=2, label=mode)
ax.set_xlabel('Iterations')
ax.set_title(loss_type)
ax.set_facecolor("#ecffe7")
ax.legend()
fig.tight_layout()
if show:
plt.show()
return fig, ax
[docs]def plot_images(images, labels=None, show=True, n=None):
""" Plot a number of input images with optional label
Parameters
----------
images : np.array
Must be of shape [nr_samples, height, width] or [nr_samples, height, width, 3].
labels : np.array, optional
Array of labels used in the title.
show : bool, optional
If True, `plt.show` is called to visualise the images directly.
n : None, optional
Number of images to be drawn, maximum is 36.
Returns
-------
plt.figure, plt.axis
Created figure and axis objects.
"""
if len(images.shape)==4 and images.shape[1] == 3:
images = invert_channel_order(images=images)
elif len(images.shape)==4 and images.shape[1] == 1:
images = images.reshape((-1, images.shape[2], images.shape[3]))
if n is None:
n = images.shape[0]
if n > 36:
n = 36
nrows = int(np.sqrt(n))
ncols = n // nrows
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8, 8))
axs = np.ravel(axs)
for i, ax in enumerate(axs):
ax.imshow(images[i])
ax.axis("off")
if labels is not None:
ax.set_title("Label: {}".format(labels[i]))
fig.tight_layout()
if show:
plt.show()
return fig, axs
[docs]def create_gif(source_path, target_path=None):
"""Create a GIF from images contained on the source path.
Parameters
----------
source_path : string
Path pointing to the source directory with .png files.
target_path : string, optional
Name of the created GIF.
"""
import os
import imageio
source_path = source_path+"/" if not source_path.endswith("/") else source_path
images = []
for file_name in sorted(os.listdir(source_path)):
if file_name.endswith('.png'):
file_path = os.path.join(source_path, file_name)
images.append(imageio.imread(file_path))
if target_path is None:
target_path = source_path+"movie.gif"
imageio.mimsave(target_path, images)
[docs]def invert_channel_order(images):
assert len(images.shape) == 4, "`images` must be of shape [batch_size, nr_channels, height, width]. Given: {}.".format(images.shape)
assert images.shape[1] == 3 or images.shape[3] == 3, (
"`images` must have 3 colour channels at second or fourth shape position. Given: {}.".format(images.shape)
)
inverted_images = []
if images.shape[1] == 3:
image_y = images.shape[2]
image_x = images.shape[3]
for i, image in enumerate(images):
red_channel = image[0].reshape(image_y, image_x)
green_channel = image[1].reshape(image_y, image_x)
blue_channel = image[2].reshape(image_y, image_x)
image = np.stack((red_channel, green_channel, blue_channel), axis=-1)
inverted_images.append(image)
elif images.shape[3] == 3:
image_y = images.shape[1]
image_x = images.shape[2]
for i, image in enumerate(images):
red_channel = image[:, :, 0].reshape(image_y, image_x)
green_channel = image[:, :, 1].reshape(image_y, image_x)
blue_channel = image[:, :, 2].reshape(image_y, image_x)
image = np.stack((red_channel, green_channel, blue_channel), axis=0)
inverted_images.append(image)
return np.array(inverted_images)