Time-Series Mixer (TSMixer)

class darts.models.forecasting.tsmixer_model.TSMixerModel(input_chunk_length, output_chunk_length, output_chunk_shift=0, hidden_size=64, ff_size=64, num_blocks=2, activation='ReLU', dropout=0.1, norm_type='LayerNorm', normalize_before=False, use_static_covariates=True, **kwargs)[source]

Bases: MixedCovariatesTorchModel

Time-Series Mixer (TSMixer): An All-MLP Architecture for Time Series.

This is an implementation of the TSMixer architecture, as outlined in [1]. A major part of the architecture was adopted from this PyTorch implementation. Additional changes were applied to increase model performance and efficiency.

TSMixer forecasts time series data by integrating historical time series data, future known inputs, and static contextual information. It uses a combination of conditional feature mixing and mixer layers to process and combine these different types of data for effective forecasting.

This model supports past covariates (known for input_chunk_length points before prediction time), future covariates (known for output_chunk_length points after prediction time), static covariates, as well as probabilistic forecasting.

  • input_chunk_length (int) – Number of time steps in the past to take as a model input (per chunk). Applies to the target series, and past and/or future covariates (if the model supports it). Also called: Encoder length

  • output_chunk_length (int) – Number of time steps predicted at once (per chunk) by the internal model. Also, the number of future values from future covariates to use as a model input (if the model supports future covariates). It is not the same as forecast horizon n used in predict(), which is the desired number of prediction points generated using either a one-shot- or autoregressive forecast. Setting n <= output_chunk_length prevents auto-regression. This is useful when the covariates don’t extend far enough into the future, or to prohibit the model from using future values of past and / or future covariates for prediction (depending on the model’s covariate support). Also called: Decoder length

  • output_chunk_shift (int) – Optionally, the number of steps to shift the start of the output chunk into the future (relative to the input chunk end). This will create a gap between the input and output. If the model supports future_covariates, the future values are extracted from the shifted output chunk. Predictions will start output_chunk_shift steps after the end of the target series. If output_chunk_shift is set, the model cannot generate autoregressive predictions (n > output_chunk_length).

  • hidden_size (int) – The hidden state size / size of the second feed-forward layer in the feature mixing MLP.

  • ff_size (int) – The size of the first feed-forward layer in the feature mixing MLP.

  • num_blocks (int) – The number of mixer blocks in the model. The number includes the first block and all subsequent blocks.

  • activation (str) – The name of the activation function to use in the mixer layers. Default: “ReLU”. Must be one of “ReLU”, “RReLU”, “PReLU”, “ELU”, “Softplus”, “Tanh”, “SELU”, “LeakyReLU”, “Sigmoid”, “GELU”.

  • dropout (float) – Fraction of neurons affected by dropout. This is compatible with Monte Carlo dropout at inference time for model uncertainty estimation (enabled with mc_dropout=True at prediction time).

  • norm_type (Union[str, Module]) – The type of LayerNorm variant to use. Default: “LayerNorm”. If a string, must be one of “LayerNormNoBias”, “LayerNorm”, “TimeBatchNorm2d”. Otherwise, must be a custom nn.Module.

  • normalize_before (bool) – Whether to apply layer normalization before or after mixer layer.

  • use_static_covariates (bool) – Whether the model should use static covariate information in case the input series passed to fit() contain static covariates. If True, and static covariates are available at fitting time, will enforce that all target series have the same static covariate dimensionality in fit() and predict().

  • **kwargs – Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and Darts’ TorchForecastingModel.

  • loss_fn – PyTorch loss function used for training. This parameter will be ignored for probabilistic models if the likelihood parameter is specified. Default: torch.nn.MSELoss().

  • likelihood – One of Darts’ Likelihood models to be used for probabilistic forecasts. Default: None.

  • torch_metrics – A torch metric or a MetricCollection used for evaluation. A full list of available metrics can be found at https://torchmetrics.readthedocs.io/en/latest/. Default: None.

  • optimizer_cls – The PyTorch optimizer class to be used. Default: torch.optim.Adam.

  • optimizer_kwargs – Optionally, some keyword arguments for the PyTorch optimizer (e.g., {'lr': 1e-3} for specifying a learning rate). Otherwise, the default values of the selected optimizer_cls will be used. Default: None.

  • lr_scheduler_cls – Optionally, the PyTorch learning rate scheduler class to be used. Specifying None corresponds to using a constant learning rate. Default: None.

  • lr_scheduler_kwargs – Optionally, some keyword arguments for the PyTorch learning rate scheduler. Default: None.

  • use_reversible_instance_norm – Whether to use reversible instance normalization RINorm against distribution shift as shown in [3]_. It is only applied to the features of the target series and not the covariates.

  • batch_size – Number of time series (input and output sequences) used in each training pass. Default: 32.

  • n_epochs – Number of epochs over which to train the model. Default: 100.

  • model_name – Name of the model. Used for creating checkpoints and saving torch.Tensorboard data. If not specified, defaults to the following string "YYYY-mm-dd_HH_MM_SS_torch_model_run_PID", where the initial part of the name is formatted with the local date and time, while PID is the processed ID (preventing models spawned at the same time by different processes to share the same model_name). E.g., "2021-06-14_09_53_32_torch_model_run_44607".

  • work_dir – Path of the working directory, where to save checkpoints and torch.Tensorboard summaries. Default: current working directory.

  • log_torch.Tensorboard – If set, use torch.Tensorboard to log the different parameters. The logs will be located in: "{work_dir}/darts_logs/{model_name}/logs/". Default: False.

  • nr_epochs_val_period – Number of epochs to wait before evaluating the validation loss (if a validation TimeSeries is passed to the fit() method). Default: 1.

  • force_reset – If set to True, any previously-existing model with the same name will be reset (all checkpoints will be discarded). Default: False.

  • save_checkpoints – Whether to automatically save the untrained model and checkpoints from training. To load the model from checkpoint, call MyModelClass.load_from_checkpoint(), where MyModelClass is the TorchForecastingModel class that was used (such as TFTModel, NBEATSModel, etc.). If set to False, the model can still be manually saved using save() and loaded using load(). Default: False.

  • add_encoders

    A large number of past and future covariates can be automatically generated with add_encoders. This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that will be used as index encoders. Additionally, a transformer such as Darts’ Scaler can be added to transform the generated covariates. This happens all under one hood and only needs to be specified at model creation. Read SequentialEncoder to find out more about add_encoders. Default: None. An example showing some of add_encoders features:

    def encode_year(idx):
        return (idx.year - 1950) / 50
        'cyclic': {'future': ['month']},
        'datetime_attribute': {'future': ['hour', 'dayofweek']},
        'position': {'past': ['relative'], 'future': ['relative']},
        'custom': {'past': [encode_year]},
        'transformer': Scaler(),
        'tz': 'CET'

  • random_state – Control the randomness of the weight’s initialization. Check this link for more details. Default: None.

  • pl_trainer_kwargs

    By default TorchForecastingModel creates a PyTorch Lightning Trainer with several useful presets that performs the training, validation and prediction processes. These presets include automatic checkpointing, torch.Tensorboard logging, setting the torch device and more. With pl_trainer_kwargs you can add additional kwargs to instantiate the PyTorch Lightning trainer object. Check the PL Trainer documentation for more information about the supported kwargs. Default: None. Running on GPU(s) is also possible using pl_trainer_kwargs by specifying keys "accelerator", "devices", and "auto_select_gpus". Some examples for setting the devices inside the pl_trainer_kwargs dict:

    • {"accelerator": "cpu"} for CPU,

    • {"accelerator": "gpu", "devices": [i]} to use only GPU i (i must be an integer),

    • {"accelerator": "gpu", "devices": -1, "auto_select_gpus": True} to use all available GPUS.

    For more info, see here: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags , and https://pytorch-lightning.readthedocs.io/en/stable/accelerators/gpu_basic.html#train-on-multiple-gpus

    With parameter "callbacks" you can add custom or PyTorch-Lightning built-in callbacks to Darts’ TorchForecastingModel. Below is an example for adding EarlyStopping to the training process. The model will stop training early if the validation loss val_loss does not improve beyond specifications. For more information on callbacks, visit: PyTorch Lightning Callbacks

    from pytorch_lightning.callbacks.early_stopping import EarlyStopping
    # stop training when validation loss does not decrease more than 0.05 (`min_delta`) over
    # a period of 5 epochs (`patience`)
    my_stopper = EarlyStopping(
    pl_trainer_kwargs={"callbacks": [my_stopper]}

    Note that you can also use a custom PyTorch Lightning Trainer for training and prediction with optional parameter trainer in fit() and predict().

  • show_warnings – whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of your forecasting use case. Default: False.





>>> from darts.datasets import WeatherDataset
>>> from darts.models import TSMixerModel
>>> series = WeatherDataset().load()
>>> # predicting temperatures
>>> target = series['T (degC)'][:100]
>>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100)
>>> past_cov = series['rain (mm)'][:100]
>>> # optionally, use future atmospheric pressure (pretending this component is a forecast)
>>> future_cov = series['p (mbar)'][:106]
>>> model = TSMixerModel(
>>>     input_chunk_length=6,
>>>     output_chunk_length=6,
>>>     use_reversible_instance_norm=True,
>>>     n_epochs=20
>>> )
>>> model.fit(target, past_covariates=past_cov, future_covariates=future_cov)
>>> pred = model.predict(6)
>>> pred.values()
    [4.4122863 ],



Whether the model considers static covariates, if there are any.


A 8-tuple containing in order: (min target lag, max target lag, min past covariate lag, max past covariate lag, min future covariate lag, max future covariate lag, output shift, max target lag train (only for RNNModel)).


The minimum number of samples for training the model.


Number of time steps predicted at once by the model, not defined for statistical models.


Number of time steps that the output/prediction starts after the end of the input.


Whether model supports future covariates


Whether model instance supports direct prediction of likelihood parameters


Whether the model considers more than one variate in the time series.


Whether the model supports optimized historical forecasts


Whether model supports past covariates


Checks if the forecasting model with this configuration supports probabilistic predictions.


Whether model supports sample weight for training.


Whether model supports static covariates


Whether the model supports prediction for any input series.


Whether the model uses future covariates, once fitted.


Whether the model uses past covariates, once fitted.


Whether the model uses static covariates, once fitted.







