TFT Explainer for Temporal Fusion Transformer (TFTModel)

The TFTExplainer uses a trained TFTModel and extracts the explainability information from the model.

  • plot_variable_selection() plots the variable selection weights for each of the input features. - encoder importance: historic part of target, past covariates and historic part of future covariates - decoder importance: future part of future covariates - static covariates importance: the numeric and catageorical static covariates importance

  • plot_attention() plots the transformer attention that the TFTModel applies on the given past and future input. The attention is aggregated over all attention heads.

The attention and feature importance values can be extracted using the TFTExplainabilityResult returned by explain(). An example of this is shown in the method description.

We also show how to use the TFTExplainer in the example notebook of the TFTModel here.

class darts.explainability.tft_explainer.TFTExplainer(model, background_series=None, background_past_covariates=None, background_future_covariates=None)[source]

Bases: _ForecastingModelExplainer

Explainer class for the TFTModel.

Definitions

  • A background series is a TimeSeries that is used as a default for generating the explainability result (if no foreground is passed to explain()).

  • A foreground series is a TimeSeries that can be passed to explain() to use instead of the background for generating the explainability result.

Parameters
  • model (darts.models.forecasting.tft_model.TFTModel) – The fitted TFTModel to be explained.

  • background_series (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a series or list of series to use as a default target series for the explanations. Optional if model was trained on a single target series. By default, it is the series used at fitting time. Mandatory if model was trained on multiple (sequence of) target series.

  • background_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a past covariates series or list of series to use as a default past covariates series for the explanations. The same requirements apply as for background_series .

  • background_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, a future covariates series or list of series to use as a default future covariates series for the explanations. The same requirements apply as for background_series.

Examples

>>> from darts.datasets import AirPassengersDataset
>>> from darts.explainability.tft_explainer import TFTExplainer
>>> from darts.models import TFTModel
>>> series = AirPassengersDataset().load()
>>> model = TFTModel(
>>>     input_chunk_length=12,
>>>     output_chunk_length=6,
>>>     add_encoders={"cyclic": {"future": ["hour"]}}
>>> )
>>> model.fit(series)
>>> # create the explainer and generate explanations
>>> explainer = TFTExplainer(model)
>>> results = explainer.explain()
>>> # plot the results
>>> explainer.plot_attention(results, plot_type="all")
>>> explainer.plot_variable_selection(results)

Methods

explain([foreground_series, ...])

Returns the TFTExplainabilityResult result for all series in foreground_series.

plot_attention(expl_result[, plot_type, ...])

Plots the attention heads of the TFTModel.

plot_variable_selection(expl_result[, ...])

Plots the variable selection / feature importances of the TFTModel based on the input.

explain(foreground_series=None, foreground_past_covariates=None, foreground_future_covariates=None, horizons=None, target_components=None)[source]

Returns the TFTExplainabilityResult result for all series in foreground_series. If foreground_series is None, will use the background input from TFTExplainer creation (either the background passed to creation, or the series stored in the TFTModel in case it was only trained on a single series). For each series, the results contain the attention heads, encoder variable importances, decoder variable importances, and static covariates importances.

Parameters
  • foreground_series (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, one or a sequence of target TimeSeries to be explained. Can be multivariate. If not provided, the background TimeSeries will be explained instead.

  • foreground_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, one or a sequence of past covariates TimeSeries if required by the forecasting model.

  • foreground_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – Optionally, one or a sequence of future covariates TimeSeries if required by the forecasting model.

  • horizons (Optional[Sequence[int]]) – This parameter is not used by the TFTExplainer.

  • target_components (Optional[Sequence[str]]) – This parameter is not used by the TFTExplainer.

Returns

The explainability result containing the attention heads, encoder variable importances, decoder variable importances, and static covariates importances.

Return type

TFTExplainabilityResult

Examples

>>> explainer = TFTExplainer(model)  # requires `background` if model was trained on multiple series

Optionally, give a foreground input to generate the explanation on a new input. Otherwise, leave it empty to compute the explanation on the background from TFTExplainer creation

>>> explain_results = explainer.explain(
>>>     foreground_series=foreground_series,
>>>     foreground_past_covariates=foreground_past_covariates,
>>>     foreground_future_covariates=foreground_future_covariates,
>>> )
>>> attn = explain_results.get_attention()
>>> importances = explain_results.get_feature_importances()
model: TFTModel
plot_attention(expl_result, plot_type='all', show_index_as='relative', ax=None, max_nr_series=5, show_plot=True)[source]

Plots the attention heads of the TFTModel.

Parameters
  • expl_result (TFTExplainabilityResult) – A TFTExplainabilityResult object. Corresponds to the output of explain().

  • plot_type (Optional[Literal[‘all’, ‘time’, ‘heatmap’]]) – The type of attention head plot. One of (“all”, “time”, “heatmap”). If “all”, will plot the attention per horizon (given the horizons in the TFTExplainabilityResult). The maximum horizon corresponds to the output_chunk_length of the trained TFTModel. If “time”, will plot the mean attention over all horizons. If “heatmap”, will plot the attention per horizon on a heat map. The horizons are shown on the y-axis, and times / relative indices on the x-axis.

  • show_index_as (Literal[‘relative’, ‘time’]) – The type of index to be shown. One of (“relative”, “time”). If “relative”, will plot the x-axis from (-input_chunk_length, output_chunk_length - 1). 0 corresponds to the first prediction point. If “time”, will plot the x-axis with the actual time index (or range index) of the corresponding TFTExplainabilityResult.

  • ax (Optional[Axes]) – Optionally, an axis to plot on. Only effective on a single expl_result.

  • max_nr_series (int) – The maximum number of plots to show in case expl_result was computed on multiple series.

  • show_plot (bool) – Whether to show the plot.

Return type

Axes

plot_variable_selection(expl_result, fig_size=None, max_nr_series=5)[source]

Plots the variable selection / feature importances of the TFTModel based on the input. The figure includes three subplots:

  • encoder importances: contains the past target, past covariates, and historic future covariates importance on the encoder (input chunk)

  • decoder importances: contains the future covariates importance on the decoder (output chunk)

  • static covariates importances: contains the numeric and / or categorical static covariates importance

Parameters
  • expl_result (TFTExplainabilityResult) – A TFTExplainabilityResult object. Corresponds to the output of explain().

  • fig_size – The size of the figure to be plotted.

  • max_nr_series (int) – The maximum number of plots to show in case expl_result was computed on multiple series.