Source code for darts.models.forecasting.tft_model

"""
Temporal Fusion Transformer (TFT)
-------
"""

from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import LSTM as _LSTM

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.components import glu_variants, layer_norm_variants
from darts.models.components.glu_variants import GLU_FFN
from darts.models.forecasting.pl_forecasting_module import (
    PLMixedCovariatesModule,
    io_processor,
)
from darts.models.forecasting.tft_submodels import (
    _GateAddNorm,
    _GatedResidualNetwork,
    _InterpretableMultiHeadAttention,
    _MultiEmbedding,
    _VariableSelectionNetwork,
    get_embedding_size,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel
from darts.utils.data import (
    MixedCovariatesSequentialDataset,
    MixedCovariatesTrainingDataset,
    TrainingDataset,
)
from darts.utils.likelihood_models import Likelihood, QuantileRegression

logger = get_logger(__name__)

MixedCovariatesTrainTensorType = Tuple[
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]


class _TFTModule(PLMixedCovariatesModule):
    def __init__(
        self,
        output_dim: Tuple[int, int],
        variables_meta: Dict[str, Dict[str, List[str]]],
        num_static_components: int,
        hidden_size: Union[int, List[int]],
        lstm_layers: int,
        num_attention_heads: int,
        full_attention: bool,
        feed_forward: str,
        hidden_continuous_size: int,
        categorical_embedding_sizes: Dict[str, Tuple[int, int]],
        dropout: float,
        add_relative_index: bool,
        norm_type: Union[str, nn.Module],
        **kwargs,
    ):
        """PyTorch module implementing the TFT architecture from `this paper <https://arxiv.org/pdf/1912.09363.pdf>`_
        The implementation is built upon `pytorch-forecasting's TemporalFusionTransformer
        <https://pytorch-forecasting.readthedocs.io/en/latest/models.html>`_.

        Parameters
        ----------
        output_dim : Tuple[int, int]
            shape of output given by (n_targets, loss_size). (loss_size corresponds to nr_params in other models).
        variables_meta : Dict[str, Dict[str, List[str]]]
            dict containing variable encoder, decoder variable names for mapping tensors in `_TFTModule.forward()`
        num_static_components
            the number of static components (not variables) of the input target series. This is either equal to the
            number of target components or 1.
        hidden_size : int
            hidden state size of the TFT. It is the main hyper-parameter and common across the internal TFT
            architecture.
        lstm_layers : int
            number of layers for the Long Short Term Memory (LSTM) Encoder and Decoder (1 is a good default).
        num_attention_heads : int
            number of attention heads (4 is a good default)
        full_attention : bool
            If `True`, applies multi-head attention query on past (encoder) and future (decoder) parts. Otherwise,
            only queries on future part. Defaults to `False`.
        feed_forward
            Set the feedforward network block. default `GatedResidualNetwork` or one of the  glu variant.
            Defaults to `GatedResidualNetwork`.
        hidden_continuous_size : int
            default for hidden size for processing continuous variables.
        categorical_embedding_sizes : dict
            A dictionary containing embedding sizes for categorical static covariates. The keys are the column names
            of the categorical static covariates. The values are tuples of integers with
            `(number of unique categories, embedding size)`. For example `{"some_column": (64, 8)}`.
            Note that `TorchForecastingModels` can only handle numeric data. Consider transforming/encoding your data
            with `darts.dataprocessing.transformers.static_covariates_transformer.StaticCovariatesTransformer`.
        dropout : float
            Fraction of neurons affected by Dropout.
        add_relative_index : bool
            Whether to add positional values to future covariates. Defaults to `False`.
            This allows to use the TFTModel without having to pass future_covariates to `fit()` and `train()`.
            It gives a value to the position of each step from input and output chunk relative to the prediction
            point. The values are normalized with `input_chunk_length`.
        likelihood
            The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
            a ``QuantileRegression`` likelihood.
        norm_type: str | nn.Module
            The type of LayerNorm variant to use.
        **kwargs
            all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
            base class.
        """

        super().__init__(**kwargs)

        self.n_targets, self.loss_size = output_dim
        self.variables_meta = variables_meta
        self.num_static_components = num_static_components
        self.hidden_size = hidden_size
        self.hidden_continuous_size = hidden_continuous_size
        self.categorical_embedding_sizes = categorical_embedding_sizes
        self.lstm_layers = lstm_layers
        self.num_attention_heads = num_attention_heads
        self.full_attention = full_attention
        self.feed_forward = feed_forward
        self.dropout = dropout
        self.add_relative_index = add_relative_index

        if isinstance(norm_type, str):
            try:
                self.layer_norm = getattr(layer_norm_variants, norm_type)
            except AttributeError:
                raise_log(
                    AttributeError("please provide a valid layer norm type"),
                )
        else:
            self.layer_norm = norm_type

        # initialize last batch size to check if new mask needs to be generated
        self.batch_size_last = -1
        self.attention_mask = None
        self.relative_index = None

        # general information on variable name endings:
        # _vsn: VariableSelectionNetwork
        # _grn: GatedResidualNetwork
        # _glu: GatedLinearUnit
        # _gan: GateAddNorm
        # _attn: Attention

        # # processing inputs
        # embeddings
        self.input_embeddings = _MultiEmbedding(
            embedding_sizes=categorical_embedding_sizes,
            variable_names=self.categorical_static_variables,
        )

        # continuous variable processing
        self.prescalers_linear = {
            name: nn.Linear(
                (
                    1
                    if name not in self.numeric_static_variables
                    else self.num_static_components
                ),
                self.hidden_continuous_size,
            )
            for name in self.reals
        }

        # static (categorical and numerical) variables
        static_input_sizes = {
            name: self.input_embeddings.output_size[name]
            for name in self.categorical_static_variables
        }
        static_input_sizes.update(
            {
                name: self.hidden_continuous_size
                for name in self.numeric_static_variables
            }
        )

        self.static_covariates_vsn = _VariableSelectionNetwork(
            input_sizes=static_input_sizes,
            hidden_size=self.hidden_size,
            input_embedding_flags={
                name: True for name in self.categorical_static_variables
            },
            dropout=self.dropout,
            prescalers=self.prescalers_linear,
            single_variable_grns={},
            context_size=None,  # no context for static variables
            layer_norm=self.layer_norm,
        )

        # variable selection for encoder and decoder
        encoder_input_sizes = {
            name: self.hidden_continuous_size for name in self.encoder_variables
        }

        decoder_input_sizes = {
            name: self.hidden_continuous_size for name in self.decoder_variables
        }

        self.encoder_vsn = _VariableSelectionNetwork(
            input_sizes=encoder_input_sizes,
            hidden_size=self.hidden_size,
            input_embedding_flags={},  # this would be required for non-static categorical inputs
            dropout=self.dropout,
            context_size=self.hidden_size,
            prescalers=self.prescalers_linear,
            single_variable_grns={},
            layer_norm=self.layer_norm,
        )

        self.decoder_vsn = _VariableSelectionNetwork(
            input_sizes=decoder_input_sizes,
            hidden_size=self.hidden_size,
            input_embedding_flags={},  # this would be required for non-static categorical inputs
            dropout=self.dropout,
            context_size=self.hidden_size,
            prescalers=self.prescalers_linear,
            single_variable_grns={},
            layer_norm=self.layer_norm,
        )

        # static encoders
        # for variable selection
        self.static_context_grn = _GatedResidualNetwork(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            output_size=self.hidden_size,
            dropout=self.dropout,
            layer_norm=self.layer_norm,
        )

        # for hidden state of the lstm
        self.static_context_hidden_encoder_grn = _GatedResidualNetwork(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            output_size=self.hidden_size,
            dropout=self.dropout,
            layer_norm=self.layer_norm,
        )

        # for cell state of the lstm
        self.static_context_cell_encoder_grn = _GatedResidualNetwork(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            output_size=self.hidden_size,
            dropout=self.dropout,
            layer_norm=self.layer_norm,
        )

        # for post lstm static enrichment
        self.static_context_enrichment = _GatedResidualNetwork(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            output_size=self.hidden_size,
            dropout=self.dropout,
            layer_norm=self.layer_norm,
        )

        # lstm encoder (history) and decoder (future) for local processing
        self.lstm_encoder = _LSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=self.lstm_layers,
            dropout=self.dropout if self.lstm_layers > 1 else 0,
            batch_first=True,
        )

        self.lstm_decoder = _LSTM(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=self.lstm_layers,
            dropout=self.dropout if self.lstm_layers > 1 else 0,
            batch_first=True,
        )

        # post lstm GateAddNorm
        self.post_lstm_gan = _GateAddNorm(
            input_size=self.hidden_size, dropout=dropout, layer_norm=self.layer_norm
        )

        # static enrichment and processing past LSTM
        self.static_enrichment_grn = _GatedResidualNetwork(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            output_size=self.hidden_size,
            dropout=self.dropout,
            context_size=self.hidden_size,
            layer_norm=self.layer_norm,
        )

        # attention for long-range processing
        self.multihead_attn = _InterpretableMultiHeadAttention(
            d_model=self.hidden_size,
            n_head=self.num_attention_heads,
            dropout=self.dropout,
        )
        self.post_attn_gan = _GateAddNorm(
            self.hidden_size, dropout=self.dropout, layer_norm=self.layer_norm
        )

        if self.feed_forward == "GatedResidualNetwork":
            self.feed_forward_block = _GatedResidualNetwork(
                self.hidden_size,
                self.hidden_size,
                self.hidden_size,
                dropout=self.dropout,
                layer_norm=self.layer_norm,
            )
        else:
            raise_if_not(
                self.feed_forward in GLU_FFN,
                f"'{self.feed_forward}' is not in {GLU_FFN + ['GatedResidualNetwork']}",
            )
            # use glu variant feedforward layers
            # 4 is a commonly used feedforward multiplier
            self.feed_forward_block = getattr(glu_variants, self.feed_forward)(
                d_model=self.hidden_size, d_ff=self.hidden_size * 4, dropout=dropout
            )

        # output processing -> no dropout at this late stage
        self.pre_output_gan = _GateAddNorm(
            self.hidden_size, dropout=None, layer_norm=self.layer_norm
        )

        self.output_layer = nn.Linear(self.hidden_size, self.n_targets * self.loss_size)

        self._attn_out_weights = None
        self._static_covariate_var = None
        self._encoder_sparse_weights = None
        self._decoder_sparse_weights = None

    @property
    def reals(self) -> List[str]:
        """
        List of all continuous variables in model
        """
        return self.variables_meta["model_config"]["reals_input"]

    @property
    def static_variables(self) -> List[str]:
        """
        List of all static variables in model
        """
        return self.variables_meta["model_config"]["static_input"]

    @property
    def numeric_static_variables(self) -> List[str]:
        """
        List of numeric static variables in model
        """
        return self.variables_meta["model_config"]["static_input_numeric"]

    @property
    def categorical_static_variables(self) -> List[str]:
        """
        List of categorical static variables in model
        """
        return self.variables_meta["model_config"]["static_input_categorical"]

    @property
    def encoder_variables(self) -> List[str]:
        """
        List of all encoder variables in model (excluding static variables)
        """
        return self.variables_meta["model_config"]["time_varying_encoder_input"]

    @property
    def decoder_variables(self) -> List[str]:
        """
        List of all decoder variables in model (excluding static variables)
        """
        return self.variables_meta["model_config"]["time_varying_decoder_input"]

    @staticmethod
    def expand_static_context(context: torch.Tensor, time_steps: int) -> torch.Tensor:
        """
        add time dimension to static context
        """
        return context[:, None].expand(-1, time_steps, -1)

    @staticmethod
    def get_relative_index(
        encoder_length: int,
        decoder_length: int,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> torch.Tensor:
        """
        Returns scaled time index relative to prediction point.
        """
        index = torch.arange(
            encoder_length + decoder_length, dtype=dtype, device=device
        )
        prediction_index = encoder_length - 1
        index[:encoder_length] = index[:encoder_length] / prediction_index
        index[encoder_length:] = index[encoder_length:] / prediction_index
        return index.reshape(1, len(index), 1).repeat(batch_size, 1, 1)

    @staticmethod
    def get_attention_mask_full(
        time_steps: int, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> torch.Tensor:
        """
        Returns causal mask to apply for self-attention layer.
        """
        eye = torch.eye(time_steps, dtype=dtype, device=device)
        mask = torch.cumsum(eye.unsqueeze(0).repeat(batch_size, 1, 1), dim=1)
        return mask < 1

    @staticmethod
    def get_attention_mask_future(
        encoder_length: int,
        decoder_length: int,
        batch_size: int,
        device: str,
        full_attention: bool,
    ) -> torch.Tensor:
        """
        Returns causal mask to apply for self-attention layer that acts on future input only.
        The model will attend to all `False` values.
        """
        if full_attention:
            # attend to entire past and future input
            decoder_mask = torch.zeros(
                (decoder_length, decoder_length), dtype=torch.bool, device=device
            )
        else:
            # attend only to past steps relative to forecasting step in the future
            # indices to which is attended
            attend_step = torch.arange(decoder_length, device=device)
            # indices for which is predicted
            predict_step = torch.arange(0, decoder_length, device=device)[:, None]
            # do not attend to steps to self or after prediction
            decoder_mask = attend_step >= predict_step
        # attend to all past input
        encoder_mask = torch.zeros(
            batch_size, encoder_length, dtype=torch.bool, device=device
        )
        # combine masks along attended time - first encoder and then decoder

        mask = torch.cat(
            (
                encoder_mask.unsqueeze(1).expand(-1, decoder_length, -1),
                decoder_mask.unsqueeze(0).expand(batch_size, -1, -1),
            ),
            dim=2,
        )
        return mask

    @io_processor
    def forward(
        self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
    ) -> torch.Tensor:
        """TFT model forward pass.

        Parameters
        ----------
        x_in
            comes as tuple `(x_past, x_future, x_static)` where `x_past` is the input/past chunk and `x_future`
            is the output/future chunk. Input dimensions are `(n_samples, n_time_steps, n_variables)`

        Returns
        -------
        torch.Tensor
            the output tensor
        """
        x_cont_past, x_cont_future, x_static = x_in
        dim_samples, dim_time, dim_variable = 0, 1, 2
        device = x_in[0].device

        batch_size = x_cont_past.shape[dim_samples]
        encoder_length = self.input_chunk_length
        decoder_length = self.output_chunk_length
        time_steps = encoder_length + decoder_length

        # avoid unnecessary regeneration of attention mask
        if batch_size != self.batch_size_last:
            self.attention_mask = self.get_attention_mask_future(
                encoder_length=encoder_length,
                decoder_length=decoder_length,
                batch_size=batch_size,
                device=device,
                full_attention=self.full_attention,
            )
            if self.add_relative_index:
                self.relative_index = self.get_relative_index(
                    encoder_length=encoder_length,
                    decoder_length=decoder_length,
                    batch_size=batch_size,
                    device=device,
                    dtype=x_cont_past.dtype,
                )

            self.batch_size_last = batch_size

        if self.add_relative_index:
            x_cont_past = torch.cat(
                [
                    ts[:, :encoder_length, :]
                    for ts in [x_cont_past, self.relative_index]
                    if ts is not None
                ],
                dim=dim_variable,
            )
            x_cont_future = torch.cat(
                [
                    ts[:, -decoder_length:, :]
                    for ts in [x_cont_future, self.relative_index]
                    if ts is not None
                ],
                dim=dim_variable,
            )

        input_vectors_past = {
            name: x_cont_past[..., idx].unsqueeze(-1)
            for idx, name in enumerate(self.encoder_variables)
        }
        input_vectors_future = {
            name: x_cont_future[..., idx].unsqueeze(-1)
            for idx, name in enumerate(self.decoder_variables)
        }

        # Embedding and variable selection
        if self.static_variables:
            # categorical static covariate embeddings
            if self.categorical_static_variables:
                static_embedding = self.input_embeddings(
                    torch.cat(
                        [
                            x_static[:, :, idx]
                            for idx, name in enumerate(self.static_variables)
                            if name in self.categorical_static_variables
                        ],
                        dim=1,
                    ).int()
                )
            else:
                static_embedding = {}
            # add numerical static covariates
            static_embedding.update(
                {
                    name: x_static[:, :, idx]
                    for idx, name in enumerate(self.static_variables)
                    if name in self.numeric_static_variables
                }
            )
            static_embedding, static_covariate_var = self.static_covariates_vsn(
                static_embedding
            )
        else:
            static_embedding = torch.zeros(
                (x_cont_past.shape[0], self.hidden_size),
                dtype=x_cont_past.dtype,
                device=device,
            )
            static_covariate_var = None

        static_context_expanded = self.expand_static_context(
            context=self.static_context_grn(static_embedding), time_steps=time_steps
        )

        embeddings_varying_encoder = {
            name: input_vectors_past[name] for name in self.encoder_variables
        }
        embeddings_varying_encoder, encoder_sparse_weights = self.encoder_vsn(
            x=embeddings_varying_encoder,
            context=static_context_expanded[:, :encoder_length],
        )

        embeddings_varying_decoder = {
            name: input_vectors_future[name] for name in self.decoder_variables
        }
        embeddings_varying_decoder, decoder_sparse_weights = self.decoder_vsn(
            x=embeddings_varying_decoder,
            context=static_context_expanded[:, encoder_length:],
        )

        # LSTM
        # calculate initial state
        input_hidden = (
            self.static_context_hidden_encoder_grn(static_embedding)
            .expand(self.lstm_layers, -1, -1)
            .contiguous()
        )
        input_cell = (
            self.static_context_cell_encoder_grn(static_embedding)
            .expand(self.lstm_layers, -1, -1)
            .contiguous()
        )

        # run local lstm encoder
        encoder_out, (hidden, cell) = self.lstm_encoder(
            input=embeddings_varying_encoder, hx=(input_hidden, input_cell)
        )

        # run local lstm decoder
        decoder_out, _ = self.lstm_decoder(
            input=embeddings_varying_decoder, hx=(hidden, cell)
        )

        lstm_layer = torch.cat([encoder_out, decoder_out], dim=dim_time)
        input_embeddings = torch.cat(
            [embeddings_varying_encoder, embeddings_varying_decoder], dim=dim_time
        )

        # post lstm GateAddNorm
        lstm_out = self.post_lstm_gan(x=lstm_layer, skip=input_embeddings)

        # static enrichment
        static_context_enriched = self.static_context_enrichment(static_embedding)
        attn_input = self.static_enrichment_grn(
            x=lstm_out,
            context=self.expand_static_context(
                context=static_context_enriched, time_steps=time_steps
            ),
        )

        # multi-head attention
        attn_out, attn_out_weights = self.multihead_attn(
            q=attn_input[:, encoder_length:],
            k=attn_input,
            v=attn_input,
            mask=self.attention_mask,
        )

        # skip connection over attention
        attn_out = self.post_attn_gan(
            x=attn_out,
            skip=attn_input[:, encoder_length:],
        )

        # feed-forward
        out = self.feed_forward_block(x=attn_out)

        # skip connection over temporal fusion decoder from LSTM post _GateAddNorm
        out = self.pre_output_gan(
            x=out,
            skip=lstm_out[:, encoder_length:],
        )

        # generate output for n_targets and loss_size elements for loss evaluation
        out = self.output_layer(out)
        out = out.view(
            batch_size, self.output_chunk_length, self.n_targets, self.loss_size
        )
        self._attn_out_weights = attn_out_weights
        self._static_covariate_var = static_covariate_var
        self._encoder_sparse_weights = encoder_sparse_weights
        self._decoder_sparse_weights = decoder_sparse_weights
        return out


[docs]class TFTModel(MixedCovariatesTorchModel): def __init__( self, input_chunk_length: int, output_chunk_length: int, output_chunk_shift: int = 0, hidden_size: Union[int, List[int]] = 16, lstm_layers: int = 1, num_attention_heads: int = 4, full_attention: bool = False, feed_forward: str = "GatedResidualNetwork", dropout: float = 0.1, hidden_continuous_size: int = 8, categorical_embedding_sizes: Optional[ Dict[str, Union[int, Tuple[int, int]]] ] = None, add_relative_index: bool = False, loss_fn: Optional[nn.Module] = None, likelihood: Optional[Likelihood] = None, norm_type: Union[str, nn.Module] = "LayerNorm", use_static_covariates: bool = True, **kwargs, ): """Temporal Fusion Transformers (TFT) for Interpretable Time Series Forecasting. This is an implementation of the TFT architecture, as outlined in [1]_. The internal sub models are adopted from `pytorch-forecasting's TemporalFusionTransformer <https://pytorch-forecasting.readthedocs.io/en/latest/models.html>`_ implementation. This model supports past covariates (known for `input_chunk_length` points before prediction time), future covariates (known for `output_chunk_length` points after prediction time), static covariates, as well as probabilistic forecasting. The TFT applies multi-head attention queries on future inputs from mandatory ``future_covariates``. Specifying future encoders with ``add_encoders`` (read below) can automatically generate future covariates and allows to use the model without having to pass any ``future_covariates`` to :func:`fit()` and :func:`predict()`. By default, this model uses the ``QuantileRegression`` likelihood, which means that its forecasts are probabilistic; it is recommended to call :func`predict()` with ``num_samples >> 1`` to get meaningful results. Parameters ---------- input_chunk_length Number of time steps in the past to take as a model input (per chunk). Applies to the target series, and past and/or future covariates (if the model supports it). Also called: Encoder length output_chunk_length Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values from future covariates to use as a model input (if the model supports future covariates). It is not the same as forecast horizon `n` used in `predict()`, which is the desired number of prediction points generated using either a one-shot- or autoregressive forecast. Setting `n <= output_chunk_length` prevents auto-regression. This is useful when the covariates don't extend far enough into the future, or to prohibit the model from using future values of past and / or future covariates for prediction (depending on the model's covariate support). Also called: Decoder length output_chunk_shift Optionally, the number of steps to shift the start of the output chunk into the future (relative to the input chunk end). This will create a gap between the input and output. If the model supports `future_covariates`, the future values are extracted from the shifted output chunk. Predictions will start `output_chunk_shift` steps after the end of the target `series`. If `output_chunk_shift` is set, the model cannot generate autoregressive predictions (`n > output_chunk_length`). hidden_size Hidden state size of the TFT. It is the main hyper-parameter and common across the internal TFT architecture. lstm_layers Number of layers for the Long Short Term Memory (LSTM) Encoder and Decoder (1 is a good default). num_attention_heads Number of attention heads (4 is a good default) full_attention If ``False``, only attends to previous time steps in the decoder. If ``True`` attends to previous, current, and future time steps. Defaults to ``False``. feed_forward A feedforward network is a fully-connected layer with an activation. TFT Can be one of the glu variant's FeedForward Network (FFN)[2]. The glu variant's FeedForward Network are a series of FFNs designed to work better with Transformer based models. Defaults to ``"GatedResidualNetwork"``. ["GLU", "Bilinear", "ReGLU", "GEGLU", "SwiGLU", "ReLU", "GELU"] or the TFT original FeedForward Network ["GatedResidualNetwork"]. dropout Fraction of neurons affected by dropout. This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at prediction time). hidden_continuous_size Default for hidden size for processing continuous variables categorical_embedding_sizes A dictionary used to construct embeddings for categorical static covariates. The keys are the column names of the categorical static covariates. Each value is either a single integer or a tuple of integers. For a single integer give the number of unique categories (n) of the corresponding variable. For example ``{"some_column": 64}``. The embedding size will be automatically determined by ``min(round(1.6 * n**0.56), 100)``. For a tuple of integers, give (number of unique categories, embedding size). For example ``{"some_column": (64, 8)}``. Note that ``TorchForecastingModels`` only support numeric data. Consider transforming/encoding your data with `darts.dataprocessing.transformers.static_covariates_transformer.StaticCovariatesTransformer`. add_relative_index Whether to add positional values to future covariates. Defaults to ``False``. This allows to use the TFTModel without having to pass future_covariates to :func:`fit()` and :func:`train()`. It gives a value to the position of each step from input and output chunk relative to the prediction point. The values are normalized with ``input_chunk_length``. loss_fn: nn.Module PyTorch loss function used for training. By default, the TFT model is probabilistic and uses a ``likelihood`` instead (``QuantileRegression``). To make the model deterministic, you can set the ` `likelihood`` to None and give a ``loss_fn`` argument. likelihood The likelihood model to be used for probabilistic forecasts. By default, the TFT uses a ``QuantileRegression`` likelihood. norm_type: str | nn.Module The type of LayerNorm variant to use. Default: ``LayerNorm``. Available options are ["LayerNorm", "RMSNorm", "LayerNormNoBias"], or provide a custom nn.Module. use_static_covariates Whether the model should use static covariate information in case the input `series` passed to ``fit()`` contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()``. **kwargs Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and Darts' :class:`TorchForecastingModel`. torch_metrics A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``. optimizer_cls The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``. optimizer_kwargs Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}`` for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls`` will be used. Default: ``None``. lr_scheduler_cls Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. use_reversible_instance_norm Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [3]_. It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. log_tensorboard If set, use Tensorboard to log the different parameters. The logs will be located in: ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``. nr_epochs_val_period Number of epochs to wait before evaluating the validation loss (if a validation ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``. force_reset If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will be discarded). Default: ``False``. save_checkpoints Whether to automatically save the untrained model and checkpoints from training. To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`, :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using :func:`save()` and loaded using :func:`load()`. Default: ``False``. add_encoders A large number of past and future covariates can be automatically generated with `add_encoders`. This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to transform the generated covariates. This happens all under one hood and only needs to be specified at model creation. Read :meth:`SequentialEncoder <darts.dataprocessing.encoders.SequentialEncoder>` to find out more about ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features: .. highlight:: python .. code-block:: python def encode_year(idx): return (idx.year - 1950) / 50 add_encoders={ 'cyclic': {'future': ['month']}, 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, 'transformer': Scaler(), 'tz': 'CET' } .. random_state Control the randomness of the weight's initialization. Check this `link <https://scikit-learn.org/stable/glossary.html#term-random_state>`_ for more details. Default: ``None``. pl_trainer_kwargs By default :class:`TorchForecastingModel` creates a PyTorch Lightning Trainer with several useful presets that performs the training, validation and prediction processes. These presets include automatic checkpointing, tensorboard logging, setting the torch device and more. With ``pl_trainer_kwargs`` you can add additional kwargs to instantiate the PyTorch Lightning trainer object. Check the `PL Trainer documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`_ for more information about the supported kwargs. Default: ``None``. Running on GPU(s) is also possible using ``pl_trainer_kwargs`` by specifying keys ``"accelerator", "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. For more info, see here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_basic.html#train-on-multiple-gpus With parameter ``"callbacks"`` you can add custom or PyTorch-Lightning built-in callbacks to Darts' :class:`TorchForecastingModel`. Below is an example for adding EarlyStopping to the training process. The model will stop training early if the validation loss `val_loss` does not improve beyond specifications. For more information on callbacks, visit: `PyTorch Lightning Callbacks <https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html>`_ .. highlight:: python .. code-block:: python from pytorch_lightning.callbacks.early_stopping import EarlyStopping # stop training when validation loss does not decrease more than 0.05 (`min_delta`) over # a period of 5 epochs (`patience`) my_stopper = EarlyStopping( monitor="val_loss", patience=5, min_delta=0.05, mode='min', ) pl_trainer_kwargs={"callbacks": [my_stopper]} .. Note that you can also use a custom PyTorch Lightning Trainer for training and prediction with optional parameter ``trainer`` in :func:`fit()` and :func:`predict()`. show_warnings whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of your forecasting use case. Default: ``False``. References ---------- .. [1] https://arxiv.org/pdf/1912.09363.pdf .. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202. .. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p Examples -------- >>> from darts.datasets import WeatherDataset >>> from darts.models import TFTModel >>> series = WeatherDataset().load() >>> # predicting atmospheric pressure >>> target = series['p (mbar)'][:100] >>> # optionally, past observed rainfall (pretending to be unknown beyond index 100) >>> past_cov = series['rain (mm)'][:100] >>> # future temperatures (pretending this component is a forecast) >>> future_cov = series['T (degC)'][:106] >>> # by default, TFTModel is trained using a `QuantileRegression` making it a probabilistic forecasting model >>> model = TFTModel( >>> input_chunk_length=6, >>> output_chunk_length=6, >>> n_epochs=5, >>> ) >>> # future_covariates are mandatory for `TFTModel` >>> model.fit(target, past_covariates=past_cov, future_covariates=future_cov) >>> # TFTModel is probabilistic by definition; using `num_samples >> 1` to generate probabilistic forecasts >>> pred = model.predict(6, num_samples=100) >>> # shape : (forecast horizon, components, num_samples) >>> pred.all_values().shape (6, 1, 100) >>> # showing the first 3 samples for each timestamp >>> pred.all_values()[:,:,:3] array([[[-0.06414202, -0.7188093 , 0.52541292]], [[ 0.02928407, -0.40867163, 1.19650033]], [[ 0.77252372, -0.50859694, 0.360166 ]], [[ 0.9586113 , 1.24147138, -0.01625545]], [[ 1.06863863, 0.2987822 , -0.69213369]], [[-0.83076568, -0.25780816, -0.28318784]]]) .. note:: `TFT example notebook <https://unit8co.github.io/darts/examples/13-TFT-examples.html>`_ presents techniques that can be used to improve the forecasts quality compared to this simple usage example. """ model_kwargs = {key: val for key, val in self.model_params.items()} if likelihood is None and loss_fn is None: # This is the default if no loss information is provided model_kwargs["loss_fn"] = None model_kwargs["likelihood"] = QuantileRegression() super().__init__(**self._extract_torch_model_params(**model_kwargs)) # extract pytorch lightning module kwargs self.pl_module_params = self._extract_pl_module_params(**model_kwargs) self.hidden_size = hidden_size self.lstm_layers = lstm_layers self.num_attention_heads = num_attention_heads self.full_attention = full_attention self.feed_forward = feed_forward self.dropout = dropout self.hidden_continuous_size = hidden_continuous_size self.categorical_embedding_sizes = ( categorical_embedding_sizes if categorical_embedding_sizes is not None else {} ) self.add_relative_index = add_relative_index self.output_dim: Optional[Tuple[int, int]] = None self.norm_type = norm_type self._considers_static_covariates = use_static_covariates def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Module: """ `train_sample` contains the following tensors: (past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates, future_target) each tensor has shape (n_timesteps, n_variables) - past/historic tensors have shape (input_chunk_length, n_variables) - future tensors have shape (output_chunk_length, n_variables) - static covariates have shape (component, static variable) Darts Interpretation of pytorch-forecasting's TimeSeriesDataSet: time_varying_knowns : future_covariates (including historic_future_covariates) time_varying_unknowns : past_targets, past_covariates time_varying_encoders : [past_targets, past_covariates, historic_future_covariates, future_covariates] time_varying_decoders : [historic_future_covariates, future_covariates] `variable_meta` is used in TFT to access specific variables """ ( past_target, past_covariate, historic_future_covariate, future_covariate, static_covariates, future_target, ) = train_sample # add a covariate placeholder so that relative index will be included if self.add_relative_index: time_steps = self.input_chunk_length + self.output_chunk_length expand_future_covariate = np.arange(time_steps).reshape((time_steps, 1)) historic_future_covariate = np.concatenate( [ ts[: self.input_chunk_length] for ts in [historic_future_covariate, expand_future_covariate] if ts is not None ], axis=1, ) future_covariate = np.concatenate( [ ts[-self.output_chunk_length :] for ts in [future_covariate, expand_future_covariate] if ts is not None ], axis=1, ) self.output_dim = ( (future_target.shape[1], 1) if self.likelihood is None else (future_target.shape[1], self.likelihood.num_parameters) ) tensors = [ past_target, past_covariate, historic_future_covariate, # for time varying encoders future_covariate, future_target, # for time varying decoders static_covariates, # for static encoder ] type_names = [ "past_target", "past_covariate", "historic_future_covariate", "future_covariate", "future_target", "static_covariate", ] variable_names = [ "target", "past_covariate", "future_covariate", "future_covariate", "target", "static_covariate", ] variables_meta = { "input": { type_name: [f"{var_name}_{i}" for i in range(tensor.shape[1])] for type_name, var_name, tensor in zip( type_names, variable_names, tensors ) if tensor is not None }, "model_config": {}, } reals_input = [] categorical_input = [] time_varying_encoder_input = [] time_varying_decoder_input = [] static_input = [] static_input_numeric = [] static_input_categorical = [] categorical_embedding_sizes = {} for input_var in type_names: if input_var in variables_meta["input"]: vars_meta = variables_meta["input"][input_var] if input_var in [ "past_target", "past_covariate", "historic_future_covariate", ]: time_varying_encoder_input += vars_meta reals_input += vars_meta elif input_var in ["future_covariate"]: time_varying_decoder_input += vars_meta reals_input += vars_meta elif input_var in ["static_covariate"]: if ( self.static_covariates is None ): # when training with fit_from_dataset static_cols = pd.Index( [i for i in range(static_covariates.shape[1])] ) else: static_cols = self.static_covariates.columns numeric_mask = ~static_cols.isin(self.categorical_embedding_sizes) for idx, (static_var, col_name, is_numeric) in enumerate( zip(vars_meta, static_cols, numeric_mask) ): static_input.append(static_var) if is_numeric: static_input_numeric.append(static_var) reals_input.append(static_var) else: # get embedding sizes for each categorical variable embedding = self.categorical_embedding_sizes[col_name] raise_if_not( isinstance(embedding, (int, tuple)), "Dict values of `categorical_embedding_sizes` must either be integers or tuples. Read " "the TFTModel documentation for more information.", logger, ) if isinstance(embedding, int): embedding = (embedding, get_embedding_size(n=embedding)) categorical_embedding_sizes[vars_meta[idx]] = embedding static_input_categorical.append(static_var) categorical_input.append(static_var) variables_meta["model_config"]["reals_input"] = list(dict.fromkeys(reals_input)) variables_meta["model_config"]["categorical_input"] = list( dict.fromkeys(categorical_input) ) variables_meta["model_config"]["time_varying_encoder_input"] = list( dict.fromkeys(time_varying_encoder_input) ) variables_meta["model_config"]["time_varying_decoder_input"] = list( dict.fromkeys(time_varying_decoder_input) ) variables_meta["model_config"]["static_input"] = list( dict.fromkeys(static_input) ) variables_meta["model_config"]["static_input_numeric"] = list( dict.fromkeys(static_input_numeric) ) variables_meta["model_config"]["static_input_categorical"] = list( dict.fromkeys(static_input_categorical) ) n_static_components = ( len(static_covariates) if static_covariates is not None else 0 ) self.categorical_embedding_sizes = categorical_embedding_sizes return _TFTModule( output_dim=self.output_dim, variables_meta=variables_meta, num_static_components=n_static_components, hidden_size=self.hidden_size, lstm_layers=self.lstm_layers, dropout=self.dropout, num_attention_heads=self.num_attention_heads, full_attention=self.full_attention, feed_forward=self.feed_forward, hidden_continuous_size=self.hidden_continuous_size, categorical_embedding_sizes=self.categorical_embedding_sizes, add_relative_index=self.add_relative_index, norm_type=self.norm_type, **self.pl_module_params, ) def _build_train_dataset( self, target: Sequence[TimeSeries], past_covariates: Optional[Sequence[TimeSeries]], future_covariates: Optional[Sequence[TimeSeries]], max_samples_per_ts: Optional[int], ) -> MixedCovariatesSequentialDataset: raise_if( future_covariates is None and not self.add_relative_index, "TFTModel requires future covariates. The model applies multi-head attention queries on future " "inputs. Consider specifying a future encoder with `add_encoders` or setting `add_relative_index` " "to `True` at model creation (read TFT model docs for more information). " "These will automatically generate `future_covariates` from indexes.", logger, ) return MixedCovariatesSequentialDataset( target_series=target, past_covariates=past_covariates, future_covariates=future_covariates, input_chunk_length=self.input_chunk_length, output_chunk_length=self.output_chunk_length, output_chunk_shift=self.output_chunk_shift, max_samples_per_ts=max_samples_per_ts, use_static_covariates=self.uses_static_covariates, ) def _verify_train_dataset_type(self, train_dataset: TrainingDataset): raise_if_not( isinstance(train_dataset, MixedCovariatesTrainingDataset), "TFTModel requires a training dataset of type MixedCovariatesTrainingDataset.", ) @property def supports_multivariate(self) -> bool: return True @property def supports_static_covariates(self) -> bool: return True