backtest(series, past_covariates=None, future_covariates=None, historical_forecasts=None, forecast_horizon=1, num_samples=1, train_length=None, start=None, start_format='value', stride=1, retrain=True, overlap_end=False, last_points_only=False, metric=<function mape>, reduction=<function mean>, verbose=False, show_warnings=True, predict_likelihood_parameters=False, enable_optimization=True, data_transformers=None, metric_kwargs=None, fit_kwargs=None, predict_kwargs=None, sample_weight=None)

Compute error values that the model produced for historical forecasts on (potentially multiple) series.

If historical_forecasts are provided, the metric(s) (given by the metric function) is evaluated directly on all forecasts and actual values. The same series and last_points_only value must be passed that were used to generate the historical forecasts. Finally, the method returns an optional reduction (the mean by default) of all these metric scores.

If historical_forecasts is None, it first generates the historical forecasts with the parameters given below (see ForecastingModel.historical_forecasts() for more info) and then evaluates as described above.

The metric(s) can be further customized metric_kwargs (e.g. control the aggregation over components, time steps, multiple series, other required arguments such as q for quantile metrics, …).

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – A (sequence of) target time series used to successively train (if retrain is not False) and compute the historical forecasts.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a (sequence of) past-observed covariate time series for every input time series in series. This applies only if the model supports past covariates.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a (sequence of) future-known covariate time series for every input time series in series. This applies only if the model supports future covariates.

  • historical_forecasts (Union[TimeSeries, Sequence[TimeSeries], Sequence[Sequence[TimeSeries]], None]) – Optionally, the (or a sequence of / a sequence of sequences of) historical forecasts time series to be evaluated. Corresponds to the output of historical_forecasts(). The same series and last_points_only values must be passed that were used to generate the historical forecasts. If provided, will skip historical forecasting and ignore all parameters except series, last_points_only, metric, and reduction.

  • forecast_horizon (int) – The forecast horizon for the predictions.

  • num_samples (int) – Number of times a prediction is sampled from a probabilistic model. Use values >1 only for probabilistic models.

  • train_length (Optional[int, None]) – Optionally, use a fixed length / number of time steps for every constructed training set (rolling window mode). Only effective when retrain is not False. The default is None, where it uses all time steps up until the prediction time (expanding window mode). If larger than the number of available time steps, uses the expanding mode. Needs to be at least min_train_series_length.

  • start (Union[Timestamp, float, int, None]) –

    Optionally, the first point in time at which a prediction is computed. This parameter supports: float, int, pandas.Timestamp, and None. If a float, it is the proportion of the time series that should lie before the first prediction point. If an int, it is either the index position of the first prediction point for series with a pd.DatetimeIndex, or the index value for series with a pd.RangeIndex. The latter can be changed to the index position with start_format=”position”. If a pandas.Timestamp, it is the time stamp of the first prediction point. If None, the first prediction point will automatically be set to:

    • the first predictable point if retrain is False, or retrain is a Callable and the first predictable point is earlier than the first trainable point.

    • the first trainable point if retrain is True or int (given train_length), or retrain is a Callable and the first trainable point is earlier than the first predictable point.

    • the first trainable point (given train_length) otherwise

    Note: If start is not within the trainable / forecastable points, uses the closest valid start point that

    is a round multiple of stride ahead of start. Raises a ValueError, if no valid start point exists.

    Note: If the model uses a shifted output (output_chunk_shift > 0), then the first predicted point is also

    shifted by output_chunk_shift points into the future.

    Note: If start is outside the possible historical forecasting times, will ignore the parameter

    (default behavior with None) and start at the first trainable/predictable point.

  • start_format (Literal[‘position’, ‘value’]) – Defines the start format. If set to 'position', start corresponds to the index position of the first predicted point and can range from (-len(series), len(series) - 1). If set to 'value', start corresponds to the index value/label of the first predicted point. Will raise an error if the value is not in series’ index. Default: 'value'.

  • stride (int) – The number of time steps between two consecutive predictions.

  • retrain (Union[bool, int, Callable[…, bool]]) –

    Whether and/or on which condition to retrain the model before predicting. This parameter supports 3 different types: bool, (positive) int, and Callable (returning a bool). In the case of bool: retrain the model at each step (True), or never retrain the model (False). In the case of int: the model is retrained every retrain iterations. In the case of Callable: the model is retrained whenever callable returns True. The callable must have the following positional arguments:

    • counter (int): current retrain iteration

    • pred_time (pd.Timestamp or int): timestamp of forecast time (end of the training series)

    • train_series (TimeSeries): train series up to pred_time

    • past_covariates (TimeSeries): past_covariates series up to pred_time

    • future_covariates (TimeSeries): future_covariates series up to min(pred_time + series.freq * forecast_horizon, series.end_time())

    Note: if any optional *_covariates are not passed to historical_forecast, None will be passed to the corresponding retrain function argument. Note: some models require being retrained every time and do not support anything other than retrain=True. Note: also controls the retraining of the data_transformers.

  • overlap_end (bool) – Whether the returned forecasts can go beyond the series’ end or not.

  • last_points_only (bool) – Whether to return only the last point of each historical forecast. If set to True, the method returns a single TimeSeries (for each time series in series) containing the successive point forecasts. Otherwise, returns a list of historical TimeSeries forecasts.

  • metric (Union[Callable[…, Union[float, list[float], ndarray, list[ndarray]]], list[Callable[…, Union[float, list[float], ndarray, list[ndarray]]]]]) – A metric function or a list of metric functions. Each metric must either be a Darts metric (see here), or a custom metric that has an identical signature as Darts’ metrics, uses decorators multi_ts_support() and multi_ts_support(), and returns the metric score.

  • reduction (Optional[Callable[…, float], None]) – A function used to combine the individual error scores obtained when last_points_only is set to False. When providing several metric functions, the function will receive the argument axis = 1 to obtain single value for each metric function. If explicitly set to None, the method will return a list of the individual error scores instead. Set to np.mean by default.

  • verbose (bool) – Whether to print the progress.

  • show_warnings (bool) – Whether to show warnings related to historical forecasts optimization, or parameters start and train_length.

  • predict_likelihood_parameters (bool) – If set to True, the model predicts the parameters of its likelihood instead of the target. Only supported for probabilistic models with a likelihood, num_samples = 1 and n<=output_chunk_length. Default: False.

  • enable_optimization (bool) – Whether to use the optimized version of historical_forecasts when supported and available. Default: True.

  • data_transformers (Optional[dict[str, Union[BaseDataTransformer, Pipeline]], None]) –

    Optionally, a dictionary of BaseDataTransformer or Pipeline to apply to the corresponding series (possibles keys; “series”, “past_covariates”, “future_covariates”). If provided, all input series must be in the un-transformed space. For fittable transformer / pipeline:

    • if retrain=True, the data transformer re-fit on the training data at each historical forecast step.

    • if retrain=False, the data transformer transforms the series once before all the forecasts.

    The fitted transformer is used to transform the input during both training and prediction. If the transformation is invertible, the forecasts will be inverse-transformed. Only effective when historical_forecasts=None.

  • metric_kwargs (Union[dict[str, Any], list[dict[str, Any]], None]) – Additional arguments passed to metric(), such as ‘n_jobs’ for parallelization, ‘component_reduction’ for reducing the component wise metrics, seasonality ‘m’ for scaled metrics, etc. Will pass arguments to each metric separately and only if they are present in the corresponding metric signature. Parameter ‘insample’ for scaled metrics (e.g. mase`, rmsse, …) is ignored, as it is handled internally.

  • fit_kwargs (Optional[dict[str, Any], None]) – Optionally, some additional arguments passed to the model fit() method.

  • predict_kwargs (Optional[dict[str, Any], None]) – Optionally, some additional arguments passed to the model predict() method.

  • sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Optionally, some sample weights to apply to the target series labels for training. Only effective when retrain is not False. They are applied per observation, per label (each step in output_chunk_length), and per component. If a series or sequence of series, then those weights are used. If the weight series only have a single component / column, then the weights are applied globally to all components in series. Otherwise, for component-specific weights, the number of components must match those of series. If a string, then the weights are generated using built-in weighting functions. The available options are “linear” or “exponential” decay - the further in the past, the lower the weight. The weights are computed per time series.

Return type

Union[float, ndarray, list[float], list[ndarray]]


  • float – A single backtest score for single uni/multivariate series, a single metric function and:

    • historical_forecasts generated with last_points_only=True

    • historical_forecasts generated with last_points_only=False and using a backtest reduction

  • np.ndarray – An numpy array of backtest scores. For single series and one of:

    • a single metric function, historical_forecasts generated with last_points_only=False and backtest reduction=None. The output has shape (n forecasts, *).

    • multiple metric functions and historical_forecasts generated with last_points_only=False. The output has shape (*, n metrics) when using a backtest reduction, and (n forecasts, *, n metrics) when reduction=None

    • multiple uni/multivariate series including series_reduction and at least one of component_reduction=None or time_reduction=None for “per time step metrics”

  • List[float] – Same as for type float but for a sequence of series. The returned metric list has length len(series) with the float metric for each input series.

  • List[np.ndarray] – Same as for type np.ndarray but for a sequence of series. The returned metric list has length len(series) with the np.ndarray metrics for each input series.

property considers_static_covariates: bool

Whether the model considers static covariates, if there are any.

Return type


property epochs_trained: int
Return type


property extreme_lags: tuple[Optional[int], Optional[int], Optional[int], Optional[int], Optional[int], Optional[int], int, Optional[int]]

A 8-tuple containing in order: (min target lag, max target lag, min past covariate lag, max past covariate lag, min future covariate lag, max future covariate lag, output shift, max target lag train (only for RNNModel)). If 0 is the index of the first prediction, then all lags are relative to this index.

See examples below.

If the model wasn’t fitted with:
  • target (concerning RegressionModels only): then the first element should be None.

  • past covariates: then the third and fourth elements should be None.

  • future covariates: then the fifth and sixth elements should be None.

Should be overridden by models that use past or future covariates, and/or for model that have minimum target lag and maximum target lags potentially different from -1 and 0.


maximum target lag (second value) cannot be None and is always larger than or equal to 0.


>>> model = LinearRegressionModel(lags=3, output_chunk_length=2)
>>> model.fit(train_series)
>>> model.extreme_lags
(-3, 1, None, None, None, None, 0, None)
>>> model = LinearRegressionModel(lags=3, output_chunk_length=2, output_chunk_shift=2)
>>> model.fit(train_series)
>>> model.extreme_lags
(-3, 1, None, None, None, None, 2, None)
>>> model = LinearRegressionModel(lags=[-3, -5], lags_past_covariates = 4, output_chunk_length=7)
>>> model.fit(train_series, past_covariates=past_covariates)
>>> model.extreme_lags
(-5, 6, -4, -1,  None, None, 0, None)
>>> model = LinearRegressionModel(lags=[3, 5], lags_future_covariates = [4, 6], output_chunk_length=7)
>>> model.fit(train_series, future_covariates=future_covariates)
>>> model.extreme_lags
(-5, 6, None, None, 4, 6, 0, None)
>>> model = NBEATSModel(input_chunk_length=10, output_chunk_length=7)
>>> model.fit(train_series)
>>> model.extreme_lags
(-10, 6, None, None, None, None, 0, None)
>>> model = NBEATSModel(input_chunk_length=10, output_chunk_length=7, lags_future_covariates=[4, 6])
>>> model.fit(train_series, future_covariates)
>>> model.extreme_lags
(-10, 6, None, None, 4, 6, 0, None)
Return type

tuple[Optional[int, None], Optional[int, None], Optional[int, None], Optional[int, None], Optional[int, None], Optional[int, None], int, Optional[int, None]]

fit(series, past_covariates=None, future_covariates=None, val_series=None, val_past_covariates=None, val_future_covariates=None, trainer=None, verbose=None, epochs=0, max_samples_per_ts=None, dataloader_kwargs=None, sample_weight=None, val_sample_weight=None)

Fit/train the model on one or multiple series.

This method wraps around fit_from_dataset(), constructing a default training dataset for this model. If you need more control on how the series are sliced for training, consider calling fit_from_dataset() with a custom darts.utils.data.TrainingDataset.

Training is performed with a PyTorch Lightning Trainer. It uses a default Trainer object from presets and pl_trainer_kwargs used at model creation. You can also use a custom Trainer with optional parameter trainer. For more information on PyTorch Lightning Trainers check out this link .

This function can be called several times to do some extra training. If epochs is specified, the model will be trained for some (extra) epochs epochs.

Below, all possible parameters are documented, but not all models support all parameters. For instance, all the PastCovariatesTorchModel support only past_covariates and not future_covariates. Darts will complain if you try fitting a model with the wrong covariates argument.

When handling covariates, Darts will try to use the time axes of the target and the covariates to come up with the right time slices. So the covariates can be longer than needed; as long as the time axes are correct Darts will handle them correctly. It will also complain if their time span is not sufficient.

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – A series or sequence of series serving as target (i.e. what the model will be trained to forecast)

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a series or sequence of series specifying past-observed covariates

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a series or sequence of series specifying future-known covariates

  • val_series (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, one or a sequence of validation target series, which will be used to compute the validation loss throughout training and keep track of the best performing models.

  • val_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the past covariates corresponding to the validation series (must match covariates)

  • val_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the future covariates corresponding to the validation series (must match covariates)

  • val_sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Same as for sample_weight but for the evaluation dataset.

  • trainer (Optional[Trainer, None]) – Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom trainer will override Darts’ default trainer.

  • verbose (Optional[bool, None]) – Whether to print the progress. Ignored if there is a ProgressBar callback in pl_trainer_kwargs.

  • epochs (int) – If specified, will train the model for epochs (additional) epochs, irrespective of what n_epochs was provided to the model constructor.

  • max_samples_per_ts (Optional[int, None]) – Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily large number of training samples. This parameter upper-bounds the number of training samples per time series (taking only the most recent samples in each series). Leaving to None does not apply any upper bound.

  • dataloader_kwargs (Optional[dict[str, Any], None]) –

    Optionally, a dictionary of keyword arguments used to create the PyTorch DataLoader instances for the training and validation datasets. For more information on DataLoader, check out this link. By default, Darts configures parameters (“batch_size”, “shuffle”, “drop_last”, “collate_fn”, “pin_memory”) for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.

  • sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Optionally, some sample weights to apply to the target series labels. They are applied per observation, per label (each step in output_chunk_length), and per component. If a series or sequence of series, then those weights are used. If the weight series only have a single component / column, then the weights are applied globally to all components in series. Otherwise, for component-specific weights, the number of components must match those of series. If a string, then the weights are generated using built-in weighting functions. The available options are “linear” or “exponential” decay - the further in the past, the lower the weight. The weights are computed globally based on the length of the longest series in series. Then for each series, the weights are extracted from the end of the global weights. This gives a common time weighting across all series.

  • val_sample_weight – Same as for sample_weight but for the evaluation dataset.


Fitted model.

Return type


fit_from_dataset(train_dataset, val_dataset=None, trainer=None, verbose=None, epochs=0, dataloader_kwargs=None)

Train the model with a specific darts.utils.data.TrainingDataset instance. These datasets implement a PyTorch Dataset, and specify how the target and covariates are sliced for training. If you are not sure which training dataset to use, consider calling fit() instead, which will create a default training dataset appropriate for this model.

Training is performed with a PyTorch Lightning Trainer. It uses a default Trainer object from presets and pl_trainer_kwargs used at model creation. You can also use a custom Trainer with optional parameter trainer. For more information on PyTorch Lightning Trainers check out this link.

This function can be called several times to do some extra training. If epochs is specified, the model will be trained for some (extra) epochs epochs.

  • train_dataset (TrainingDataset) – A training dataset with a type matching this model (e.g. PastCovariatesTrainingDataset for PastCovariatesTorchModel).

  • val_dataset (Optional[TrainingDataset, None]) – A training dataset with a type matching this model (e.g. PastCovariatesTrainingDataset for :class:`PastCovariatesTorchModel`s), representing the validation set (to track the validation loss).

  • trainer (Optional[Trainer, None]) – Optionally, a custom PyTorch-Lightning Trainer object to perform prediction. Using a custom trainer will override Darts’ default trainer.

  • verbose (Optional[bool, None]) – Whether to print the progress. Ignored if there is a ProgressBar callback in pl_trainer_kwargs.

  • epochs (int) – If specified, will train the model for epochs (additional) epochs, irrespective of what n_epochs was provided to the model constructor.

  • dataloader_kwargs (Optional[dict[str, Any], None]) –

    Optionally, a dictionary of keyword arguments used to create the PyTorch DataLoader instances for the training and validation datasets. For more information on DataLoader, check out this link. By default, Darts configures parameters (“batch_size”, “shuffle”, “drop_last”, “collate_fn”, “pin_memory”) for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.


Fitted model.

Return type


generate_fit_encodings(series, past_covariates=None, future_covariates=None)

Generates the covariate encodings that were used/generated for fitting the model and returns a tuple of past, and future covariates series with the original and encoded covariates stacked together. The encodings are generated by the encoders defined at model creation with parameter add_encoders. Pass the same series, past_covariates, and future_covariates that you used to train/fit the model.

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – The series or sequence of series with the target values used when fitting the model.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the series or sequence of series with the past-observed covariates used when fitting the model.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the series or sequence of series with the future-known covariates used when fitting the model.


A tuple of (past covariates, future covariates). Each covariate contains the original as well as the encoded covariates.

Return type

Tuple[Union[TimeSeries, Sequence[TimeSeries]], Union[TimeSeries, Sequence[TimeSeries]]]

generate_fit_predict_encodings(n, series, past_covariates=None, future_covariates=None)

Generates covariate encodings for training and inference/prediction and returns a tuple of past, and future covariates series with the original and encoded covariates stacked together. The encodings are generated by the encoders defined at model creation with parameter add_encoders. Pass the same series, past_covariates, and future_covariates that you intend to use for training and prediction.

  • n (int) – The number of prediction time steps after the end of series intended to be used for prediction.

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – The series or sequence of series with target values intended to be used for training and prediction.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the past-observed covariates series intended to be used for training and prediction. The dimensions must match those of the covariates used for training.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the future-known covariates series intended to be used for prediction. The dimensions must match those of the covariates used for training.


A tuple of (past covariates, future covariates). Each covariate contains the original as well as the encoded covariates.

Return type

Tuple[Union[TimeSeries, Sequence[TimeSeries]], Union[TimeSeries, Sequence[TimeSeries]]]

generate_predict_encodings(n, series, past_covariates=None, future_covariates=None)

Generates covariate encodings for the inference/prediction set and returns a tuple of past, and future covariates series with the original and encoded covariates stacked together. The encodings are generated by the encoders defined at model creation with parameter add_encoders. Pass the same series, past_covariates, and future_covariates that you intend to use for prediction.

  • n (int) – The number of prediction time steps after the end of series intended to be used for prediction.

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – The series or sequence of series with target values intended to be used for prediction.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the past-observed covariates series intended to be used for prediction. The dimensions must match those of the covariates used for training.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the future-known covariates series intended to be used for prediction. The dimensions must match those of the covariates used for training.


A tuple of (past covariates, future covariates). Each covariate contains the original as well as the encoded covariates.

Return type

Tuple[Union[TimeSeries, Sequence[TimeSeries]], Union[TimeSeries, Sequence[TimeSeries]]]

classmethod gridsearch(parameters, series, past_covariates=None, future_covariates=None, forecast_horizon=None, stride=1, start=None, start_format='value', last_points_only=False, show_warnings=True, val_series=None, use_fitted_values=False, metric=<function mape>, reduction=<function mean>, verbose=False, n_jobs=1, n_random_samples=None, data_transformers=None, fit_kwargs=None, predict_kwargs=None, sample_weight=None)

Find the best hyper-parameters among a given set using a grid search.

This function has 3 modes of operation: Expanding window mode, split mode and fitted value mode. The three modes of operation evaluate every possible combination of hyper-parameter values provided in the parameters dictionary by instantiating the model_class subclass of ForecastingModel with each combination, and returning the best-performing model with regard to the metric function. The metric function is expected to return an error value, thus the model resulting in the smallest metric output will be chosen.

The relationship of the training data and test data depends on the mode of operation.

Expanding window mode (activated when forecast_horizon is passed): For every hyperparameter combination, the model is repeatedly trained and evaluated on different splits of series. This process is accomplished by using the backtest() function as a subroutine to produce historic forecasts starting from start that are compared against the ground truth values of series. Note that the model is retrained for every single prediction, thus this mode is slower.

Split window mode (activated when val_series is passed): This mode will be used when the val_series argument is passed. For every hyper-parameter combination, the model is trained on series and evaluated on val_series.

Fitted value mode (activated when use_fitted_values is set to True): For every hyper-parameter combination, the model is trained on series and evaluated on the resulting fitted values. Not all models have fitted values, and this method raises an error if the model doesn’t have a fitted_values member. The fitted values are the result of the fit of the model on series. Comparing with the fitted values can be a quick way to assess the model, but one cannot see if the model is overfitting the series.

Derived classes must ensure that a single instance of a model will not share parameters with the other instances, e.g., saving models in the same path. Otherwise, an unexpected behavior can arise while running several models in parallel (when n_jobs != 1). If this cannot be avoided, then gridsearch should be redefined, forcing n_jobs = 1.

Currently this method only supports deterministic predictions (i.e. when models’ predictions have only 1 sample).

  • model_class – The ForecastingModel subclass to be tuned for ‘series’.

  • parameters (dict) – A dictionary containing as keys hyperparameter names, and as values lists of values for the respective hyperparameter.

  • series (TimeSeries) – The target series used as input and target for training.

  • past_covariates (Optional[TimeSeries, None]) – Optionally, a past-observed covariate series. This applies only if the model supports past covariates.

  • future_covariates (Optional[TimeSeries, None]) – Optionally, a future-known covariate series. This applies only if the model supports future covariates.

  • forecast_horizon (Optional[int, None]) – The integer value of the forecasting horizon. Activates expanding window mode.

  • stride (int) – Only used in expanding window mode. The number of time steps between two consecutive predictions.

  • start (Union[Timestamp, float, int, None]) –

    Only used in expanding window mode. Optionally, the first point in time at which a prediction is computed. This parameter supports: float, int, pandas.Timestamp, and None. If a float, it is the proportion of the time series that should lie before the first prediction point. If an int, it is either the index position of the first prediction point for series with a pd.DatetimeIndex, or the index value for series with a pd.RangeIndex. The latter can be changed to the index position with start_format=”position”. If a pandas.Timestamp, it is the time stamp of the first prediction point. If None, the first prediction point will automatically be set to:

    • the first predictable point if retrain is False, or retrain is a Callable and the first predictable point is earlier than the first trainable point.

    • the first trainable point if retrain is True or int (given train_length), or retrain is a Callable and the first trainable point is earlier than the first predictable point.

    • the first trainable point (given train_length) otherwise

    Note: If start is not within the trainable / forecastable points, uses the closest valid start point that

    is a round multiple of stride ahead of start. Raises a ValueError, if no valid start point exists.

    Note: If the model uses a shifted output (output_chunk_shift > 0), then the first predicted point is also

    shifted by output_chunk_shift points into the future.

    Note: If start is outside the possible historical forecasting times, will ignore the parameter

    (default behavior with None) and start at the first trainable/predictable point.

  • start_format (Literal[‘position’, ‘value’]) – Only used in expanding window mode. Defines the start format. Only effective when start is an integer and series is indexed with a pd.RangeIndex. If set to ‘position’, start corresponds to the index position of the first predicted point and can range from (-len(series), len(series) - 1). If set to ‘value’, start corresponds to the index value/label of the first predicted point. Will raise an error if the value is not in series’ index. Default: 'value'

  • last_points_only (bool) – Only used in expanding window mode. Whether to use the whole forecasts or only the last point of each forecast to compute the error.

  • show_warnings (bool) – Only used in expanding window mode. Whether to show warnings related to the start parameter.

  • val_series (Optional[TimeSeries, None]) – The TimeSeries instance used for validation in split mode. If provided, this series must start right after the end of series; so that a proper comparison of the forecast can be made.

  • use_fitted_values (bool) – If True, uses the comparison with the fitted values. Raises an error if fitted_values is not an attribute of model_class.

  • metric (Callable[[TimeSeries, TimeSeries], float]) –

    A metric function that returns the error between two TimeSeries as a float value . Must either be one of Darts’ “aggregated over time” metrics (see here), or a custom metric that as input two TimeSeries and returns the error

  • reduction (Callable[[ndarray], float]) – A reduction function (mapping array to float) describing how to aggregate the errors obtained on the different validation series when backtesting. By default it’ll compute the mean of errors.

  • verbose – Whether to print the progress.

  • n_jobs (int) – The number of jobs to run in parallel. Parallel jobs are created only when there are two or more parameters combinations to evaluate. Each job will instantiate, train, and evaluate a different instance of the model. Defaults to 1 (sequential). Setting the parameter to -1 means using all the available cores.

  • n_random_samples (Union[int, float, None]) – The number/ratio of hyperparameter combinations to select from the full parameter grid. This will perform a random search instead of using the full grid. If an integer, n_random_samples is the number of parameter combinations selected from the full grid and must be between 0 and the total number of parameter combinations. If a float, n_random_samples is the ratio of parameter combinations selected from the full grid and must be between 0 and 1. Defaults to None, for which random selection will be ignored.

  • data_transformers (Optional[dict[str, Union[BaseDataTransformer, Pipeline]], None]) –

    Optionally, a dictionary of BaseDataTransformer or Pipeline to apply to the corresponding series (possibles keys; “series”, “past_covariates”, “future_covariates”). If provided, all input series must be in the un-transformed space. For fittable transformer / pipeline:

    • if retrain=True, the data transformer re-fit on the training data at each historical forecast step.

    • if retrain=False, the data transformer transforms the series once before all the forecasts.

    The fitted transformer is used to transform the input during both training and prediction. If the transformation is invertible, the forecasts will be inverse-transformed.

  • fit_kwargs (Optional[dict[str, Any], None]) – Additional arguments passed to the model fit() method.

  • predict_kwargs (Optional[dict[str, Any], None]) – Additional arguments passed to the model predict() method.

  • sample_weight (Union[TimeSeries, str, None]) – Optionally, some sample weights to apply to the target series labels for training. Only effective when retrain is not False. They are applied per observation, per label (each step in output_chunk_length), and per component. If a series, then those weights are used. If the weight series only have a single component / column, then the weights are applied globally to all components in series. Otherwise, for component-specific weights, the number of components must match those of series. If a string, then the weights are generated using built-in weighting functions. The available options are “linear” or “exponential” decay - the further in the past, the lower the weight.


A tuple containing an untrained model_class instance created from the best-performing hyper-parameters, along with a dictionary containing these best hyper-parameters, and metric score for the best hyper-parameters.

Return type

ForecastingModel, Dict, float

historical_forecasts(series, past_covariates=None, future_covariates=None, forecast_horizon=1, num_samples=1, train_length=None, start=None, start_format='value', stride=1, retrain=True, overlap_end=False, last_points_only=True, verbose=False, show_warnings=True, predict_likelihood_parameters=False, enable_optimization=True, data_transformers=None, fit_kwargs=None, predict_kwargs=None, sample_weight=None)

Generates historical forecasts by simulating predictions at various points in time throughout the history of the provided (potentially multiple) series. This process involves retrospectively applying the model to different time steps, as if the forecasts were made in real-time at those specific moments. This allows for an evaluation of the model’s performance over the entire duration of the series, providing insights into its predictive accuracy and robustness across different historical periods.

There are two main modes for this method:

  • Re-training Mode (Default, retrain=True): The model is re-trained at each step of the simulation, and generates a forecast using the updated model. In case of multiple series, the model is re-trained on each series independently (global training is not yet supported).

  • Pre-trained Mode (retrain=False): The forecasts are generated at each step of the simulation without re-training. It is only supported for pre-trained global forecasting models. This mode is significantly faster as it skips the re-training step.

By choosing the appropriate mode, you can balance between computational efficiency and the need for up-to-date model training.

Re-training Mode: This mode repeatedly builds a training set by either expanding from the beginning of the series or by using a fixed-length train_length (the start point can also be configured with start and start_format). The model is then trained on this training set, and a forecast of length forecast_horizon is generated. Subsequently, the end of the training set is moved forward by stride time steps, and the process is repeated.

Pre-trained Mode: This mode is only supported for pre-trained global forecasting models. It uses the same simulation steps as in the Re-training Mode (ignoring train_length), but generates the forecasts directly without re-training.

By default, with last_points_only=True, this method returns a single time series (or a sequence of time series) composed of the last point from each historical forecast. This time series will thus have a frequency of series.freq * stride. If last_points_only=False, it will instead return a list (or a sequence of lists) of the full historical forecast series each with frequency series.freq.

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – A (sequence of) target time series used to successively train (if retrain is not False) and compute the historical forecasts.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a (sequence of) past-observed covariate time series for every input time series in series. This applies only if the model supports past covariates.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a (sequence of) future-known covariate time series for every input time series in series. This applies only if the model supports future covariates.

  • forecast_horizon (int) – The forecast horizon for the predictions.

  • num_samples (int) – Number of times a prediction is sampled from a probabilistic model. Use values >1 only for probabilistic models.

  • train_length (Optional[int, None]) – Optionally, use a fixed length / number of time steps for every constructed training set (rolling window mode). Only effective when retrain is not False. The default is None, where it uses all time steps up until the prediction time (expanding window mode). If larger than the number of available time steps, uses the expanding mode. Needs to be at least min_train_series_length.

  • start (Union[Timestamp, float, int, None]) –

    Optionally, the first point in time at which a prediction is computed. This parameter supports: float, int, pandas.Timestamp, and None. If a float, it is the proportion of the time series that should lie before the first prediction point. If an int, it is either the index position of the first prediction point for series with a pd.DatetimeIndex, or the index value for series with a pd.RangeIndex. The latter can be changed to the index position with start_format=”position”. If a pandas.Timestamp, it is the time stamp of the first prediction point. If None, the first prediction point will automatically be set to:

    • the first predictable point if retrain is False, or retrain is a Callable and the first predictable point is earlier than the first trainable point.

    • the first trainable point if retrain is True or int (given train_length), or retrain is a Callable and the first trainable point is earlier than the first predictable point.

    • the first trainable point (given train_length) otherwise

    Note: If start is not within the trainable / forecastable points, uses the closest valid start point that

    is a round multiple of stride ahead of start. Raises a ValueError, if no valid start point exists.

    Note: If the model uses a shifted output (output_chunk_shift > 0), then the first predicted point is also

    shifted by output_chunk_shift points into the future.

    Note: If start is outside the possible historical forecasting times, will ignore the parameter

    (default behavior with None) and start at the first trainable/predictable point.

  • start_format (Literal[‘position’, ‘value’]) – Defines the start format. If set to 'position', start corresponds to the index position of the first predicted point and can range from (-len(series), len(series) - 1). If set to 'value', start corresponds to the index value/label of the first predicted point. Will raise an error if the value is not in series’ index. Default: 'value'.

  • stride (int) – The number of time steps between two consecutive predictions.

  • retrain (Union[bool, int, Callable[…, bool]]) –

    Whether and/or on which condition to retrain the model before predicting. This parameter supports 3 different types: bool, (positive) int, and Callable (returning a bool). In the case of bool: retrain the model at each step (True), or never retrain the model (False). In the case of int: the model is retrained every retrain iterations. In the case of Callable: the model is retrained whenever callable returns True. The callable must have the following positional arguments:

    • counter (int): current retrain iteration

    • pred_time (pd.Timestamp or int): timestamp of forecast time (end of the training series)

    • train_series (TimeSeries): train series up to pred_time

    • past_covariates (TimeSeries): past_covariates series up to pred_time

    • future_covariates (TimeSeries): future_covariates series up to min(pred_time + series.freq * forecast_horizon, series.end_time())

    Note: if any optional *_covariates are not passed to historical_forecast, None will be passed to the corresponding retrain function argument. Note: some models require being retrained every time and do not support anything other than retrain=True. Note: also controls the retraining of the data_transformers.

  • overlap_end (bool) – Whether the returned forecasts can go beyond the series’ end or not.

  • last_points_only (bool) – Whether to return only the last point of each historical forecast. If set to True, the method returns a single TimeSeries (for each time series in series) containing the successive point forecasts. Otherwise, returns a list of historical TimeSeries forecasts.

  • verbose (bool) – Whether to print the progress.

  • show_warnings (bool) – Whether to show warnings related to historical forecasts optimization, or parameters start and train_length.

  • predict_likelihood_parameters (bool) – If set to True, the model predicts the parameters of its likelihood instead of the target. Only supported for probabilistic models with a likelihood, num_samples = 1 and n<=output_chunk_length. Default: False.

  • enable_optimization (bool) – Whether to use the optimized version of historical_forecasts when supported and available. Default: True.

  • data_transformers (Optional[dict[str, Union[BaseDataTransformer, Pipeline]], None]) –

    Optionally, a dictionary of BaseDataTransformer or Pipeline to apply to the corresponding series (possibles keys; “series”, “past_covariates”, “future_covariates”). If provided, all input series must be in the un-transformed space. For fittable transformer / pipeline:

    • if retrain=True, the data transformer re-fit on the training data at each historical forecast step.

    • if retrain=False, the data transformer transforms the series once before all the forecasts.

    The fitted transformer is used to transform the input during both training and prediction. If the transformation is invertible, the forecasts will be inverse-transformed.

  • fit_kwargs (Optional[dict[str, Any], None]) – Optionally, some additional arguments passed to the model fit() method.

  • predict_kwargs (Optional[dict[str, Any], None]) – Optionally, some additional arguments passed to the model predict() method.

  • sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Optionally, some sample weights to apply to the target series labels for training. Only effective when retrain is not False. They are applied per observation, per label (each step in output_chunk_length), and per component. If a series or sequence of series, then those weights are used. If the weight series only have a single component / column, then the weights are applied globally to all components in series. Otherwise, for component-specific weights, the number of components must match those of series. If a string, then the weights are generated using built-in weighting functions. The available options are “linear” or “exponential” decay - the further in the past, the lower the weight. The weights are computed per time series.

Return type

Union[TimeSeries, list[TimeSeries], list[list[TimeSeries]]]


  • TimeSeries – A single historical forecast for a single series and last_points_only=True: it contains only the predictions at step forecast_horizon from all historical forecasts.

  • List[TimeSeries] – A list of historical forecasts for:

    • a sequence (list) of series and last_points_only=True: for each series, it contains only the predictions at step forecast_horizon from all historical forecasts.

    • a single series and last_points_only=False: for each historical forecast, it contains the entire horizon forecast_horizon.

  • List[List[TimeSeries]] – A list of lists of historical forecasts for a sequence of series and last_points_only=False. For each series, and historical forecast, it contains the entire horizon forecast_horizon. The outer list is over the series provided in the input sequence, and the inner lists contain the historical forecasts for each series.

property input_chunk_length: int
Return type


property likelihood: Optional[Likelihood]
Return type

Optional[Likelihood, None]

static load(path, pl_trainer_kwargs=None, **kwargs)

Loads a model from a given file path.

Example for loading a general save from RNNModel:

from darts.models import RNNModel

model_loaded = RNNModel.load(path)

Example for loading an RNNModel to GPU that was trained on CPU:

from darts.models import RNNModel

model_loaded = RNNModel.load(path, pl_trainer_kwargs={"accelerator": "gpu"})

Example for loading an RNNModel to CPU that was saved on GPU:

from darts.models import RNNModel

model_loaded = RNNModel.load(path, map_location="cpu", pl_trainer_kwargs={"accelerator": "gpu"})
  • path (str) – Path from which to load the model. If no path was specified when saving the model, the automatically generated path ending with “.pt” has to be provided.

  • pl_trainer_kwargs (Optional[dict, None]) – Optionally, a set of kwargs to create a new Lightning Trainer used to configure the model for downstream tasks (e.g. prediction). Some examples include specifying the batch size or moving the model to CPU/GPU(s). Check the Lightning Trainer documentation for more information about the supported kwargs.

  • **kwargs

    Additional kwargs for PyTorch Lightning’s LightningModule.load_from_checkpoint() method, such as map_location to load the model onto a different device than the one on which it was saved. For more information, read the official documentation.

Return type


static load_from_checkpoint(model_name, work_dir=None, file_name=None, best=True, **kwargs)

Load the model from automatically saved checkpoints under ‘{work_dir}/darts_logs/{model_name}/checkpoints/’. This method is used for models that were created with save_checkpoints=True.

If you manually saved your model, consider using load().

Example for loading a RNNModel from checkpoint (model_name is the model_name used at model creation):

from darts.models import RNNModel

model_loaded = RNNModel.load_from_checkpoint(model_name, best=True)

If file_name is given, returns the model saved under ‘{work_dir}/darts_logs/{model_name}/checkpoints/{file_name}’.

If file_name is not given, will try to restore the best checkpoint (if best is True) or the most recent checkpoint (if best is False from ‘{work_dir}/darts_logs/{model_name}/checkpoints/’.

Example for loading an RNNModel checkpoint to CPU that was saved on GPU:

from darts.models import RNNModel

model_loaded = RNNModel.load_from_checkpoint(model_name, best=True, map_location="cpu")
  • model_name (str) – The name of the model, used to retrieve the checkpoints folder’s name.

  • work_dir (Optional[str, None]) – Working directory (containing the checkpoints folder). Defaults to current working directory.

  • file_name (Optional[str, None]) – The name of the checkpoint file. If not specified, use the most recent one.

  • best (bool) – If set, will retrieve the best model (according to validation loss) instead of the most recent one. Only is ignored when file_name is given.

  • **kwargs

    Additional kwargs for PyTorch Lightning’s LightningModule.load_from_checkpoint() method, such as map_location to load the model onto a different device than the one from which it was saved. For more information, read the official documentation.


The corresponding trained TorchForecastingModel.

Return type


load_weights(path, load_encoders=True, skip_checks=False, **kwargs)

Loads the weights from a manually saved model (saved with save()).

Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders and perform sanity checks on the model parameters.

  • path (str) – Path from which to load the model’s weights. If no path was specified when saving the model, the automatically generated path ending with “.pt” has to be provided.

  • load_encoders (bool) – If set, will load the encoders from the model to enable direct call of fit() or predict(). Default: True.

  • skip_checks (bool) – If set, will disable the loading of the encoders and the sanity checks on model parameters (not recommended). Cannot be used with load_encoders=True. Default: False.

  • **kwargs

    Additional kwargs for PyTorch’s load() method, such as map_location to load the model onto a different device than the one from which it was saved. For more information, read the official documentation.

load_weights_from_checkpoint(model_name=None, work_dir=None, file_name=None, best=True, strict=True, load_encoders=True, skip_checks=False, **kwargs)

Load only the weights from automatically saved checkpoints under ‘{work_dir}/darts_logs/{model_name}/ checkpoints/’. This method is used for models that were created with save_checkpoints=True and that need to be re-trained or fine-tuned with different optimizer or learning rate scheduler. However, it can also be used to load weights for inference.

To resume an interrupted training, please consider using load_from_checkpoint() which also reload the trainer, optimizer and learning rate scheduler states.

For manually saved model, consider using load() or load_weights() instead.

Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders and perform sanity checks on the model parameters.

  • model_name (Optional[str, None]) – The name of the model, used to retrieve the checkpoints folder’s name. Default: self.model_name.

  • work_dir (Optional[str, None]) – Working directory (containing the checkpoints folder). Defaults to current working directory.

  • file_name (Optional[str, None]) – The name of the checkpoint file. If not specified, use the most recent one.

  • best (bool) – If set, will retrieve the best model (according to validation loss) instead of the most recent one. Only is ignored when file_name is given. Default: True.

  • strict (bool) –

    If set, strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict(). Default: True. For more information, read the official documentation.

  • load_encoders (bool) – If set, will load the encoders from the model to enable direct call of fit() or predict(). Default: True.

  • skip_checks (bool) – If set, will disable the loading of the encoders and the sanity checks on model parameters (not recommended). Cannot be used with load_encoders=True. Default: False.

  • **kwargs

    Additional kwargs for PyTorch’s load() method, such as map_location to load the model onto a different device than the one from which it was saved. For more information, read the official documentation.

lr_find(series, past_covariates=None, future_covariates=None, val_series=None, val_past_covariates=None, val_future_covariates=None, sample_weight=None, val_sample_weight=None, trainer=None, verbose=None, epochs=0, max_samples_per_ts=None, dataloader_kwargs=None, min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0)

A wrapper around PyTorch Lightning’s Tuner.lr_find(). Performs a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. For more information on PyTorch Lightning’s Tuner check out this link. It is recommended to increase the number of epochs if the tuner did not give satisfactory results. Consider creating a new model object with the suggested learning rate for example using model creation parameters optimizer_cls, optimizer_kwargs, lr_scheduler_cls, and lr_scheduler_kwargs.

Example using a RNNModel:

import torch
from darts.datasets import AirPassengersDataset
from darts.models import NBEATSModel

series = AirPassengersDataset().load()
train, val = series[:-18], series[-18:]
model = NBEATSModel(input_chunk_length=12, output_chunk_length=6, random_state=42)
# run the learning rate tuner
results = model.lr_find(series=train, val_series=val)
# plot the results
results.plot(suggest=True, show=True)
# create a new model with the suggested learning rate
model = NBEATSModel(
    optimizer_kwargs={"lr": results.suggestion()}
  • series (Union[TimeSeries, Sequence[TimeSeries]]) – A series or sequence of series serving as target (i.e. what the model will be trained to forecast)

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a series or sequence of series specifying past-observed covariates

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a series or sequence of series specifying future-known covariates

  • val_series (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, one or a sequence of validation target series, which will be used to compute the validation loss throughout training and keep track of the best performing models.

  • val_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the past covariates corresponding to the validation series (must match covariates)

  • val_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the future covariates corresponding to the validation series (must match covariates)

  • sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Optionally, some sample weights to apply to the target series labels. They are applied per observation, per label (each step in output_chunk_length), and per component. If a series or sequence of series, then those weights are used. If the weight series only have a single component / column, then the weights are applied globally to all components in series. Otherwise, for component-specific weights, the number of components must match those of series. If a string, then the weights are generated using built-in weighting functions. The available options are “linear” or “exponential” decay - the further in the past, the lower the weight. The weights are computed globally based on the length of the longest series in series. Then for each series, the weights are extracted from the end of the global weights. This gives a common time weighting across all series.

  • val_sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Same as for sample_weight but for the evaluation dataset.

  • trainer (Optional[Trainer, None]) – Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom trainer will override Darts’ default trainer.

  • verbose (Optional[bool, None]) – Whether to print the progress. Ignored if there is a ProgressBar callback in pl_trainer_kwargs.

  • epochs (int) – If specified, will train the model for epochs (additional) epochs, irrespective of what n_epochs was provided to the model constructor.

  • max_samples_per_ts (Optional[int, None]) – Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily large number of training samples. This parameter upper-bounds the number of training samples per time series (taking only the most recent samples in each series). Leaving to None does not apply any upper bound.

  • dataloader_kwargs (Optional[dict[str, Any], None]) –

    Optionally, a dictionary of keyword arguments used to create the PyTorch DataLoader instances for the training and validation datasets. For more information on DataLoader, check out this link. By default, Darts configures parameters (“batch_size”, “shuffle”, “drop_last”, “collate_fn”, “pin_memory”) for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.

  • min_lr (float) – minimum learning rate to investigate

  • max_lr (float) – maximum learning rate to investigate

  • num_training (int) – number of learning rates to test

  • mode (str) – Search strategy to update learning rate after each batch: ‘exponential’: Increases the learning rate exponentially. ‘linear’: Increases the learning rate linearly.

  • early_stop_threshold (float) – Threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None


_LRFinder object of Lightning containing the results of the LR sweep.

Return type


property min_train_samples: int

The minimum number of samples for training the model.

Return type


property model_created: bool
Return type


property model_params: dict
Return type


property output_chunk_length: int

Number of time steps predicted at once by the model, not defined for statistical models.

Return type


property output_chunk_shift: int

Number of time steps that the output/prediction starts after the end of the input.

Return type


predict(n, series=None, past_covariates=None, future_covariates=None, trainer=None, batch_size=None, verbose=None, n_jobs=1, roll_size=None, num_samples=1, dataloader_kwargs=None, mc_dropout=False, predict_likelihood_parameters=False, show_warnings=True)

Predict the n time step following the end of the training series, or of the specified series.

Prediction is performed with a PyTorch Lightning Trainer. It uses a default Trainer object from presets and pl_trainer_kwargs used at model creation. You can also use a custom Trainer with optional parameter trainer. For more information on PyTorch Lightning Trainers check out this link .

Below, all possible parameters are documented, but not all models support all parameters. For instance, all the PastCovariatesTorchModel support only past_covariates and not future_covariates. Darts will complain if you try calling predict() on a model with the wrong covariates argument.

Darts will also complain if the provided covariates do not have a sufficient time span. In general, not all models require the same covariates’ time spans:

  • Models relying on past covariates require the last input_chunk_length of the past_covariates
    points to be known at prediction time. For horizon values n > output_chunk_length, these models
    require at least the next n - output_chunk_length future values to be known as well.
  • Models relying on future covariates require the next n values to be known.
    In addition (for DualCovariatesTorchModel and MixedCovariatesTorchModel), they also
    require the “historic” values of these future covariates (over the past input_chunk_length).

When handling covariates, Darts will try to use the time axes of the target and the covariates to come up with the right time slices. So the covariates can be longer than needed; as long as the time axes are correct Darts will handle them correctly. It will also complain if their time span is not sufficient.

  • n (int) – The number of time steps after the end of the training time series for which to produce predictions

  • series (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a series or sequence of series, representing the history of the target series whose future is to be predicted. If specified, the method returns the forecasts of these series. Otherwise, the method returns the forecast of the (single) training series.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the past-observed covariates series needed as inputs for the model. They must match the covariates used for training in terms of dimension.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, the future-known covariates series needed as inputs for the model. They must match the covariates used for training in terms of dimension.

  • trainer (Optional[Trainer, None]) – Optionally, a custom PyTorch-Lightning Trainer object to perform prediction. Using a custom trainer will override Darts’ default trainer.

  • batch_size (Optional[int, None]) – Size of batches during prediction. Defaults to the models’ training batch_size value.

  • verbose (Optional[bool, None]) – Whether to print the progress. Ignored if there is a ProgressBar callback in pl_trainer_kwargs.

  • n_jobs (int) – The number of jobs to run in parallel. -1 means using all processors. Defaults to 1.

  • roll_size (Optional[int, None]) – For self-consuming predictions, i.e. n > output_chunk_length, determines how many outputs of the model are fed back into it at every iteration of feeding the predicted target (and optionally future covariates) back into the model. If this parameter is not provided, it will be set output_chunk_length by default.

  • num_samples (int) – Number of times a prediction is sampled from a probabilistic model. Must be 1 for deterministic models.

  • dataloader_kwargs (Optional[dict[str, Any], None]) –

    Optionally, a dictionary of keyword arguments used to create the PyTorch DataLoader instance for the inference/prediction dataset. For more information on DataLoader, check out this link. By default, Darts configures parameters (“batch_size”, “shuffle”, “drop_last”, “collate_fn”, “pin_memory”) for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.

  • mc_dropout (bool) – Optionally, enable monte carlo dropout for predictions using neural network based models. This allows bayesian approximation by specifying an implicit prior over learned models.

  • predict_likelihood_parameters (bool) – If set to True, the model predicts the parameters of its likelihood instead of the target. Only supported for probabilistic models with a likelihood, num_samples = 1 and n<=output_chunk_length. Default: False.

  • show_warnings (bool) – Optionally, control whether warnings are shown. Not effective for all models.


One or several time series containing the forecasts of series, or the forecast of the training series if series is not specified and the model has been trained on a single series.

Return type

Union[TimeSeries, Sequence[TimeSeries]]

predict_from_dataset(n, input_series_dataset, trainer=None, batch_size=None, verbose=None, n_jobs=1, roll_size=None, num_samples=1, dataloader_kwargs=None, mc_dropout=False, predict_likelihood_parameters=False)

This method allows for predicting with a specific darts.utils.data.InferenceDataset instance. These datasets implement a PyTorch Dataset, and specify how the target and covariates are sliced for inference. In most cases, you’ll rather want to call predict() instead, which will create an appropriate InferenceDataset for you.

Prediction is performed with a PyTorch Lightning Trainer. It uses a default Trainer object from presets and pl_trainer_kwargs used at model creation. You can also use a custom Trainer with optional parameter trainer. For more information on PyTorch Lightning Trainers check out this link .

  • n (int) – The number of time steps after the end of the training time series for which to produce predictions

  • input_series_dataset (InferenceDataset) – Optionally, a series or sequence of series, representing the history of the target series’ whose future is to be predicted. If specified, the method returns the forecasts of these series. Otherwise, the method returns the forecast of the (single) training series.

  • trainer (Optional[Trainer, None]) – Optionally, a custom PyTorch-Lightning Trainer object to perform prediction. Using a custom trainer will override Darts’ default trainer.

  • batch_size (Optional[int, None]) – Size of batches during prediction. Defaults to the models batch_size value.

  • verbose (Optional[bool, None]) – Whether to print the progress. Ignored if there is a ProgressBar callback in pl_trainer_kwargs.

  • n_jobs (int) – The number of jobs to run in parallel. -1 means using all processors. Defaults to 1.

  • roll_size (Optional[int, None]) – For self-consuming predictions, i.e. n > output_chunk_length, determines how many outputs of the model are fed back into it at every iteration of feeding the predicted target (and optionally future covariates) back into the model. If this parameter is not provided, it will be set output_chunk_length by default.

  • num_samples (int) – Number of times a prediction is sampled from a probabilistic model. Must be 1 for deterministic models.

  • dataloader_kwargs (Optional[dict[str, Any], None]) –

    Optionally, a dictionary of keyword arguments used to create the PyTorch DataLoader instance for the inference/prediction dataset. For more information on DataLoader, check out this link. By default, Darts configures parameters (“batch_size”, “shuffle”, “drop_last”, “collate_fn”, “pin_memory”) for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.

  • mc_dropout (bool) – Optionally, enable monte carlo dropout for predictions using neural network based models. This allows bayesian approximation by specifying an implicit prior over learned models.

  • predict_likelihood_parameters (bool) – If set to True, the model predicts the parameters of its likelihood instead of the target. Only supported for probabilistic models with a likelihood, num_samples = 1 and n<=output_chunk_length. Default: False


Returns one or more forecasts for time series.

Return type



Resets the model object and removes all stored data - model, checkpoints, loggers and training history.

residuals(series, past_covariates=None, future_covariates=None, historical_forecasts=None, forecast_horizon=1, num_samples=1, train_length=None, start=None, start_format='value', stride=1, retrain=True, overlap_end=False, last_points_only=True, metric=<function err>, verbose=False, show_warnings=True, predict_likelihood_parameters=False, enable_optimization=True, data_transformers=None, metric_kwargs=None, fit_kwargs=None, predict_kwargs=None, sample_weight=None, values_only=False)

Compute the residuals that the model produced for historical forecasts on (potentially multiple) series.

This function computes the difference (or one of Darts’ “per time step” metrics) between the actual observations from series and the fitted values obtained by training the model on series (or using a pre-trained model with retrain=False). Not all models support fitted values, so we use historical forecasts as an approximation for them.

In sequence this method performs:

  • use pre-computed historical_forecasts or compute historical forecasts for each series (see historical_forecasts() for more details). How the historical forecasts are generated can be configured with parameters num_samples, train_length, start, start_format, forecast_horizon, stride, retrain, last_points_only, fit_kwargs, and predict_kwargs.

  • compute a backtest using a “per time step” metric between the historical forecasts and series per component/column and time step (see backtest() for more details). By default, uses the residuals err() (error) as a metric.

  • create and return TimeSeries (or simply a np.ndarray with values_only=True) with the time index from historical forecasts, and values from the metrics per component and time step.

This method works for single or multiple univariate or multivariate series. It uses the median prediction (when dealing with stochastic forecasts).

  • series (Union[TimeSeries, Sequence[TimeSeries]]) – A (sequence of) target time series used to successively train (if retrain is not False) and compute the historical forecasts.

  • past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a (sequence of) past-observed covariate time series for every input time series in series. This applies only if the model supports past covariates.

  • future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a (sequence of) future-known covariate time series for every input time series in series. This applies only if the model supports future covariates.

  • historical_forecasts (Union[TimeSeries, Sequence[TimeSeries], Sequence[Sequence[TimeSeries]], None]) – Optionally, the (or a sequence of / a sequence of sequences of) historical forecasts time series to be evaluated. Corresponds to the output of historical_forecasts(). The same series and last_points_only values must be passed that were used to generate the historical forecasts. If provided, will skip historical forecasting and ignore all parameters except series, last_points_only, metric, and reduction.

  • forecast_horizon (int) – The forecast horizon for the predictions.

  • num_samples (int) – Number of times a prediction is sampled from a probabilistic model. Use values >1 only for probabilistic models.

  • train_length (Optional[int, None]) – Optionally, use a fixed length / number of time steps for every constructed training set (rolling window mode). Only effective when retrain is not False. The default is None, where it uses all time steps up until the prediction time (expanding window mode). If larger than the number of available time steps, uses the expanding mode. Needs to be at least min_train_series_length.

  • start (Union[Timestamp, float, int, None]) –

    Optionally, the first point in time at which a prediction is computed. This parameter supports: float, int, pandas.Timestamp, and None. If a float, it is the proportion of the time series that should lie before the first prediction point. If an int, it is either the index position of the first prediction point for series with a pd.DatetimeIndex, or the index value for series with a pd.RangeIndex. The latter can be changed to the index position with start_format=”position”. If a pandas.Timestamp, it is the time stamp of the first prediction point. If None, the first prediction point will automatically be set to:

    • the first predictable point if retrain is False, or retrain is a Callable and the first predictable point is earlier than the first trainable point.

    • the first trainable point if retrain is True or int (given train_length), or retrain is a Callable and the first trainable point is earlier than the first predictable point.

    • the first trainable point (given train_length) otherwise

    Note: If start is not within the trainable / forecastable points, uses the closest valid start point that

    is a round multiple of stride ahead of start. Raises a ValueError, if no valid start point exists.

    Note: If the model uses a shifted output (output_chunk_shift > 0), then the first predicted point is also

    shifted by output_chunk_shift points into the future.

    Note: If start is outside the possible historical forecasting times, will ignore the parameter

    (default behavior with None) and start at the first trainable/predictable point.

  • start_format (Literal[‘position’, ‘value’]) – Defines the start format. If set to 'position', start corresponds to the index position of the first predicted point and can range from (-len(series), len(series) - 1). If set to 'value', start corresponds to the index value/label of the first predicted point. Will raise an error if the value is not in series’ index. Default: 'value'.

  • stride (int) – The number of time steps between two consecutive predictions.

  • retrain (Union[bool, int, Callable[…, bool]]) –

    Whether and/or on which condition to retrain the model before predicting. This parameter supports 3 different types: bool, (positive) int, and Callable (returning a bool). In the case of bool: retrain the model at each step (True), or never retrain the model (False). In the case of int: the model is retrained every retrain iterations. In the case of Callable: the model is retrained whenever callable returns True. The callable must have the following positional arguments:

    • counter (int): current retrain iteration

    • pred_time (pd.Timestamp or int): timestamp of forecast time (end of the training series)

    • train_series (TimeSeries): train series up to pred_time

    • past_covariates (TimeSeries): past_covariates series up to pred_time

    • future_covariates (TimeSeries): future_covariates series up to min(pred_time + series.freq * forecast_horizon, series.end_time())

    Note: if any optional *_covariates are not passed to historical_forecast, None will be passed to the corresponding retrain function argument. Note: some models require being retrained every time and do not support anything other than retrain=True. Note: also controls the retraining of the data_transformers.

  • overlap_end (bool) – Whether the returned forecasts can go beyond the series’ end or not.

  • last_points_only (bool) – Whether to return only the last point of each historical forecast. If set to True, the method returns a single TimeSeries (for each time series in series) containing the successive point forecasts. Otherwise, returns a list of historical TimeSeries forecasts.

  • metric (Callable[…, Union[float, list[float], ndarray, list[ndarray]]]) –

    Either one of Darts’ “per time step” metrics (see here), or a custom metric that has an identical signature as Darts’ “per time step” metrics, uses decorators multi_ts_support() and multi_ts_support(), and returns one value per time step.

  • verbose (bool) – Whether to print the progress.

  • show_warnings (bool) – Whether to show warnings related to historical forecasts optimization, or parameters start and train_length.

  • predict_likelihood_parameters (bool) – If set to True, the model predicts the parameters of its likelihood instead of the target. Only supported for probabilistic models with a likelihood, num_samples = 1 and n<=output_chunk_length. Default: False.

  • enable_optimization (bool) – Whether to use the optimized version of historical_forecasts when supported and available. Default: True.

  • data_transformers (Optional[dict[str, Union[BaseDataTransformer, Pipeline]], None]) –

    Optionally, a dictionary of BaseDataTransformer or Pipeline to apply to the corresponding series (possibles keys; “series”, “past_covariates”, “future_covariates”). If provided, all input series must be in the un-transformed space. For fittable transformer / pipeline:

    • if retrain=True, the data transformer re-fit on the training data at each historical forecast step.

    • if retrain=False, the data transformer transforms the series once before all the forecasts.

    The fitted transformer is used to transform the input during both training and prediction. If the transformation is invertible, the forecasts will be inverse-transformed. Only effective when historical_forecasts=None.

  • metric_kwargs (Optional[dict[str, Any], None]) – Additional arguments passed to metric(), such as ‘n_jobs’ for parallelization, ‘m’ for scaled metrics, etc. Will pass arguments only if they are present in the corresponding metric signature. Ignores reduction arguments “series_reduction”, “component_reduction”, “time_reduction”, and parameter ‘insample’ for scaled metrics (e.g. mase`, rmsse, …), as they are handled internally.

  • fit_kwargs (Optional[dict[str, Any], None]) – Optionally, some additional arguments passed to the model fit() method.

  • predict_kwargs (Optional[dict[str, Any], None]) – Optionally, some additional arguments passed to the model predict() method.

  • sample_weight (Union[TimeSeries, Sequence[TimeSeries], str, None]) – Optionally, some sample weights to apply to the target series labels for training. Only effective when retrain is not False. They are applied per observation, per label (each step in output_chunk_length), and per component. If a series or sequence of series, then those weights are used. If the weight series only have a single component / column, then the weights are applied globally to all components in series. Otherwise, for component-specific weights, the number of components must match those of series. If a string, then the weights are generated using built-in weighting functions. The available options are “linear” or “exponential” decay - the further in the past, the lower the weight. The weights are computed per time series.

  • values_only (bool) – Whether to return the residuals as np.ndarray. If False, returns residuals as TimeSeries.

