Source code for darts.models.forecasting.dlinear

"""
D-Linear
--------
"""

from typing import Optional, Tuple

import torch
import torch.nn as nn

from darts.logging import raise_if
from darts.models.forecasting.pl_forecasting_module import (
    PLMixedCovariatesModule,
    io_processor,
)
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel

MixedCovariatesTrainTensorType = Tuple[
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]


class _MovingAvg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """

    def __init__(self, kernel_size, stride):
        super().__init__()
        # asymmetrical padding, shorther on the ts start side
        if kernel_size % 2 == 0:
            self.padding_size_left = kernel_size // 2 - 1
            self.padding_size_right = kernel_size // 2
        else:
            self.padding_size_left = (kernel_size - 1) // 2
            self.padding_size_right = (kernel_size - 1) // 2
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series with the extremities values
        front = x[:, 0:1, :].repeat(1, self.padding_size_left, 1)
        end = x[:, -1:, :].repeat(1, self.padding_size_right, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class _SeriesDecomp(nn.Module):
    """
    Series decomposition block
    """

    def __init__(self, kernel_size):
        super().__init__()
        self.moving_avg = _MovingAvg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class _DLinearModule(PLMixedCovariatesModule):
    """
    DLinear module
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        future_cov_dim: int,
        static_cov_dim: int,
        nr_params: int,
        shared_weights: bool,
        kernel_size: int,
        const_init: bool,
        **kwargs,
    ):
        """PyTorch module implementing the DLinear architecture.

        Parameters
        ----------
        input_dim
            The number of input components (target + optional covariate)
        output_dim
            Number of output components in the target
        future_cov_dim
            Number of components in the future covariates
        static_cov_dim
            Dimensionality of the static covariates (either component-specific or shared)
        nr_params
            The number of parameters of the likelihood (or 1 if no likelihood is used).
        shared_weights
            Whether to use shared weights for the components of the series.
            ** Ignores covariates when True. **
        kernel_size
            The size of the kernel for the moving average
        const_init
            Whether to initialize the weights to 1/in_len
        **kwargs
            all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
            base class.

        Inputs
        ------
        x of shape `(batch_size, input_chunk_length)`
            Tensor containing the input sequence.

        Outputs
        -------
        y of shape `(batch_size, output_chunk_length, target_size/output_dim, nr_params)`
            Tensor containing the output of the NBEATS module.
        """

        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.future_cov_dim = future_cov_dim
        self.static_cov_dim = static_cov_dim
        self.nr_params = nr_params
        self.const_init = const_init

        # Decomposition Kernel Size
        self.decomposition = _SeriesDecomp(kernel_size)
        self.shared_weights = shared_weights

        def _create_linear_layer(in_dim, out_dim):
            layer = nn.Linear(in_dim, out_dim)
            if self.const_init:
                layer.weight = nn.Parameter(
                    (1.0 / in_dim) * torch.ones(layer.weight.shape)
                )
            return layer

        if self.shared_weights:
            layer_in_dim = self.input_chunk_length
            layer_out_dim = self.output_chunk_length * self.nr_params
        else:
            layer_in_dim = self.input_chunk_length * self.input_dim
            layer_out_dim = self.output_chunk_length * self.output_dim * self.nr_params

        self.linear_seasonal = _create_linear_layer(layer_in_dim, layer_out_dim)
        self.linear_trend = _create_linear_layer(layer_in_dim, layer_out_dim)

        if self.future_cov_dim != 0:
            # future covariates layer acts on time steps independently
            self.linear_fut_cov = _create_linear_layer(
                self.future_cov_dim, self.output_dim * self.nr_params
            )
        if self.static_cov_dim != 0:
            self.linear_static_cov = _create_linear_layer(
                self.static_cov_dim, layer_out_dim
            )

    @io_processor
    def forward(
        self, x_in: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
    ):
        """
        x_in
            comes as tuple `(x_past, x_future, x_static)` where `x_past` is the input/past chunk and `x_future`
            is the output/future chunk. Input dimensions are `(n_samples, n_time_steps, n_variables)`
        """

        x, x_future, x_static = x_in  # x: (batch, in_len, in_dim)
        batch, _, _ = x.shape

        if self.shared_weights:
            # discard covariates, to ensure that in_dim == out_dim
            x = x[:, :, : self.output_dim]

            # extract trend
            res, trend = self.decomposition(x)

            # permute to (batch, in_dim, in_len) and apply linear layer on last dimension
            seasonal_output = self.linear_seasonal(res.permute(0, 2, 1))
            trend_output = self.linear_trend(trend.permute(0, 2, 1))

            x = seasonal_output + trend_output

            # extract nr_params
            x = x.view(batch, self.output_dim, self.output_chunk_length, self.nr_params)

            # permute back to (batch, out_len, out_dim, nr_params)
            x = x.permute(0, 2, 1, 3)

        else:
            res, trend = self.decomposition(x)

            # (in_len * in_dim) => (out_len * out_dim * out_nr_params)
            seasonal_output = self.linear_seasonal(res.view(batch, -1))
            trend_output = self.linear_trend(trend.view(batch, -1))
            seasonal_output = seasonal_output.view(
                batch, self.output_chunk_length, self.output_dim * self.nr_params
            )
            trend_output = trend_output.view(
                batch, self.output_chunk_length, self.output_dim * self.nr_params
            )

            x = seasonal_output + trend_output

            if self.future_cov_dim != 0:
                # x_future might be shorter than output_chunk_length when n < output_chunk_length
                # so we need to pad it with zeros at the end to match the output_chunk_length
                x_future = torch.nn.functional.pad(
                    input=x_future,
                    pad=(0, 0, 0, self.output_chunk_length - x_future.shape[1]),
                    mode="constant",
                    value=0,
                )

                fut_cov_output = self.linear_fut_cov(x_future)
                x = x + fut_cov_output.view(
                    batch, self.output_chunk_length, self.output_dim * self.nr_params
                )

            if self.static_cov_dim != 0:
                static_cov_output = self.linear_static_cov(x_static.reshape(batch, -1))
                x = x + static_cov_output.view(
                    batch, self.output_chunk_length, self.output_dim * self.nr_params
                )

            # extract nr_params
            x = x.view(batch, self.output_chunk_length, self.output_dim, self.nr_params)

        return x


