Source code for darts.models.components.glu_variants

import torch
from torch import nn

from darts.models.components.feed_forward import FeedForward

GLU_FFN = ["GLU", "Bilinear", "ReGLU", "GEGLU", "SwiGLU", "ReLU", "GELU"]


# GLU Variants Improve Transformer
# Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202
[docs]class GLU(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward( d_model, d_ff, dropout, nn.Sigmoid(), True, False, False, False )
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)
[docs]class Bilinear(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward( d_model, d_ff, dropout, nn.Identity(), True, False, False, False )
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)
[docs]class ReGLU(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward( d_model, d_ff, dropout, nn.ReLU(), True, False, False, False )
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)
[docs]class GEGLU(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward( d_model, d_ff, dropout, nn.GELU(), True, False, False, False )
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)
[docs]class SwiGLU(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward( d_model, d_ff, dropout, nn.SiLU(), True, False, False, False )
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)
[docs]class ReLU(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward(d_model, d_ff, dropout, nn.ReLU())
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)
[docs]class GELU(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ffn = FeedForward(d_model, d_ff, dropout, nn.GELU())
[docs] def forward(self, x: torch.Tensor): return self.ffn(x)