Return type

Union[TimeSeries, list[TimeSeries], list[list[TimeSeries]]]


  • TimeSeries – Residual TimeSeries for a single series and historical_forecasts generated with last_points_only=True.

  • List[TimeSeries] – A list of residual TimeSeries for a sequence (list) of series with last_points_only=True. The residual list has length len(series).

  • List[List[TimeSeries]] – A list of lists of residual TimeSeries for a sequence of series with last_points_only=False. The outer residual list has length len(series). The inner lists consist of the residuals from all possible series-specific historical forecasts.

save(path=None, clean=False)

Saves the model under a given path.

Creates two files under path (model object) and path.ckpt (checkpoint).

Note: Pickle errors may occur when saving models with custom classes. In this case, consider using the clean flag to strip the saved model from training related attributes.

Example for saving and loading a RNNModel:

from darts.models import RNNModel

model = RNNModel(input_chunk_length=4)

model_loaded = RNNModel.load("my_model.pt")
  • path (Optional[str, None]) – Path under which to save the model at its current state. Please avoid path starting with “last-” or “best-” to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model is automatically saved under "{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pt". E.g., "RNNModel_2020-01-01_12_00_00.pt".

  • clean (bool) –

    Whether to store a cleaned version of the model. If True, the training series and covariates are removed. Additionally, removes all Lightning Trainer-related parameters (passed with pl_trainer_kwargs at model creation).

    Note: After loading a model stored with clean=True, a series must be passed ‘predict()’, historical_forecasts() and other forecasting methods.