[docs]class DLinearModel(MixedCovariatesTorchModel): def __init__( self, input_chunk_length: int, output_chunk_length: int, output_chunk_shift: int = 0, shared_weights: bool = False, kernel_size: int = 25, const_init: bool = True, use_static_covariates: bool = True, **kwargs, ): """An implementation of the DLinear model, as presented in [1]_. This implementation is improved by allowing the optional use of past covariates (known for `input_chunk_length` points before prediction time), future covariates (known for `output_chunk_length` points after prediction time) and static covariates, as well as supporting probabilistic forecasting. Parameters ---------- input_chunk_length Number of time steps in the past to take as a model input (per chunk). Applies to the target series, and past and/or future covariates (if the model supports it). output_chunk_length Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values from future covariates to use as a model input (if the model supports future covariates). It is not the same as forecast horizon `n` used in `predict()`, which is the desired number of prediction points generated using either a one-shot- or autoregressive forecast. Setting `n <= output_chunk_length` prevents auto-regression. This is useful when the covariates don't extend far enough into the future, or to prohibit the model from using future values of past and / or future covariates for prediction (depending on the model's covariate support). output_chunk_shift Optionally, the number of steps to shift the start of the output chunk into the future (relative to the input chunk end). This will create a gap between the input and output. If the model supports `future_covariates`, the future values are extracted from the shifted output chunk. Predictions will start `output_chunk_shift` steps after the end of the target `series`. If `output_chunk_shift` is set, the model cannot generate autoregressive predictions (`n > output_chunk_length`). shared_weights Whether to use shared weights for all components of multivariate series. .. warning:: When set to True, covariates will be ignored as a 1-to-1 mapping is required between input dimensions and output dimensions. .. Default: False. kernel_size The size of the kernel for the moving average (default=25). If the size of the kernel is even, the padding will be asymmetrical (shorter on the start/left side). const_init Whether to initialize the weights to 1/in_len. If False, the default PyTorch initialization is used (default='True'). use_static_covariates Whether the model should use static covariate information in case the input `series` passed to ``fit()`` contain static covariates. If ``True``, and static covariates are available at fitting time, will enforce that all target `series` have the same static covariate dimensionality in ``fit()`` and ``predict()``. **kwargs Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and Darts' :class:`TorchForecastingModel`. loss_fn PyTorch loss function used for training. This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified. Default: ``torch.nn.MSELoss()``. likelihood One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for probabilistic forecasts. Default: ``None``. torch_metrics A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``. optimizer_cls The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``. optimizer_kwargs Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}`` for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls`` will be used. Default: ``None``. lr_scheduler_cls Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds to using a constant learning rate. Default: ``None``. lr_scheduler_kwargs Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: ``None``. use_reversible_instance_norm Whether to use reversible instance normalization `RINorm` against distribution shift as shown in [2]_. It is only applied to the features of the target series and not the covariates. batch_size Number of time series (input and output sequences) used in each training pass. Default: ``32``. n_epochs Number of epochs over which to train the model. Default: ``100``. model_name Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified, defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., ``"2021-06-14_09_53_32_torch_model_run_44607"``. work_dir Path of the working directory, where to save checkpoints and Tensorboard summaries. Default: current working directory. log_tensorboard If set, use Tensorboard to log the different parameters. The logs will be located in: ``"{work_dir}/darts_logs/{model_name}/logs/"``. Default: ``False``. nr_epochs_val_period Number of epochs to wait before evaluating the validation loss (if a validation ``TimeSeries`` is passed to the :func:`fit()` method). Default: ``1``. force_reset If set to ``True``, any previously-existing model with the same name will be reset (all checkpoints will be discarded). Default: ``False``. save_checkpoints Whether to automatically save the untrained model and checkpoints from training. To load the model from checkpoint, call :func:`MyModelClass.load_from_checkpoint()`, where :class:`MyModelClass` is the :class:`TorchForecastingModel` class that was used (such as :class:`TFTModel`, :class:`NBEATSModel`, etc.). If set to ``False``, the model can still be manually saved using :func:`save()` and loaded using :func:`load()`. Default: ``False``. add_encoders A large number of past and future covariates can be automatically generated with `add_encoders`. This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to transform the generated covariates. This happens all under one hood and only needs to be specified at model creation. Read :meth:`SequentialEncoder <darts.dataprocessing.encoders.SequentialEncoder>` to find out more about ``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features: .. highlight:: python .. code-block:: python def encode_year(idx): return (idx.year - 1950) / 50 add_encoders={ 'cyclic': {'future': ['month']}, 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, 'transformer': Scaler(), 'tz': 'CET' } .. random_state Control the randomness of the weights initialization. Check this `link <https://scikit-learn.org/stable/glossary.html#term-random_state>`_ for more details. Default: ``None``. pl_trainer_kwargs By default :class:`TorchForecastingModel` creates a PyTorch Lightning Trainer with several useful presets that performs the training, validation and prediction processes. These presets include automatic checkpointing, tensorboard logging, setting the torch device and more. With ``pl_trainer_kwargs`` you can add additional kwargs to instantiate the PyTorch Lightning trainer object. Check the `PL Trainer documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`_ for more information about the supported kwargs. Default: ``None``. Running on GPU(s) is also possible using ``pl_trainer_kwargs`` by specifying keys ``"accelerator", "devices", and "auto_select_gpus"``. Some examples for setting the devices inside the ``pl_trainer_kwargs`` dict: - ``{"accelerator": "cpu"}`` for CPU, - ``{"accelerator": "gpu", "devices": [i]}`` to use only GPU ``i`` (``i`` must be an integer), - ``{"accelerator": "gpu", "devices": -1, "auto_select_gpus": True}`` to use all available GPUS. For more info, see here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_basic.html#train-on-multiple-gpus With parameter ``"callbacks"`` you can add custom or PyTorch-Lightning built-in callbacks to Darts' :class:`TorchForecastingModel`. Below is an example for adding EarlyStopping to the training process. The model will stop training early if the validation loss `val_loss` does not improve beyond specifications. For more information on callbacks, visit: `PyTorch Lightning Callbacks <https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html>`_ .. highlight:: python .. code-block:: python from pytorch_lightning.callbacks.early_stopping import EarlyStopping # stop training when validation loss does not decrease more than 0.05 (`min_delta`) over # a period of 5 epochs (`patience`) my_stopper = EarlyStopping( monitor="val_loss", patience=5, min_delta=0.05, mode='min', ) pl_trainer_kwargs={"callbacks": [my_stopper]} .. Note that you can also use a custom PyTorch Lightning Trainer for training and prediction with optional parameter ``trainer`` in :func:`fit()` and :func:`predict()`. show_warnings whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of your forecasting use case. Default: ``False``. References ---------- .. [1] Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2022). Are Transformers Effective for Time Series Forecasting?. arXiv preprint arXiv:2205.13504. .. [2] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p Examples -------- >>> from darts.datasets import WeatherDataset >>> from darts.models import DLinearModel >>> series = WeatherDataset().load() >>> # predicting atmospheric pressure >>> target = series['p (mbar)'][:100] >>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100) >>> past_cov = series['rain (mm)'][:100] >>> # optionally, use future temperatures (pretending this component is a forecast) >>> future_cov = series['T (degC)'][:106] >>> # predict 6 pressure values using the 12 past values of pressure and rainfall, as well as the 6 temperature >>> # values corresponding to the forecasted period >>> model = DLinearModel( >>> input_chunk_length=6, >>> output_chunk_length=6, >>> n_epochs=20, >>> ) >>> model.fit(target, past_covariates=past_cov, future_covariates=future_cov) >>> pred = model.predict(6) >>> pred.values() array([[667.20957388], [666.76986848], [666.67733306], [666.06625381], [665.8529289 ], [665.75320573]]) .. note:: This simple usage example produces poor forecasts. In order to obtain better performance, user should transform the input data, increase the number of epochs, use a validation set, optimize the hyper- parameters, ... """ super().__init__(**self._extract_torch_model_params(**self.model_params)) # extract pytorch lightning module kwargs self.pl_module_params = self._extract_pl_module_params(**self.model_params) self.shared_weights = shared_weights self.kernel_size = kernel_size self.const_init = const_init self._considers_static_covariates = use_static_covariates def _create_model( self, train_sample: MixedCovariatesTrainTensorType ) -> torch.nn.Module: # samples are made of # (past_target, past_covariates, historic_future_covariates, # future_covariates, static_covariates, future_target) raise_if( self.shared_weights and (train_sample[1] is not None or train_sample[2] is not None), "Covariates have been provided, but the model has been built with shared_weights=True. " "Please set shared_weights=False to use covariates.", ) input_dim = train_sample[0].shape[1] + sum( # add past covariates dim and historic future covariates dim, if present train_sample[i].shape[1] if train_sample[i] is not None else 0 for i in (1, 2) ) future_cov_dim = train_sample[3].shape[1] if train_sample[3] is not None else 0 if train_sample[4] is None: static_cov_dim = 0 else: # account for component-specific or shared static covariates representation static_cov_dim = train_sample[4].shape[0] * train_sample[4].shape[1] output_dim = train_sample[-1].shape[1] nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters return _DLinearModule( input_dim=input_dim, output_dim=output_dim, future_cov_dim=future_cov_dim, static_cov_dim=static_cov_dim, nr_params=nr_params, shared_weights=self.shared_weights, kernel_size=self.kernel_size, const_init=self.const_init, **self.pl_module_params, ) @property def supports_multivariate(self) -> bool: return True @property def supports_static_covariates(self) -> bool: return True @property def supports_future_covariates(self) -> bool: return not self.shared_weights @property def supports_past_covariates(self) -> bool: return not self.shared_weights