"""
Fast Fourier Transform
----------------------
"""
from typing import Callable, Optional
import numpy as np
import pandas as pd
from statsmodels.tsa.stattools import acf
from darts.logging import get_logger
from darts.models.forecasting.forecasting_model import LocalForecastingModel
from darts.timeseries import TimeSeries
from darts.utils.missing_values import fill_missing_values
logger = get_logger(__name__)
def _check_approximate_seasonality(
series: TimeSeries,
seasonality_period: int,
period_error_margin: int,
max_seasonality_order: int,
) -> bool:
"""Checks whether the given series has a given seasonality.
Analyzes the given TimeSeries instance for seasonality of the given period
while taking into account potential noise of the autocorrelation function.
This is done by averaging all AC values that are within `period_error_margin`
steps from the index `seasonality_period` in the ACF domain.
Parameters
----------
series
The TimeSeries instance to be analyzed.
seasonality_period
The (approximate) period to be checked for seasonality.
period_error_margin
The radius around the `seasonality_period` that is taken into consideration when computing the autocorrelation.
max_seasonality_order
The maximum number of lags (or inputs to the acf) that can exceed the ac value computed over the interval
around `seasonality_period`. The lower this number, the stricter the criterion for seasonality.
Returns
-------
bool
Boolean value indicating whether the seasonality is significant given the parameters passed.
"""
# fraction of seasonality_period that will skipped when looking at acf values due to high
# autocorrelation for small lags
frac = 1 / 4
# return False if there are not enough entries in the TimeSeries instance
if len(series) < seasonality_period * (1 + frac):
return False
# compute relevant autocorrelation values
r = acf(
series.univariate_values(),
nlags=int(seasonality_period * (1 + frac)),
fft=False,
)
# compute the approximate autocorrelation value for the given period
left_bound = seasonality_period - period_error_margin
right_bound = seasonality_period + period_error_margin
approximation_interval = range(left_bound, right_bound + 1)
approximated_period_ac = np.mean(r[approximation_interval])
# compute the number of ac values larger than the approximated ac value for the given period
indices = list(range(int(frac * seasonality_period), left_bound)) + list(
range(right_bound + 1, len(r))
)
order = sum(
map(lambda ac_value: int(ac_value > approximated_period_ac), r[indices])
)
return order <= max_seasonality_order
def _find_relevant_timestamp_attributes(series: TimeSeries) -> set:
"""Finds pd.Timestamp attributes relevant for seasonality.
Analyzes the given TimeSeries instance for relevant pd.Timestamp attributes
in terms of the autocorrelation of their length within the series with the
goal of finding the periods of the seasonal trends present in the series.
Parameters
----------
series
The TimeSeries instance to be analyzed.
Returns
-------
set
A set of pd.Timestamp attributes with high autocorrelation within `series`.
"""
relevant_attributes = set()
if type(series.freq) in {
pd.tseries.offsets.MonthBegin,
pd.tseries.offsets.MonthEnd,
}:
# check for yearly seasonality
if _check_approximate_seasonality(series, 12, 1, 0):
relevant_attributes.add("month")
elif type(series.freq) is pd.tseries.offsets.Day:
# check for yearly seasonality
if _check_approximate_seasonality(series, 365, 5, 20):
relevant_attributes.update({"month", "day"})
# check for monthly seasonality
elif _check_approximate_seasonality(series, 30, 2, 2):
relevant_attributes.add("day")
# check for weekly seasonality
elif _check_approximate_seasonality(series, 7, 0, 0):
relevant_attributes.add("weekday")
elif type(series.freq) is pd.tseries.offsets.Hour:
# check for yearly seasonality
if _check_approximate_seasonality(series, 8760, 100, 100):
relevant_attributes.update({"month", "day", "hour"})
# check for monthly seasonality
elif _check_approximate_seasonality(series, 730, 10, 30):
relevant_attributes.update({"day", "hour"})
# check for weekly seasonality
elif _check_approximate_seasonality(series, 168, 3, 10):
relevant_attributes.update({"weekday", "hour"})
# check for daily seasonality
elif _check_approximate_seasonality(series, 24, 1, 1):
relevant_attributes.add("hour")
elif type(series.freq) is pd.tseries.offsets.Minute:
# check for daily seasonality
if _check_approximate_seasonality(series, 1440, 20, 50):
relevant_attributes.update({"hour", "minute"})
# check for hourly seasonality
elif _check_approximate_seasonality(series, 60, 4, 3):
relevant_attributes.add("minute")
return relevant_attributes
def _compare_timestamps_on_attributes(
ts_1: pd.Timestamp, ts_2: pd.Timestamp, required_matches: set
) -> bool:
"""Compares pd.Timestamp instances on attributes.
Compares two timestamps according two a given set of attributes (such as minute, hour, day, etc.).
It returns True if and only if the two timestamps are matching in all given attributes.
Parameters
----------
ts_1
First timestamp that will be compared.
ts_2
Second timestamp that will be compared.
required_matches
A set of pd.Timestamp attributes which ts_1 and ts_2 will be checked on.
Returns
-------
bool
True if and only if `ts_1` and `ts_2` match in all attributes given in `required_matches`.
"""
return all(
map(lambda attr: getattr(ts_1, attr) == getattr(ts_2, attr), required_matches)
)
def _crop_to_match_seasons(
series: TimeSeries, required_matches: Optional[set]
) -> TimeSeries:
"""Crops TimeSeries instance to contain full periods.
Crops a given TimeSeries `series` that will be used as a training set in such
a way that its first entry has a timestamp that matches the first timestamp
right after the end of `series` in all attributes given in `required_matches`.
If no such timestamp can be found, the original TimeSeries instance is returned.
If the value of `required_matches` is `None`, the original TimeSeries instance is returned.
Parameters
----------
series
TimeSeries instance to be cropped.
required_matches
A set of pd.Timestamp attributes which will be used to choose the cropping point.
Returns
-------
TimeSeries
New TimeSeries instance that is cropped as described above.
"""
if required_matches is None or len(required_matches) == 0:
return series
first_ts = series.time_index[0]
freq = series.freq
pred_ts = series.time_index[-1] + freq
# start at first timestamp of given series and move forward until a matching timestamp is found
curr_ts = first_ts
while curr_ts < pred_ts - 4 * freq:
curr_ts += freq
if _compare_timestamps_on_attributes(pred_ts, curr_ts, required_matches):
new_series = series.drop_before(curr_ts)
return new_series
logger.warning(
"No matching timestamp could be found, returning original TimeSeries."
)
return series
[docs]class FFT(LocalForecastingModel):
def __init__(
self,
nr_freqs_to_keep: Optional[int] = 10,
required_matches: Optional[set] = None,
trend: Optional[str] = None,
trend_poly_degree: int = 3,
):
"""Fast Fourier Transform Model
This model performs forecasting on a TimeSeries instance using FFT, subsequent frequency filtering
(controlled by the `nr_freqs_to_keep` argument) and inverse FFT, combined with the option to detrend
the data (controlled by the `trend` argument) and to crop the training sequence to full seasonal periods
Note that if the training series contains any NaNs (missing values), these will be filled using
:func:`darts.utils.missing_values.fill_missing_values()`.
Parameters
----------
nr_freqs_to_keep
The total number of frequencies that will be used for forecasting.
required_matches
The attributes of pd.Timestamp that will be used to create a training sequence that is cropped at the
beginning such that the first timestamp of the training sequence and the first prediction point have
matching phases. If the series has a yearly seasonality, include `month`, if it has a monthly
seasonality, include `day`, etc. If not set, or explicitly set to None, the model tries to find the
pd.Timestamp attributes that are relevant for the seasonality automatically.
trend
If set, indicates what kind of detrending will be applied before performing DFT.
Possible values: 'poly', 'exp' or None, for polynomial trend, exponential trend or no trend, respectively.
trend_poly_degree
The degree of the polynomial that will be used for detrending, if `trend='poly'`.
Examples
--------
Automatically detect the seasonal periods, uses the 10 most significant frequencies for
forecasting and expect no global trend to be present in the data:
>>> FFT(nr_freqs_to_keep=10)
Assume the provided TimeSeries instances will have a monthly seasonality and an exponential
global trend, and do not perform any frequency filtering:
>>> FFT(required_matches={'month'}, trend='exp')
Simple usage example, using one of the dataset available in darts
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import FFT
>>> series = AirPassengersDataset().load()
>>> # increase the number of frequency and use a polynomial trend of degree 2
>>> model = FFT(
>>> nr_freqs_to_keep=20,
>>> trend= "poly",
>>> trend_poly_degree=2
>>> )
>>> model.fit(series)
>>> pred = model.predict(6)
>>> pred.values()
array([[471.79323146],
[494.6381425 ],
[504.5659999 ],
[515.82463265],
[520.59404623],
[547.26720705]])
.. note::
`FFT example notebook <https://unit8co.github.io/darts/examples/03-FFT-examples.html>`_ presents techniques
that can be used to improve the forecasts quality compared to this simple usage example.
"""
super().__init__()
self.nr_freqs_to_keep = nr_freqs_to_keep
self.required_matches = required_matches
self.trend = trend
self.trend_poly_degree = trend_poly_degree
@property
def supports_multivariate(self) -> bool:
return False
def _exp_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return np.exp(self.trend_coefficients[1]) * np.exp(
self.trend_coefficients[0] * x
)
def _poly_trend(self, trend_coefficients) -> Callable:
"""Helper function, for consistency with the other trends"""
return np.poly1d(trend_coefficients)
def _null_trend(self, x) -> Callable:
"""Helper function, used to make FFT model pickable."""
return 0
[docs] def fit(self, series: TimeSeries):
series = fill_missing_values(series)
super().fit(series)
self._assert_univariate(series)
series = self.training_series
# determine trend
if self.trend == "poly":
self.trend_coefficients = np.polyfit(
range(len(series)), series.univariate_values(), self.trend_poly_degree
)
self.trend_function = self._poly_trend(self.trend_coefficients)
elif self.trend == "exp":
self.trend_coefficients = np.polyfit(
range(len(series)), np.log(series.univariate_values()), 1
)
self.trend_function = self._exp_trend
else:
self.trend_coefficients = None
self.trend_function = self._null_trend
# subtract trend
detrended_values = series.univariate_values() - self.trend_function(
range(len(series))
)
detrended_series = TimeSeries.from_times_and_values(
series.time_index, detrended_values
)
# crop training set to match the seasonality of the first prediction point
if self.required_matches is None:
curr_required_matches = _find_relevant_timestamp_attributes(
detrended_series
)
else:
curr_required_matches = self.required_matches
cropped_series = _crop_to_match_seasons(
detrended_series, required_matches=curr_required_matches
)
# perform dft
self.fft_values = np.fft.fft(cropped_series.univariate_values())
# get indices of `nr_freqs_to_keep` (if a correct value was provided) frequencies with the highest amplitudes
# by partitioning around the element with sorted index -nr_freqs_to_keep instead of sorting the whole array
first_n = self.nr_freqs_to_keep
if first_n is None or first_n < 1 or first_n > len(self.fft_values):
first_n = len(self.fft_values)
self.filtered_indices = np.argpartition(abs(self.fft_values), -first_n)[
-first_n:
]
# set all other values in the frequency domain to 0
self.fft_values_filtered = np.zeros(len(self.fft_values), dtype=np.complex_)
self.fft_values_filtered[self.filtered_indices] = self.fft_values[
self.filtered_indices
]
# precompute all possible predicted values using inverse dft
self.predicted_values = np.fft.ifft(self.fft_values_filtered).real
return self
[docs] def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
trend_forecast = np.array(
[self.trend_function(i + len(self.training_series)) for i in range(n)]
)
periodic_forecast = np.array(
[self.predicted_values[i % len(self.predicted_values)] for i in range(n)]
)
return self._build_forecast_series(periodic_forecast + trend_forecast)