Return type


property supports_future_covariates: bool

Whether model supports future covariates

Return type


property supports_likelihood_parameter_prediction: bool

Whether model instance supports direct prediction of likelihood parameters

Return type


property supports_multivariate: bool

Whether the model considers more than one variate in the time series.

Return type


property supports_optimized_historical_forecasts: bool

Whether the model supports optimized historical forecasts

Return type


property supports_past_covariates: bool

Whether model supports past covariates

Return type


property supports_probabilistic_prediction: bool

Checks if the forecasting model with this configuration supports probabilistic predictions.

By default, returns False. Needs to be overwritten by models that do support probabilistic predictions.

Return type


property supports_sample_weight: bool

Whether model supports sample weight for training.

Return type


property supports_static_covariates: bool

Whether model supports static covariates

Return type


property supports_transferrable_series_prediction: bool

Whether the model supports prediction for any input series.

Return type



Updates the PyTorch Lightning Trainer parameters to move the model to CPU the next time :fun:`fit()` or predict() is called.

to_onnx(path=None, **kwargs)

Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning’s torch.onnx.export() method (official documentation).

Note: requires onnx library (optional dependency) to be installed.

Example for exporting a DLinearModel:

from darts.datasets import AirPassengersDataset
from darts.models import DLinearModel

series = AirPassengersDataset().load()
model = DLinearModel(input_chunk_length=4, output_chunk_length=1)
model.fit(series, epochs=1)
  • path (Optional[str, None]) – Path under which to save the model at its current state. If no path is specified, the model is automatically saved under "{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.onnx".

  • **kwargs

    Additional kwargs for PyTorch’s torch.onnx.export() method (except parameters file_path, input_sample, input_name). For more information, read the official documentation.


Returns a new (untrained) model instance created with the same parameters.

property uses_future_covariates: bool

Whether the model uses future covariates, once fitted.

Return type


property uses_past_covariates: bool

Whether the model uses past covariates, once fitted.

Return type


property uses_static_covariates: bool

Whether the model uses static covariates, once fitted.

Return type


class darts.models.forecasting.tsmixer_model.TimeBatchNorm2d(*args, **kwargs)[source]

Bases: BatchNorm2d

A batch normalization layer that normalizes over the last two dimensions of a Tensor.


add_module(name, module)

Add a child module to the current module.


Apply fn recursively to every submodule (as returned by .children()) as well as self.


Casts all floating point parameters and buffers to bfloat16 datatype.


Return an iterator over module buffers.


Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().


Move all model parameters and buffers to the CPU.


Move all model parameters and buffers to the GPU.


Casts all floating point parameters and buffers to double datatype.


Set the module in evaluation mode.


Return the extra representation of the module.


Casts all floating point parameters and buffers to float datatype.


Define the computation performed at every call.


Return the buffer given by target if it exists, otherwise throw an error.


Return any extra state to include in the module's state_dict.


Return the parameter given by target if it exists, otherwise throw an error.


Return the submodule given by target if it exists, otherwise throw an error.


Casts all floating point parameters and buffers to half datatype.


