Explaining Forecasting Models with SHAP#

Why did the model predict a spike tomorrow morning? Which lagged values mattered most? How do covariates shape the forecast?

Darts’ ShapExplainer answers these questions using SHAP (SHapley Additive exPlanations) [1], a game-theoretic framework that attributes each input feature’s contribution to a prediction. It works with any of Darts’ scikit-learn-like [2] and PyTorch [3] models through a single, unified API.

What you will learn:

We use the ElectricityConsumptionZurichDataset throughout, keeping the data small for fast computation.

  • [1] Lundberg & Lee, A Unified Approach to Interpreting Model Predictions, NeurIPS 2017.

  • [2] Models such as SKLearnModel, LinearRegressionModel, LightGBMModel, …

  • [3] Regular neural network models such as TiDEModel and foundation models such as Chronos2Model, …

1. Setup and Data Preparation#

[1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
[2]:
import logging
import warnings

import numpy as np
import plotly
import shap

from darts import concatenate, set_option
from darts.datasets import ElectricityConsumptionZurichDataset
from darts.explainability import ShapExplainer
from darts.metrics import mae, mic, miw
from darts.models import DLinearModel, LinearRegressionModel
from darts.utils.likelihood_models import QuantileRegression
from darts.utils.timeseries_generation import datetime_attribute_timeseries

warnings.filterwarnings("ignore")
logging.disable(logging.CRITICAL)
set_option("plotting.use_darts_style", True)

plotly.offline.init_notebook_mode()
shap.initjs()

We load hourly electricity consumption for Zurich households & SMEs, using the last four weeks (three for training, one for testing). Also, let’s use calendar features “hour” and “dayofweek” as future covariates.

[3]:
data = ElectricityConsumptionZurichDataset().load().astype("float32")
data = data.resample("h", method="sum")[-28 * 24 - 1 : -1]

series = data["Value_NE5"].with_columns_renamed("Value_NE5", "consumption")
future_covariates = (
    datetime_attribute_timeseries(
        time_index=series,
        attribute="hour",
        add_length=24,
    )
    .add_datetime_attribute("dayofweek")
    .astype("float32")
)

train, test = series[: -24 * 7], series[-24 * 7 :]
[4]:
fig = train.plotly(label="train")
test.plotly(label="test", fig=fig)
fig.update_layout(yaxis_title="Consumption (MWh)", autosize=True)

2. Train Two Models#

We train a LinearRegressionModel (scikit-learn) and a DLinearModel (PyTorch) to forecast the next 24 hours. Both use the same inputs:

  • Target lags: past 24 hours of consumption.

  • Future covariates: next 24 hours of hour-of-day, and day-of-week.

[5]:
sklearn_model = LinearRegressionModel(
    lags=24,
    lags_future_covariates=(0, 24),
    output_chunk_length=24,
    random_state=42,
).fit(train, future_covariates=future_covariates)

torch_model = DLinearModel(
    input_chunk_length=24,
    output_chunk_length=24,
    random_state=42,
).fit(train, future_covariates=future_covariates)

Quick Evaluation#

A rolling 24-hour-ahead historical forecast over the test set confirms both models produce reasonable predictions. This notebook focuses on explainability, so we keep evaluation brief.

[6]:
hfc_kwargs = dict(
    series=series,
    future_covariates=future_covariates,
    start=test.start_time(),
    forecast_horizon=24,
    stride=24,
    retrain=False,
    last_points_only=False,
)
sklearn_pred = concatenate(sklearn_model.historical_forecasts(**hfc_kwargs))
torch_pred = concatenate(torch_model.historical_forecasts(**hfc_kwargs))

fig = test.plotly(label="actual")
sklearn_pred.plotly(
    label=f"LinearRegression (MAE: {mae(test, sklearn_pred):.2f})", fig=fig
)
torch_pred.plotly(label=f"DLinear (MAE: {mae(test, torch_pred):.2f})", fig=fig)
fig.update_layout(yaxis_title="Consumption (MWh)", autosize=True)

3. Global Explainability#

Global explanations reveal which features are most important overall across many forecasts. We initialize a ShapExplainer with:

  • model: the fitted forecasting model.

  • background_series: reference data used to compute the SHAP baseline prediction. Defaults to the training data if omitted.

By default, ShapExplainer selects the most appropriate SHAP method for the model type (e.g. "linear" for linear regression, "permutation" for PyTorch models). You can override this via the shap_method parameter.

[7]:
sklearn_explainer = ShapExplainer(
    model=sklearn_model,
    background_series=test,
    background_future_covariates=future_covariates,
)

3.1 Summary Beeswarm Plot#

The beeswarm (or “dot”) plot shows every forecast instance as a dot, colored by the raw feature value. Features are sorted by mean absolute SHAP value (most important at top).

We inspect the 12-hour-ahead horizon:

[8]:
sklearn_explainer.summary_plot(horizons=[12], plot_kwargs=dict(max_display=10));
../_images/examples_28-Explainability-examples_14_0.png

Target lags dominate, meaning that past consumption is the strongest predictor of future consumption. The color gradient shows how feature values correlate with their SHAP contributions: for example, higher values of recent consumption lags push the forecast up (red dots on the right).

3.2 Summary Bar Plot#

A bar plot gives a cleaner ranking of mean |SHAP| importance. Let’s compare two horizons 12h and 24h ahead:

[9]:
sklearn_explainer.summary_plot(
    horizons=[12, 24], plot_type="bar", plot_kwargs=dict(max_display=5)
);
../_images/examples_28-Explainability-examples_17_0.png
../_images/examples_28-Explainability-examples_17_1.png

Notice how the ranking shifts between horizons: some lags matter more for shorter horizons, while others gain importance at 24h. This is a natural consequence of multi-horizon forecasting.

3.3 Dependence Plot#

To see how a single feature’s value relates to its SHAP contribution, we extract the shap.Explanation object and use shap.plots.scatter. Because LinearRegressionModel is linear, we expect a clean linear relationship:

[10]:
result = sklearn_explainer.explain()
shap_object = result.get_shap_explanation_object(horizon=12)
shap.plots.scatter(shap_object[:, "consumption_target_lag-13"])
../_images/examples_28-Explainability-examples_20_0.png

The linear relationship confirms the model’s structure: lag -13 is exactly one 24 hours before horizon=12 (lag=11), so the model learns a strong daily-seasonality coefficient.

3.4 Same API for PyTorch Models#

The ShapExplainer API is identical for PyTorch models. Let’s create an explainer for DLinearModel and produce a summary plot at the same horizon:

[11]:
torch_explainer = ShapExplainer(
    model=torch_model,
    background_series=test,
    background_future_covariates=future_covariates,
    batch_size=4096,
)

torch_explainer.summary_plot(horizons=[12], plot_kwargs=dict(max_display=5));
../_images/examples_28-Explainability-examples_23_0.png

The DLinear model also relies heavily on target lags - but the distribution of SHAP values is different from the linear model, reflecting the different model architecture. The API calls are identical: just swap the model.

4. Local Explainability#

Local explanations answer: why did the model make this specific prediction?

We use the sklearn model for local explanations (the API is the same for torch models).

4.1 Computing SHAP Values#

The .explain() method computes SHAP values for all forecastable timestamps in the foreground series, returning a ShapExplainabilityResult:

[12]:
result = sklearn_explainer.explain(
    foreground_series=test,
    foreground_future_covariates=future_covariates,
)

The result object provides three accessors:

Method

Returns

get_explanation(horizon)

TimeSeries of SHAP values (features as components)

get_feature_values(horizon)

TimeSeries of raw input feature values

get_shap_explanation_object(horizon)

Raw shap.Explanation for use with any SHAP plot

Let’s inspect the SHAP values at the 12-hour horizon:

[13]:
result.get_explanation(horizon=12)
[13]:

consumption_target_lag-24 consumption_target_lag-23 consumption_target_lag-22 consumption_target_lag-21 consumption_target_lag-20 ... dayofweek_futcov_lag21 hour_futcov_lag22 dayofweek_futcov_lag22 hour_futcov_lag23 dayofweek_futcov_lag23
2022-08-25 00:00:00 -10379.894531 2148.450439 -2.689489 532.226379 -281.067688 ... -14.348060 153.506866 -3.887789 -442.714478 -8.456299
2022-08-25 01:00:00 -12328.749023 2048.898438 -2.099411 58.013622 1131.847046 ... -14.348060 167.893631 -3.887789 432.817169 -177.582428
2022-08-25 02:00:00 -11750.551758 1589.196167 -0.198098 -288.357758 1926.117310 ... -14.348060 -163.002121 -198.277420 394.750580 -177.582428
2022-08-25 03:00:00 -9080.606445 107.973320 1.190647 -483.070465 2247.932861 ... -219.320526 -148.615356 -198.277420 356.683960 -177.582428
2022-08-25 04:00:00 -477.680511 -973.931763 1.971330 -561.962463 2723.159912 ... -219.320526 -134.228577 -198.277420 318.617371 -177.582428
... ... ... ... ... ... ... ... ... ... ... ...
2022-08-30 20:00:00 -3280.172363 899.971313 -1.236710 467.921997 -2239.091064 ... 190.624420 95.959770 190.501846 -290.448120 160.669830
2022-08-30 21:00:00 -5077.596191 917.106873 -1.841588 538.016541 -2383.853027 ... 190.624420 110.346542 190.501846 -328.514709 160.669830
2022-08-30 22:00:00 -5177.119141 1388.338867 -2.122626 573.504517 -2409.126221 ... 190.624420 124.733315 190.501846 -366.581299 160.669830
2022-08-30 23:00:00 -7914.029297 1607.282104 -2.264912 579.700073 -1911.726196 ... 190.624420 139.120087 190.501846 -404.647888 160.669830
2022-08-31 00:00:00 -9185.649414 1718.130127 -2.289752 457.764160 -212.922882 ... 190.624420 153.506866 190.501846 -442.714478 160.669830

shape: (145, 72, 1), freq: h, size: 40.78 KB

Each row is a forecast start time; each column is a feature’s SHAP contribution to the 12-hour-ahead prediction. There are 72 input features in total: 24 target lags, plus 24 lags each for the three future covariate components (hour, dayofweek).

4.2 Waterfall Plot#

A waterfall plot is the most intuitive way to see why a single prediction was made. Starting from the baseline (average prediction), each feature adds or subtracts from the final predicted value:

[14]:
shap_object = result.get_shap_explanation_object(horizon=12)
shap.plots.waterfall(shap_object[0])
../_images/examples_28-Explainability-examples_31_0.png

The model starts at E[f(x)] (baseline), then each feature pushes the prediction up (red) or down (blue), arriving at f(x) - the actual predicted value. Target lags dominate this particular prediction.

4.3 Force Plot#

The force plot is an interactive alternative that shows the same additive decomposition in a compact horizontal layout. Select “original sample ordering” in the interactive toolbar on top to show forecast timestamps chronologically:

[15]:
sklearn_explainer.force_plot(
    foreground_series=test,
    foreground_future_covariates=future_covariates,
    horizon=12,
)
[15]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Red regions push the prediction above the baseline; blue regions pull it below. Hovering over a region reveals the contributing features and their values.

4.4 Explaining a Single Forecast#

When you only care about one specific prediction (e.g. the actual model forecast), .explain_single() is more efficient than .explain(), as it computes SHAP values for just that one forecast across all horizons. This explains the forecast generated by calling model.predict(n=24, series=...).

[16]:
result_single = sklearn_explainer.explain_single(
    foreground_series=test,
    foreground_future_covariates=future_covariates,
)
print("Test set end time:", test.end_time())
result_single.get_explanation()
Test set end time: 2022-08-30 23:00:00
[16]:

consumption_target_lag-24 consumption_target_lag-23 consumption_target_lag-22 consumption_target_lag-21 consumption_target_lag-20 ... dayofweek_futcov_lag21 hour_futcov_lag22 dayofweek_futcov_lag22 hour_futcov_lag23 dayofweek_futcov_lag23
2022-08-31 00:00:00 -106.176033 -2581.511963 1692.701782 2070.556885 -270.833435 ... 13.682087 -943.732361 12.682373 -569.508179 9.638122
2022-08-31 01:00:00 3936.019287 -9164.561523 588.513000 3805.863281 -82.457878 ... 16.302183 -2153.395020 18.264908 -1933.486938 9.143848
2022-08-31 02:00:00 907.882080 -594.693359 -7568.899902 3586.885986 55.325130 ... 18.259106 -2947.022949 24.863527 -3413.346924 10.468850
2022-08-31 03:00:00 -3093.964111 1685.322266 -144.382187 -2340.815430 -41.063385 ... 25.609144 -2990.720459 35.415722 -4281.487793 19.503969
2022-08-31 04:00:00 -4849.722168 102.794174 1404.632446 3677.335938 -750.013672 ... 36.911377 -2338.827148 46.718651 -4185.157715 32.783611
... ... ... ... ... ... ... ... ... ... ... ...
2022-08-31 19:00:00 4591.011719 731.986877 -259.273438 -1857.692993 118.767563 ... 186.558899 1713.976440 176.009476 2282.935059 151.773544
2022-08-31 20:00:00 5599.786133 -248.086975 1380.527588 -2292.551025 -23.829096 ... 174.183365 1322.323608 164.254425 2094.097168 142.380859
2022-08-31 21:00:00 6806.327637 -1419.354736 1527.216797 -1237.050293 -96.863815 ... 168.595703 1044.313354 158.792480 1799.267090 139.620987
2022-08-31 22:00:00 7623.354980 -1417.829590 788.297424 -1198.392090 21.118551 ... 167.128998 1050.255005 157.198898 1600.380859 141.915787
2022-08-31 23:00:00 6147.047852 1602.511841 -190.753265 -1160.363892 -24.619184 ... 168.154938 871.813843 157.211441 1456.708130 145.764175

shape: (24, 72, 1), freq: h, size: 6.75 KB

The result is a TimeSeries where each row is a forecast horizon (t+1 through t+24) and each column is a feature’s SHAP contribution at that horizon.

We see that the first row is one hour ahead of the end time of the test set, confirming that we’re looking at the forecast of the test set.

4.5 Heatmap Across Horizons#

The heatmap is a powerful visualization unique to multi-horizon forecasting: it shows how feature importance evolves across the forecast horizon. The top panel shows the predicted values; the bottom panel shows SHAP contributions of the top features at each horizon:

[17]:
shap_object = result_single.get_shap_explanation_object()
ax = shap.plots.heatmap(
    shap_object,
    instance_order=np.arange(shap_object.shape[0]),
    show=False,
)
ax.set_xlabel("Horizon lag (0 is the first forecasted step)");
../_images/examples_28-Explainability-examples_40_0.png

A clear pattern emerges: many target lags have peak influence at horizon lag = lag + 24, revealing that the model has learned a daily seasonality pattern. For instance, lag-10 matters most at horizon lag 14 (-10 + 24 = 14), because 24 hours ago is the same time of day.

5. Probabilistic Explainability#

ShapExplainer can also explain the likelihood parameters of probabilistic forecasts.

  • For PyTorch models: All likelihoods are supported (e.g., QuantileRegression, GaussianLikelihood, …).

  • For scikit-learn-like models: Quantile and poisson regression are supported.

For example, with quantile regression, SHAP values are computed for each predicted quantile - letting you understand what drives not just the median forecast, but also the model’s uncertainty bounds.

5.1 Train a Probabilistic Model#

[18]:
prob_model = DLinearModel(
    input_chunk_length=24,
    output_chunk_length=24,
    likelihood=QuantileRegression(quantiles=[0.1, 0.5, 0.9]),
    random_state=42,
).fit(train, future_covariates=future_covariates)

5.2 Quick Evaluation#

We evaluate the prediction intervals using Mean Interval Coverage (MIC) and Mean Interval Width (MIW):

[19]:
prob_preds = prob_model.historical_forecasts(
    predict_likelihood_parameters=True, **hfc_kwargs
)
prob_pred = concatenate(prob_preds)

mic_val = mic(test, prob_pred, q_interval=(0.1, 0.9))
miw_val = miw(test, prob_pred, q_interval=(0.1, 0.9))
print(f"MIC (target ~80%): {mic_val:.1%} | MIW: {miw_val:.2f}")

fig = test.plotly(label="actual")
prob_pred.plotly(fig=fig)
fig.update_layout(yaxis_title="Consumption (MWh)", autosize=True)
MIC (target ~80%): 79.2% | MIW: 25663.69

The three forecasted components correspond to the quantiles: consumption_q0.100, consumption_q0.500, and consumption_q0.900.

5.3 Explaining Quantile Forecasts#

We create an explainer and use .explain_single() to compute SHAP values for the last day’s forecast. We can then inspect any quantile by specifying the component:

[20]:
prob_explainer = ShapExplainer(
    model=prob_model,
    background_series=test,
    background_future_covariates=future_covariates,
    batch_size=4096,
)

result_single = prob_explainer.explain_single(
    foreground_series=test,
    foreground_future_covariates=future_covariates,
)

5.4 Heatmap for the Upper Bound (q=0.9)#

Let’s visualize what features drive the 90th percentile - the upper bound of the prediction interval:

[21]:
shap_object = result_single.get_shap_explanation_object(component="consumption_q0.900")
ax = shap.plots.heatmap(
    shap_object,
    instance_order=np.arange(shap_object.shape[0]),
    show=False,
)
ax.set_xlabel("Horizon lag (0 is the first forecasted step)");
../_images/examples_28-Explainability-examples_50_0.png

The upper-bound forecast is driven by many of the same target lags, but with different magnitudes and patterns compared to the point forecast. Lags like lag-14 and lag-13 show strong influence at the lag + 24 horizon, confirming the model learns daily seasonality for the uncertainty bounds too.

Conclusion#

In this notebook we demonstrated Darts’ ShapExplainer to:

  • Explain any Darts scikit-learn and PyTorch model through a single, unified API.

  • Global explanations: identify which features matter most across many forecasts (beeswarm, bar, dependence plots).

  • Local explanations: understand individual predictions (waterfall, force plot) and how feature importance evolves across the forecast horizon (heatmap).

  • Probabilistic explanations: explain predicted quantiles to understand what drives uncertainty bounds.

Further reading:

[ ]: