Source code for darts.models.forecasting.tide_model

"""
Time-series Dense Encoder (TiDE)
------
"""

from typing import Optional, Tuple

import torch
import torch.nn as nn

from darts.logging import get_logger, raise_log
from darts.models.forecasting.pl_forecasting_module import (
    PLMixedCovariatesModule,
    io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel
from darts.utils.torch import MonteCarloDropout

MixedCovariatesTrainTensorType = Tuple[
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]


logger = get_logger(__name__)


class _ResidualBlock(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_size: int,
        dropout: float,
        use_layer_norm: bool,
    ):
        """Pytorch module implementing the Residual Block from the TiDE paper."""
        super().__init__()

        # dense layer with ReLU activation with dropout
        self.dense = nn.Sequential(
            nn.Linear(input_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_dim),
            MonteCarloDropout(dropout),
        )

        # linear skip connection from input to output of self.dense
        self.skip = nn.Linear(input_dim, output_dim)

        # layer normalization as output
        if use_layer_norm:
            self.layer_norm = nn.LayerNorm(output_dim)
        else:
            self.layer_norm = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # residual connection
        x = self.dense(x) + self.skip(x)

        # layer normalization
        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return x


class _TideModule(PLMixedCovariatesModule):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        future_cov_dim: int,
        static_cov_dim: int,
        nr_params: int,
        num_encoder_layers: int,
        num_decoder_layers: int,
        decoder_output_dim: int,
        hidden_size: int,
        temporal_decoder_hidden: int,
        temporal_width_past: int,
        temporal_width_future: int,
        use_layer_norm: bool,
        dropout: float,
        temporal_hidden_size_past: Optional[int] = None,
        temporal_hidden_size_future: Optional[int] = None,
        **kwargs,
    ):
        """Pytorch module implementing the TiDE architecture.

        Parameters
        ----------
        input_dim
            The number of input components (target + optional past covariates + optional future covariates).
        output_dim
            Number of output components in the target.
        future_cov_dim
            Number of future covariates.
        static_cov_dim
            Number of static covariates.
        nr_params
            The number of parameters of the likelihood (or 1 if no likelihood is used).
        num_encoder_layers
            Number of stacked Residual Blocks in the encoder.
        num_decoder_layers
            Number of stacked Residual Blocks in the decoder.
        decoder_output_dim
            The number of output components of the decoder.
        hidden_size
            The width of the hidden layers in the encoder/decoder Residual Blocks.
        temporal_decoder_hidden
            The width of the hidden layers in the temporal decoder.
        temporal_width_past
            The width of the past covariate embedding space.
        temporal_width_future
            The width of the future covariate embedding space.
        temporal_hidden_size_past
            The width of the hidden layers in the past covariate projection Residual Block.
        temporal_hidden_size_future
            The width of the hidden layers in the future covariate projection Residual Block.
        use_layer_norm
            Whether to use layer normalization in the Residual Blocks.
        dropout
            Dropout probability
        **kwargs
            all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
            base class.

        Inputs
        ------
        x
            Tuple of Tensors `(x_past, x_future, x_static)` where `x_past` is the input/past chunk and
            `x_future`is the output/future chunk. Input dimensions are `(batch_size, time_steps, components)`
        Outputs
        -------
        y
            Tensor of shape `(batch_size, output_chunk_length, output_dim, nr_params)`

        """

        super().__init__(**kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.past_cov_dim = input_dim - output_dim - future_cov_dim
        self.future_cov_dim = future_cov_dim
        self.static_cov_dim = static_cov_dim
        self.nr_params = nr_params
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.decoder_output_dim = decoder_output_dim
        self.hidden_size = hidden_size
        self.temporal_decoder_hidden = temporal_decoder_hidden
        self.use_layer_norm = use_layer_norm
        self.dropout = dropout
        self.temporal_width_past = temporal_width_past
        self.temporal_width_future = temporal_width_future
        self.temporal_hidden_size_past = temporal_hidden_size_past or hidden_size
        self.temporal_hidden_size_future = temporal_hidden_size_future or hidden_size

        # past covariates handling: either feature projection, raw features, or no features
        self.past_cov_projection = None
        if self.past_cov_dim and temporal_width_past:
            # residual block for past covariates feature projection
            self.past_cov_projection = _ResidualBlock(
                input_dim=self.past_cov_dim,
                output_dim=temporal_width_past,
                hidden_size=temporal_hidden_size_past,
                use_layer_norm=use_layer_norm,
                dropout=dropout,
            )
            past_covariates_flat_dim = self.input_chunk_length * temporal_width_past
        elif self.past_cov_dim:
            # skip projection and use raw features
            past_covariates_flat_dim = self.input_chunk_length * self.past_cov_dim
        else:
            past_covariates_flat_dim = 0

        # future covariates handling: either feature projection, raw features, or no features
        self.future_cov_projection = None
        if future_cov_dim and self.temporal_width_future:
            # residual block for future covariates feature projection
            self.future_cov_projection = _ResidualBlock(
                input_dim=future_cov_dim,
                output_dim=temporal_width_future,
                hidden_size=temporal_hidden_size_future,
                use_layer_norm=use_layer_norm,
                dropout=dropout,
            )
            historical_future_covariates_flat_dim = (
                self.input_chunk_length + self.output_chunk_length
            ) * temporal_width_future
        elif future_cov_dim:
            # skip projection and use raw features
            historical_future_covariates_flat_dim = (
                self.input_chunk_length + self.output_chunk_length
            ) * future_cov_dim
        else:
            historical_future_covariates_flat_dim = 0

        encoder_dim = (
            self.input_chunk_length * output_dim
            + past_covariates_flat_dim
            + historical_future_covariates_flat_dim
            + static_cov_dim
        )

        self.encoders = nn.Sequential(
            _ResidualBlock(
                input_dim=encoder_dim,
                output_dim=hidden_size,
                hidden_size=hidden_size,
                use_layer_norm=use_layer_norm,
                dropout=dropout,
            ),
            *[
                _ResidualBlock(
                    input_dim=hidden_size,
                    output_dim=hidden_size,
                    hidden_size=hidden_size,
                    use_layer_norm=use_layer_norm,
                    dropout=dropout,
                )
                for _ in range(num_encoder_layers - 1)
            ],
        )

        self.decoders = nn.Sequential(
            *[
                _ResidualBlock(
                    input_dim=hidden_size,
                    output_dim=hidden_size,
                    hidden_size=hidden_size,
                    use_layer_norm=use_layer_norm,
                    dropout=dropout,
                )
                for _ in range(num_decoder_layers - 1)
            ],
            # add decoder output layer
            _ResidualBlock(
                input_dim=hidden_size,
                output_dim=decoder_output_dim
                * self.output_chunk_length
                * self.nr_params,
                hidden_size=hidden_size,
                use_layer_norm=use_layer_norm,
                dropout=dropout,
            ),
        )

        decoder_input_dim = decoder_output_dim * self.nr_params
        if temporal_width_future and future_cov_dim:
            decoder_input_dim += temporal_width_future
        elif future_cov_dim:
            decoder_input_dim += future_cov_dim

        self.temporal_decoder = _ResidualBlock(
            input_dim=decoder_input_dim,
            output_dim=output_dim * self.nr_params,
            hidden_size=temporal_decoder_hidden,
            use_layer_norm=use_layer_norm,
            dropout=dropout,
        )

        self.lookback_skip = nn.Linear(
            self.input_chunk_length, self.output_chunk_length * self.nr_params
        )

    @io_processor
    def forward(
        self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
    ) -> torch.Tensor:
        """TiDE model forward pass.
        Parameters
        ----------
        x_in
            comes as tuple `(x_past, x_future, x_static)` where `x_past` is the input/past chunk and `x_future`
            is the output/future chunk. Input dimensions are `(batch_size, time_steps, components)`
        Returns
        -------
        torch.Tensor
            The output Tensor of shape `(batch_size, output_chunk_length, output_dim, nr_params)`
        """

        # x has shape (batch_size, input_chunk_length, input_dim)
        # x_future_covariates has shape (batch_size, input_chunk_length, future_cov_dim)
        # x_static_covariates has shape (batch_size, static_cov_dim)
        x, x_future_covariates, x_static_covariates = x_in

        x_lookback = x[:, :, : self.output_dim]

        # future covariates: feature projection or raw features
        # historical future covariates need to be extracted from x and stacked with part of future covariates
        if self.future_cov_dim:
            x_dynamic_future_covariates = torch.cat(
                [
                    x[
                        :,
                        :,
                        None if self.future_cov_dim == 0 else -self.future_cov_dim :,
                    ],
                    x_future_covariates,
                ],
                dim=1,
            )
            if self.temporal_width_future:
                # project input features across all input and output time steps
                x_dynamic_future_covariates = self.future_cov_projection(
                    x_dynamic_future_covariates
                )
        else:
            x_dynamic_future_covariates = None

        # past covariates: feature projection or raw features
        # the past covariates are embedded in `x`
        if self.past_cov_dim:
            x_dynamic_past_covariates = x[
                :,
                :,
                self.output_dim : self.output_dim + self.past_cov_dim,
            ]
            if self.temporal_width_past:
                # project input features across all input time steps
                x_dynamic_past_covariates = self.past_cov_projection(
                    x_dynamic_past_covariates
                )
        else:
            x_dynamic_past_covariates = None

        # setup input to encoder
        encoded = [
            x_lookback,
            x_dynamic_past_covariates,
            x_dynamic_future_covariates,
            x_static_covariates,
        ]
        encoded = [t.flatten(start_dim=1) for t in encoded if t is not None]
        encoded = torch.cat(encoded, dim=1)

        # encoder, decode, reshape
        encoded = self.encoders(encoded)
        decoded = self.decoders(encoded)

        # get view that is batch size x output chunk length x self.decoder_output_dim x nr params
        decoded = decoded.view(x.shape[0], self.output_chunk_length, -1)

        # stack and temporally decode with future covariate last output steps
        temporal_decoder_input = [
            decoded,
            (
                x_dynamic_future_covariates[:, -self.output_chunk_length :, :]
                if self.future_cov_dim > 0
                else None
            ),
        ]
        temporal_decoder_input = [t for t in temporal_decoder_input if t is not None]

        temporal_decoder_input = torch.cat(temporal_decoder_input, dim=2)
        temporal_decoded = self.temporal_decoder(temporal_decoder_input)

        # pass x_lookback through self.lookback_skip but swap the last two dimensions
        # this is needed because the skip connection is applied across the input time steps
        # and not across the output time steps
        skip = self.lookback_skip(x_lookback.transpose(1, 2)).transpose(1, 2)

        # add skip connection
        y = temporal_decoded + skip.reshape_as(
            temporal_decoded
        )  # skip.view(temporal_decoded.shape)

        y = y.view(-1, self.output_chunk_length, self.output_dim, self.nr_params)
        return y


[docs]class TiDEModel(MixedCovariatesTorchModel): def __init__( self, input_chunk_length: int, output_chunk_length: int, output_chunk_shift: int = 0, num_encoder_layers: int = 1, num_decoder_layers: int = 1, decoder_output_dim: int = 16, hidden_size: int = 128, temporal_width_past: int = 4, temporal_width_future: int = 4, temporal_hidden_size_past: int = None, temporal_hidden_size_future: int = None, temporal_decoder_hidden: int = 32, use_layer_norm: bool = False, dropout: float = 0.1, use_static_covariates: bool = True, **kwargs, ): """An implementation of the TiDE model, as presented in [1]_. TiDE is similar to Transformers (implemented in :class:`TransformerModel`), but attempts to provide better performance at lower computational cost by introducing multilayer perceptron (MLP)-based encoder-decoders without attention. This model supports past covariates (known for `input_chunk_length` points before prediction time), future covariates (known for `output_chunk_length` points after prediction time), static covariates, as well as probabilistic forecasting. The encoder and decoder are implemented as a series of residual blocks. The number of residual blocks in the encoder and decoder can be controlled via ``num_encoder_layers`` and ``num_decoder_layers`` respectively. The width of the layers in the residual blocks can be controlled via ``hidden_size``. Similarly, the width of the layers in the temporal decoder can be controlled via ``temporal_decoder_hidden``. Parameters ---------- input_chunk_length Number of time steps in the past to take as a model input (per chunk). Applies to the target series, and past and/or future covariates (if the model supports it). output_chunk_length Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values from future covariates to use as a model input (if the model supports future covariates). It is not the same as forecast horizon `n` used in `predict()`, which is the desired number of prediction points generated using either a one-shot- or autoregressive forecast. Setting `n <= output_chunk_length` prevents auto-regression. This is useful when the covariates don't extend far enough into the future, or to prohibit the model from using future values of past and / or future covariates for prediction (depending on the model's covariate support). output_chunk_shift Optionally, the number of steps to shift the start of the output chunk into the future (relative to the input chunk end). This will create a gap between the input and output. If the model supports `future_covariates`, the future values are extracted from the shifted output chunk. Predictions will start `output_chunk_shift` steps after the end of the target `series`. If `output_chunk_shift` is set, the model cannot generate autoregressive predictions (`n > output_chunk_length`). num_encoder_layers The number of residual blocks in the encoder. num_decoder_layers The number of residual blocks in the decoder. decoder_output_dim The dimensionality of the output of the decoder. hidden_size The width of the layers in the residual blocks of the encoder and decoder. temporal_width_past The width of the output layer in the past covariate projection residual block. If `0`, will bypass feature projection and use the raw feature data. temporal_width_future The width of the output layer in the future covariate projection residual block. If `0`, will bypass feature projection and use the raw feature data. temporal_hidden_size_past The width of the hidden layer in the past covariate projection residual block. If not specified, defaults to `hidden_size`, which is the width of the hidden layer in the encoder and decoder. This is likely to be too large in many cases, so it is recommended to set this parameter explicitly. temporal_hidden_size_future The width of the hidden layer in the future covariate projection residual block. If not specified, defaults to `hidden_size`, which is the width of the hidden layer in the encoder and decoder. This is likely to be too large in many cases, so it is recommended to set this parameter explicitly. temporal_decoder_hidden The width of the layers in the temporal decoder. use_layer_norm Whether to use layer normalization in the residual blocks. dropout The dropout probability to be used in fully connected layers. This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at prediction time). **kwargs Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and Darts' :class:`TorchForecastingModel`. loss_fn PyTorch loss function used for training. This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified. Default: ``torch.nn.MSELoss()``. likelihood One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for probabilistic forecasts. Default: ``None``. torch_metrics A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``. optimizer_cls The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``. optimizer_kwargs Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}`` for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls`` will be used. Default: ``None``. lr_scheduler_cls Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. use_reversible_instance_norm Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. log_tensorboard If set, use Tensorboard to log the different parameters. The logs will be located in: ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``. nr_epochs_val_period Number of epochs to wait before evaluating the validation loss (if a validation ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``. force_reset If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will be discarded). Default: ``False``. save_checkpoints Whether to automatically save the untrained model and checkpoints from training. To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`, :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using :func:`save()` and loaded using :func:`load()`. Default: ``False``. add_encoders A large number of past and future covariates can be automatically generated with `add_encoders`. This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to transform the generated covariates. This happens all under one hood and only needs to be specified at model creation. Read :meth:`SequentialEncoder <darts.dataprocessing.encoders.SequentialEncoder>` to find out more about ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features: .. highlight:: python .. code-block:: python def encode_year(idx): return (idx.year - 1950) / 50 add_encoders={ 'cyclic': {'future': ['month']}, 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, 'transformer': Scaler(), 'tz': 'CET' } .. random_state Control the randomness of the weights initialization. Check this `link <https://scikit-learn.org/stable/glossary.html#term-random_state>`_ for more details. Default: ``None``. pl_trainer_kwargs By default :class:`TorchForecastingModel` creates a PyTorch Lightning Trainer with several useful presets that performs the training, validation and prediction processes. These presets include automatic checkpointing, tensorboard logging, setting the torch device and more. With ``pl_trainer_kwargs`` you can add additional kwargs to instantiate the PyTorch Lightning trainer object. Check the `PL Trainer documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`_ for more information about the supported kwargs. Default: ``None``. Running on GPU(s) is also possible using ``pl_trainer_kwargs`` by specifying keys ``"accelerator", "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. For more info, see here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_basic.html#train-on-multiple-gpus With parameter ``"callbacks"`` you can add custom or PyTorch-Lightning built-in callbacks to Darts' :class:`TorchForecastingModel`. Below is an example for adding EarlyStopping to the training process. The model will stop training early if the validation loss `val_loss` does not improve beyond specifications. For more information on callbacks, visit: `PyTorch Lightning Callbacks <https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html>`_ .. highlight:: python .. code-block:: python from pytorch_lightning.callbacks.early_stopping import EarlyStopping # stop training when validation loss does not decrease more than 0.05 (`min_delta`) over # a period of 5 epochs (`patience`) my_stopper = EarlyStopping( monitor="val_loss", patience=5, min_delta=0.05, mode='min', ) pl_trainer_kwargs={"callbacks": [my_stopper]} .. Note that you can also use a custom PyTorch Lightning Trainer for training and prediction with optional parameter ``trainer`` in :func:`fit()` and :func:`predict()`. show_warnings whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of your forecasting use case. Default: ``False``. References ---------- .. [1] A. Das et al. "Long-term Forecasting with TiDE: Time-series Dense Encoder", http://arxiv.org/abs/2304.08424 .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p Examples -------- >>> from darts.datasets import WeatherDataset >>> from darts.models import TiDEModel >>> series = WeatherDataset().load() >>> # predicting atmospheric pressure >>> target = series['p (mbar)'][:100] >>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100) >>> past_cov = series['rain (mm)'][:100] >>> # optionally, use future temperatures (pretending this component is a forecast) >>> future_cov = series['T (degC)'][:106] >>> model = TiDEModel( >>> input_chunk_length=6, >>> output_chunk_length=6, >>> n_epochs=20 >>> ) >>> model.fit(target, past_covariates=past_cov, future_covariates=future_cov) >>> pred = model.predict(6) >>> pred.values() array([[1008.1667634 ], [ 997.08337201], [1017.72035839], [1005.10790392], [ 998.90537286], [1005.91534452]]) .. note:: `TiDE example notebook <https://unit8co.github.io/darts/examples/18-TiDE-examples.html>`_ presents techniques that can be used to improve the forecasts quality compared to this simple usage example. """ if temporal_width_past < 0 or temporal_width_future < 0: raise_log( ValueError( "`temporal_width_past` and `temporal_width_future` must be >= 0." ), logger=logger, ) super().__init__(**self._extract_torch_model_params(**self.model_params)) # extract pytorch lightning module kwargs self.pl_module_params = self._extract_pl_module_params(**self.model_params) self.num_encoder_layers = num_encoder_layers self.num_decoder_layers = num_decoder_layers self.decoder_output_dim = decoder_output_dim self.hidden_size = hidden_size self.temporal_width_past = temporal_width_past self.temporal_width_future = temporal_width_future self.temporal_hidden_size_past = temporal_hidden_size_past or hidden_size self.temporal_hidden_size_future = temporal_hidden_size_future or hidden_size self.temporal_decoder_hidden = temporal_decoder_hidden self._considers_static_covariates = use_static_covariates self.use_layer_norm = use_layer_norm self.dropout = dropout def _create_model( self, train_sample: MixedCovariatesTrainTensorType ) -> torch.nn.Module: ( past_target, past_covariates, historic_future_covariates, future_covariates, static_covariates, future_target, ) = train_sample # target, past covariates, historic future covariates input_dim = ( past_target.shape[1] + (past_covariates.shape[1] if past_covariates is not None else 0) + ( historic_future_covariates.shape[1] if historic_future_covariates is not None else 0 ) ) output_dim = future_target.shape[1] future_cov_dim = ( future_covariates.shape[1] if future_covariates is not None else 0 ) static_cov_dim = ( static_covariates.shape[0] * static_covariates.shape[1] if static_covariates is not None else 0 ) nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters past_cov_dim = input_dim - output_dim - future_cov_dim if past_cov_dim and self.temporal_width_past >= past_cov_dim: logger.warning( f"number of `past_covariates` features is <= `temporal_width_past`, leading to feature expansion." f"number of covariates: {past_cov_dim}, `temporal_width_past={self.temporal_width_past}`." ) if future_cov_dim and self.temporal_width_future >= future_cov_dim: logger.warning( f"number of `future_covariates` features is <= `temporal_width_future`, leading to feature expansion." f"number of covariates: {future_cov_dim}, `temporal_width_future={self.temporal_width_future}`." ) return _TideModule( input_dim=input_dim, output_dim=output_dim, future_cov_dim=future_cov_dim, static_cov_dim=static_cov_dim, nr_params=nr_params, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, decoder_output_dim=self.decoder_output_dim, hidden_size=self.hidden_size, temporal_width_past=self.temporal_width_past, temporal_width_future=self.temporal_width_future, temporal_hidden_size_past=self.temporal_hidden_size_past, temporal_hidden_size_future=self.temporal_hidden_size_future, temporal_decoder_hidden=self.temporal_decoder_hidden, use_layer_norm=self.use_layer_norm, dropout=self.dropout, **self.pl_module_params, ) @property def supports_static_covariates(self) -> bool: return True @property def supports_multivariate(self) -> bool: return True def _check_ckpt_parameters(self, tfm_save): # new parameters were added that will break loading weights new_params = ["temporal_hidden_size_past", "temporal_hidden_size_future"] for param in new_params: if param not in tfm_save.model_params: tfm_save.model_params[param] = None super()._check_ckpt_parameters(tfm_save)