Move all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict, assign])

Copy parameters and buffers from state_dict into this module and its descendants.


Return an iterator over all modules in the network.


Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.


Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.


Return an iterator over module parameters.


Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.


Register a post-hook to be run after module's load_state_dict() is called.


Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.


Register a post-hook for the state_dict() method.


Register a pre-hook for the state_dict() method.


Change if autograd should record operations on parameters in this module.


Set extra state contained in the loaded state_dict.

set_submodule(target, module)

Set the submodule given by target if it exists, otherwise throw an error.


See torch.Tensor.share_memory_().

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.


Set the module in training mode.


Casts all parameters and buffers to dst_type.


Move all model parameters and buffers to the XPU.


Reset gradients of all model parameters.





alias of TypeVar(‘T_destination’, bound=Dict[str, Any])

add_module(name, module)

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

  • name (str) – name of the child module. The child module can be accessed from this module using the given name

  • module (Module) – child module to be added to the module.

Return type


affine: bool

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also nn-init-doc).


fn (Module -> None) – function to be applied to each submodule



Return type



>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)

Casts all floating point parameters and buffers to bfloat16 datatype.


This method modifies the module in-place.



Return type



Return an iterator over module buffers.


recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.


torch.Tensor – module buffer


>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
Return type


call_super_init: bool = False

