Source code for darts.models.forecasting.exponential_smoothing

"""
Exponential Smoothing
---------------------
"""

from typing import Any, Dict, Optional

import numpy as np
import statsmodels.tsa.holtwinters as hw

from darts.logging import get_logger
from darts.models.forecasting.forecasting_model import LocalForecastingModel
from darts.timeseries import TimeSeries
from darts.utils.utils import ModelMode, SeasonalityMode

logger = get_logger(__name__)


[docs]class ExponentialSmoothing(LocalForecastingModel): def __init__( self, trend: Optional[ModelMode] = ModelMode.ADDITIVE, damped: Optional[bool] = False, seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE, seasonal_periods: Optional[int] = None, random_state: int = 0, kwargs: Optional[Dict[str, Any]] = None, **fit_kwargs, ): """Exponential Smoothing This is a wrapper around `statsmodels Holt-Winters' Exponential Smoothing <https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html>`_; we refer to this link for the original and more complete documentation of the parameters. `trend` must be a ``ModelMode`` Enum member. You can access the Enum with ``from darts.utils.utils import ModelMode``. `seasonal` must be a ``SeasonalityMode`` Enum member. You can access the Enum with ``from darts.utils.utils import SeasonalityMode``. ``ExponentialSmoothing(trend=ModelMode.NONE, seasonal=SeasonalityMode.NONE)`` corresponds to a single exponential smoothing. ``ExponentialSmoothing(trend=ModelMode.ADDITIVE, seasonal=SeasonalityMode.NONE)`` corresponds to a Holt's exponential smoothing. Please note that automatic `seasonal_period` selection (setting the `seasonal_periods` parameter equal to `None`) can sometimes lead to errors if the input time series is too short. In these cases we suggest to manually set the `seasonal_periods` parameter to a positive integer. Parameters ---------- trend Type of trend component. Either ``ModelMode.ADDITIVE``, ``ModelMode.MULTIPLICATIVE``, ``ModelMode.NONE``, or ``None``. Defaults to ``ModelMode.ADDITIVE``. damped Should the trend component be damped. Defaults to False. seasonal Type of seasonal component. Either ``SeasonalityMode.ADDITIVE``, ``SeasonalityMode.MULTIPLICATIVE``, ``SeasonalityMode.NONE``, or ``None``. Defaults to ``SeasonalityMode.ADDITIVE``. seasonal_periods The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily data with a weekly cycle. If not set, inferred from frequency of the series. kwargs Some optional keyword arguments that will be used to call :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`. See `the documentation <https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.html>`_. fit_kwargs Some optional keyword arguments that will be used to call :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. See `the documentation <https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.ExponentialSmoothing.fit.html>`_. Examples -------- >>> from darts.datasets import AirPassengersDataset >>> from darts.models import ExponentialSmoothing >>> from darts.utils.utils import ModelMode, SeasonalityMode >>> series = AirPassengersDataset().load() >>> # using Holt's exponential smoothing >>> model = ExponentialSmoothing(trend=ModelMode.ADDITIVE, seasonal=SeasonalityMode.MULTIPLICATIVE) >>> model.fit(series) >>> pred = model.predict(6) >>> pred.values() array([[445.24283838], [418.22618932], [465.31305075], [494.95129261], [505.4770514 ], [573.31519186]]) """ super().__init__() self.trend = trend self.damped = damped self.seasonal = seasonal self.infer_seasonal_periods = seasonal_periods is None self.seasonal_periods = seasonal_periods self.constructor_kwargs = dict() if kwargs is None else kwargs self.fit_kwargs = fit_kwargs self.model = None np.random.seed(random_state)
[docs] def fit(self, series: TimeSeries): super().fit(series) self._assert_univariate(series) series = self.training_series # if the model was initially created with `self.seasonal_periods=None`, make sure that # the model will try to automatically infer the index, otherwise it should use the # provided `seasonal_periods` value seasonal_periods_param = ( None if self.infer_seasonal_periods else self.seasonal_periods ) # set the seasonal periods parameter to a default value if it was not provided explicitly # and if it cannot be inferred due to the lack of a datetime index if self.seasonal_periods is None and series.has_range_index: seasonal_periods_param = 12 hw_model = hw.ExponentialSmoothing( series.values(copy=False), trend=self.trend if self.trend is None else self.trend.value, damped_trend=self.damped, seasonal=self.seasonal if self.seasonal is None else self.seasonal.value, seasonal_periods=seasonal_periods_param, freq=series.freq if series.has_datetime_index else None, dates=series.time_index if series.has_datetime_index else None, **self.constructor_kwargs, ) hw_results = hw_model.fit(**self.fit_kwargs) self.model = hw_results if self.infer_seasonal_periods: self.seasonal_periods = hw_model.seasonal_periods return self
[docs] def predict( self, n: int, num_samples: int = 1, verbose: bool = False, show_warnings: bool = True, ): super().predict(n, num_samples) if num_samples == 1: forecast = self.model.forecast(n) else: forecast = np.expand_dims( self.model.simulate(n, repetitions=num_samples), axis=1 ) return self._build_forecast_series(forecast)
@property def supports_multivariate(self) -> bool: return False @property def supports_probabilistic_prediction(self) -> bool: return True @property def min_train_series_length(self) -> int: if self.seasonal_periods is not None and self.seasonal_periods > 1: return 2 * self.seasonal_periods return 3