Return an iterator over immediate children modules.


Module – a child module

Return type


compile(*args, **kwargs)

Compile this Module’s forward using torch.compile().

This Module’s __call__ method is compiled and all arguments are passed as-is to torch.compile().

See torch.compile() for details on the arguments for this function.


Move all model parameters and buffers to the CPU.


This method modifies the module in-place.



Return type



Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on GPU while being optimized.


This method modifies the module in-place.


device (int, optional) – if specified, all parameters will be copied to that device



Return type



Casts all floating point parameters and buffers to double datatype.


This method modifies the module in-place.



Return type


dump_patches: bool = False
eps: float

Set the module in evaluation mode.

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e. whether they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.



Return type



Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.


Casts all floating point parameters and buffers to float datatype.


This method modifies the module in-place.



Return type



Define the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type



Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.


target (str) – The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.)


The buffer referenced by target

Return type



AttributeError – If the target string references an invalid path or resolves to something that is not a buffer


Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.


Any extra state to store in the module’s state_dict

Return type



Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.


target (str) – The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.)


The Parameter referenced by target

Return type



AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Parameter


Return the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        (linear): Linear(in_features=100, out_features=200, bias=True)

(The diagram shows an nn.Module A. A which has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.


target (str) – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)


The submodule referenced by target

Return type



AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Module


Casts all floating point parameters and buffers to half datatype.


This method modifies the module in-place.



Return type



Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on IPU while being optimized.


This method modifies the module in-place.


device (int, optional) – if specified, all parameters will be copied to that device



Return type


load_state_dict(state_dict, strict=True, assign=False)

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.


If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – When set to False, the properties of the tensors in the current module are preserved whereas setting it to True preserves properties of the Tensors in the state dict. The only exception is the requires_grad field of Default: ``False`


  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Return type

NamedTuple with missing_keys and unexpected_keys fields


If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.


Return an iterator over all modules in the network.


Module – a module in the network


Duplicate modules are returned only once. In the following example, l will be returned only once.


>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
1 -> Linear(in_features=2, out_features=2, bias=True)
Return type


momentum: Optional[float]

Move all model parameters and buffers to the MTIA.

This also makes associated parameters and buffers different objects. So it should be called before constructing the optimizer if the module will live on MTIA while being optimized.


This method modifies the module in-place.


device (int, optional) – if specified, all parameters will be copied to that device



Return type


named_buffers(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

  • prefix (str) – prefix to prepend to all buffer names.

  • recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

  • remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.


(str, torch.Tensor) – Tuple containing the name and buffer


>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
Return type

Iterator[Tuple[str, Tensor]]


Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.


(str, Module) – Tuple containing a name and child module


>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
Return type

Iterator[Tuple[str, Module]]

named_modules(memo=None, prefix='', remove_duplicate=True)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

  • memo (Optional[Set[Module], None]) – a memo to store the set of modules already added to the result

  • prefix (str) – a prefix that will be added to the name of the module

  • remove_duplicate (bool) – whether to remove the duplicated module instances in the result or not


(str, Module) – Tuple of name and module


Duplicate modules are returned only once. In the following example, l will be returned only once.


>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix='', recurse=True, remove_duplicate=True)

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

  • prefix (str) – prefix to prepend to all parameter names.

  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

  • remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.


(str, Parameter) – Tuple containing the name and parameter


>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
Return type

Iterator[Tuple[str, Parameter]]

num_features: int

Return an iterator over module parameters.

This is typically passed to an optimizer.


recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.


Parameter – module parameter


>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
Return type



Register a backward hook on the module.

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.


a handle that can be used to remove the added hook by calling handle.remove()

Return type


register_buffer(name, tensor, persistent=True)

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

Buffers can be accessed as attributes using given names.

  • name (str) – name of the buffer. The buffer can be accessed from this module using the given name

  • tensor (Tensor or None) – buffer to be registered. If None, then operations that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.

  • persistent (bool) – whether the buffer is part of this module’s state_dict.


>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
Return type


register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)

Register a forward hook on the module.

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:

hook(module, args, kwargs, output) -> None or modified output
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided hook will be fired before all existing forward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.modules.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False

  • with_kwargs (bool) – If True, the hook will be passed the kwargs given to the forward function. Default: False

  • always_call (bool) – If True the hook will be run regardless of whether an exception is raised while calling the Module. Default: False


a handle that can be used to remove the added hook by calling handle.remove()

Return type


register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)

Register a forward pre-hook on the module.

The hook will be called every time before forward() is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If true, the provided hook will be fired before all existing forward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.modules.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False

  • with_kwargs (bool) – If true, the hook will be passed the kwargs given to the forward function. Default: False


a handle that can be used to remove the added hook by calling handle.remove()

Return type


register_full_backward_hook(hook, prepend=False)

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.


Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

  • hook (Callable) – The user-defined hook to be registered.

  • prepend (bool) – If true, the provided hook will be fired before all existing backward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.modules.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method.


a handle that can be used to remove the added hook by calling handle.remove()

Return type


register_full_backward_pre_hook(hook, prepend=False)

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature:

hook(module, grad_output) -> tuple[Tensor] or None

The grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of grad_output in subsequent computations. Entries in grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.


Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

  • hook (Callable) – The user-defined hook to be registered.

  • prepend (bool) – If true, the provided hook will be fired before all existing backward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.modules.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method.


a handle that can be used to remove the added hook by calling handle.remove()

Return type



Register a post-hook to be run after module’s load_state_dict() is called.

It should have the following signature::

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.


a handle that can be used to remove the added hook by calling handle.remove()

Return type



Register a pre-hook to be run before module’s load_state_dict() is called.

It should have the following signature::

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950


hook (Callable) – Callable hook that will be invoked before loading the state dict.

register_module(name, module)

Alias for add_module().

Return type


register_parameter(name, param)

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

  • name (str) – name of the parameter. The parameter can be accessed from this module using the given name

  • param (Parameter or None) – parameter to be added to the module. If None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict.

Return type



Register a post-hook for the state_dict() method.

It should have the following signature::

hook(module, state_dict, prefix, local_metadata) -> None

The registered hooks can modify the state_dict inplace.


Register a pre-hook for the state_dict() method.

It should have the following signature::

hook(module, prefix, keep_vars) -> None

The registered hooks can be used to perform pre-processing before the state_dict call is made.


Change if autograd should record operations on parameters in this module.

This method sets the parameters’ requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.


requires_grad (bool) – whether autograd should record operations on parameters in this module. Default: True.



Return type


Return type


Return type



Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.


state (dict) – Extra state from the state_dict

Return type


set_submodule(target, module)

Set the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        (linear): Linear(in_features=100, out_features=200, bias=True)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To overide the Conv2d with a new submodule Linear, you would call set_submodule("net_b.net_c.conv", nn.Linear(33, 16)).

  • target (str) – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)

  • module (Module) – The module to set the submodule to.

  • ValueError – If the target string is empty

  • AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Module

Return type



See torch.Tensor.share_memory_().

Return type


state_dict(*args, destination=None, prefix='', keep_vars=False)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.


The returned object is a shallow copy. It contains references to the module’s parameters and buffers.


Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.


Please avoid the use of argument destination as it is not designed for end-users.

  • destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''.

  • keep_vars (bool, optional) – by default the Tensor s returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.


a dictionary containing a whole state of the module

Return type



>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
to(*args, **kwargs)

Move and/or cast the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.


This method modifies the module in-place.

  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)



Return type



>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device, recurse=True)

Move the parameters and buffers to the specified device without copying storage.

  • device (torch.device) – The desired device of the parameters and buffers in this module.

  • recurse (bool) – Whether parameters and buffers of submodules should be recursively moved to the specified device.



Return type


track_running_stats: bool

Set the module in training mode.

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc.


mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.



Return type


training: bool

Casts all parameters and buffers to dst_type.


This method modifies the module in-place.


dst_type (type or string) – the desired type



Return type



Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.


This method modifies the module in-place.


device (int, optional) – if specified, all parameters will be copied to that device



Return type



Reset gradients of all model parameters.

See similar function under torch.optim.Optimizer for more context.


set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details.

Return type
