{ "cells": [ { "cell_type": "markdown", "id": "da55dd6c", "metadata": {}, "source": [ "# Chronos-2 Foundation Model\n", "In this notebook, we will show how to use Chronos-2 in Darts. If you are new to Darts, please check out the [Quickstart Guide](https://unit8co.github.io/darts/quickstart/00-quickstart.html) before proceeding.\n", "\n", "Chronos-2 is a time series foundation model for zero-shot forecasting. That means that it can be used for forecasting **without any training or fine-tuning** since it has already been pre-trained on large-scale time series data. Chronos-2 supports multivariate time series forecasting with [covariates](https://unit8co.github.io/darts/userguide/covariates.html) (exogenous variables) and can produce probabilistic forecasts.\n", "\n", "Check out the [Amazon Science Blog](https://www.amazon.science/blog/introducing-chronos-2-from-univariate-to-universal-forecasting) and the [original paper](https://arxiv.org/abs/2510.15821) for technical details." ] }, { "cell_type": "markdown", "id": "9ad51937", "metadata": {}, "source": [ "
\n", " Fine-tuning Chronos-2 on your own data is not yet supported in Darts, but may be added in the future.\n", "
" ] }, { "cell_type": "code", "execution_count": 1, "id": "310fa52a", "metadata": {}, "outputs": [], "source": [ "# fix python path if working locally\n", "from utils import fix_pythonpath_if_working_locally\n", "\n", "fix_pythonpath_if_working_locally()\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "id": "bfa59f65", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "id": "d510b54b", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "import numpy as np\n", "\n", "from darts.datasets import ElectricityConsumptionZurichDataset\n", "from darts.metrics import mae, mic, miw\n", "from darts.models import Chronos2Model\n", "from darts.utils.likelihood_models import QuantileRegression\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "import logging\n", "\n", "logging.disable(logging.CRITICAL)" ] }, { "cell_type": "markdown", "id": "6b82a07a", "metadata": {}, "source": [ "## Data Preparation" ] }, { "cell_type": "markdown", "id": "70d7e392", "metadata": {}, "source": [ "Here, we will use the [Electricity Consumption Zurich Dataset](https://unit8co.github.io/darts/generated_api/darts.datasets.html#darts.datasets.ElectricityConsumptionZurichDataset), which records the electricity consumption of households & SMEs (`\"Value_NE5\"` column) and business & services (`\"Value_NE7\"`) in Zurich, Switzerland, along with weather covariates such as temperature (`\"T [°C]\"`) and humidity (`\"Hr [%Hr]\"`).\n", "Values are recorded every 15 minutes between January 2015 and August 2022.\n", "\n", "
\n", "\n", "Train-Test Split\n", "\n", "Even though Chronos-2 is pre-trained already, we still need to split the data into training and test sets. That is because `Chronos2Model` follows the Darts unified interface and will require calling the `fit()` method before forecasting. However, no training or fine-tuning will be performed during the `fit()` call.\n", "\n", "
\n", "\n", "
\n", "\n", "Data Scaling\n", "\n", "Unlike other deep learning models in Darts, Chronos-2 does not require data scaling since it has its own internal data normalization mechanism. Therefore, we will skip the scaling step in this notebook.\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 4, "id": "2f87bcc5", "metadata": {}, "outputs": [], "source": [ "# convert to float32 as Chronos-2 works with float32 input\n", "data = ElectricityConsumptionZurichDataset().load().astype(np.float32)\n", "# extract households energy consumption\n", "ts_energy = data[\"Value_NE5\"]\n", "# extract temperature, solar irradiation and rain duration\n", "ts_weather = data[[\"T [°C]\", \"StrGlo [W/m2]\", \"RainDur [min]\"]]\n", "# split into train and validation sets by last 7 days\n", "train_energy, val_energy = ts_energy.split_before(len(ts_energy) - 7 * 24 * 4)" ] }, { "cell_type": "markdown", "id": "a3887f37", "metadata": {}, "source": [ "Let's quickly visualize the last 7 days of the electricity consumption data." ] }, { "cell_type": "code", "execution_count": 5, "id": "3b43a60a", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "val_energy.plot(label=\"consumption\");" ] }, { "cell_type": "markdown", "id": "46c1c07f", "metadata": {}, "source": [ "## Model Creation" ] }, { "cell_type": "markdown", "id": "30825d45", "metadata": {}, "source": [ "Chronos-2 supports two types of forecasting outputs:\n", "- **Deterministic** forecasts (**default**): single point estimates for each future time step.\n", "- **Probabilistic** forecasts: multiple samples for each future time step, which can be used to estimate prediction intervals. To enable probabilistic forecasting, set `likelihood=QuantileRegression([...])` when creating the model. The list of quantiles used here must be a subset of Chronos-2 supported quantiles: `[0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]`.\n", "\n", "
\n", "\n", "Lookback and Forward Windows\n", "\n", "Under the hood, Chronos-2 is no different from other [Torch Forecasting Models (TFMs)](https://unit8co.github.io/darts/userguide/torch_forecasting_models.html) in Darts and most hyperparameters from TFMs can be applied here as well. In particular, you can control the length of the lookback window and the forward window using the `input_chunk_length` and `output_chunk_length` parameters, respectively.\n", "\n", "- `input_chunk_length`: the number of time steps of history the model takes as input when making a forecast. Maximum is **8192** for Chronos-2.\n", "- `output_chunk_length`: the number of time steps the model outputs in one forward pass. If the forecast horizon is longer than this value, the model consumes its own previous predictions to produce further forecasts. This is known as the autoregressive forecasting. Maximum is **1024** for Chronos-2.\n", "\n", "![figure0](https://unit8co.github.io/darts/_images/tfm.png)\n", "\n", "See the [Torch Forecasting Models User Guide](https://unit8co.github.io/darts/userguide/torch_forecasting_models.html) for more details.\n", "\n", "
\n", "\n", "
\n", "\n", "Model Downloading and Caching\n", "\n", "When creating a `Chronos2Model` instance for the first time, the pre-trained model checkpoint will be automatically downloaded from [amazon/chronos-2](https://huggingface.co/amazon/chronos-2) hosted on Hugging Face Hub and cached locally. Subsequent usage of `Chronos2Model` will NOT re-download the files but use the cached version instead.\n", "\n", "If you would like to download or load the model checkpoint to a custom directory, set `local_dir` argument when creating the model. For example:\n", "\n", "```python\n", "model = Chronos2Model(\n", " input_chunk_length=168,\n", " output_chunk_length=24,\n", " local_dir=\"path/to/your/directory\"\n", ")\n", "```\n", "\n", "
\n", "\n", "
\n", "\n", "Using Other Checkpoints\n", "\n", "Other Chronos-2 checkpoints might be available in the future. You can specify a different checkpoint on Hugging Face Hub by setting the `hub_model_name` and `hub_model_revision` (optional) arguments when creating the model. For example:\n", "\n", "```python\n", "model = Chronos2Model(\n", " input_chunk_length=168,\n", " output_chunk_length=24,\n", " hub_model_name=\"amazon/chronos-2-some-other-checkpoint\",\n", " hub_model_revision=None, # e.g., branch, tag, or commit ID\n", ")\n", "```\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 6, "id": "045aa20c", "metadata": {}, "outputs": [], "source": [ "# use last 30 days of data to predict next 7 days\n", "model = Chronos2Model(\n", " input_chunk_length=30 * 24 * 4,\n", " output_chunk_length=7 * 24 * 4,\n", ")" ] }, { "cell_type": "markdown", "id": "a6607bf5", "metadata": {}, "source": [ "## Model Training\n", "Here, we will call the `fit()` method to \"train\" the model on the training set. Note that no actual training or fine-tuning will be performed since Chronos-2 is already pre-trained." ] }, { "cell_type": "code", "execution_count": 7, "id": "a447aec7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Chronos2Model(output_chunk_shift=0, likelihood=None, hub_model_name=amazon/chronos-2, hub_model_revision=None, local_dir=None, input_chunk_length=2880, output_chunk_length=672)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(\n", " series=train_energy,\n", " verbose=True,\n", ")" ] }, { "cell_type": "markdown", "id": "ea528f34", "metadata": {}, "source": [ "## Forecasting\n", "We now perform a one-shot forecast for the next 7 days using Chronos-2. We then compare the forecast against the actual values from the validation set." ] }, { "cell_type": "code", "execution_count": 8, "id": "b076cafe", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32756a350f8942ee884acebddebbfdf3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | | 0/? [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pred = model.predict(\n", " n=7 * 24 * 4,\n", " series=train_energy,\n", ")\n", "val_energy.plot(label=\"actual\")\n", "pred.plot(label=\"forecast\");" ] }, { "cell_type": "markdown", "id": "c0a0144c", "metadata": {}, "source": [ "You can see that Chronos-2 is able to produce qualitatively accurate forecasts without any training or fine-tuning! Let's evaluate the forecast accuracy using the [Mean Absolute Error (MAE)](https://unit8co.github.io/darts/generated_api/darts.metrics.metrics.html#darts.metrics.metrics.mae) metric." ] }, { "cell_type": "code", "execution_count": 9, "id": "c4e7e695", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MAE on validation set: 653.57\n" ] } ], "source": [ "mae_val = mae(val_energy, pred)\n", "print(f\"MAE on validation set: {mae_val:.2f}\")" ] }, { "cell_type": "markdown", "id": "d511f0d0", "metadata": {}, "source": [ "## Forecasting with Covariates\n", "Recall that Chronos-2 supports forecasting with covariates (exogenous variables). Since no training is required, we do not worry about hyperparameter tuning for covariates. Forecasting with covariates is as simple as passing the covariate series to the `fit()` and `predict()` methods!\n", "\n", "We use weather variables as future covariates to help forecast the electricity consumption. We then compare the forecast (with and without covariates) against the actual values from the validation set.\n", "\n", "
\n", " The weather variables here are actual measurements from a weather station and not forecasts. The results shown here are optimistic and for demonstration purposes only. In practice, you should supply weather forecasts as future covariates to get realistic results.\n", "
" ] }, { "cell_type": "code", "execution_count": 10, "id": "addc0b08", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "407106836a2146c28fe8d7f7544cffd2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | | 0/? [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = Chronos2Model(\n", " input_chunk_length=30 * 24 * 4,\n", " output_chunk_length=7 * 24 * 4,\n", ")\n", "model.fit(\n", " series=train_energy,\n", " future_covariates=ts_weather,\n", " verbose=True,\n", ")\n", "pred_cov = model.predict(\n", " n=7 * 24 * 4,\n", " series=train_energy,\n", " future_covariates=ts_weather,\n", ")\n", "val_energy.plot(label=\"actual\")\n", "pred_cov.plot(label=\"forecast with covariates\")\n", "pred.plot(label=\"forecast without covariates\");" ] }, { "cell_type": "markdown", "id": "5e3d38ca", "metadata": {}, "source": [ "With future covariates such as weather, we see that the forecast accuracy has improved on the 7-day horizon! Covariate support from Chronos-2 can be very useful when exogenous variables have a strong influence on the target series." ] }, { "cell_type": "code", "execution_count": 11, "id": "c8fe6715", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MAE on validation set with covariates: 466.05\n" ] } ], "source": [ "mae_cov = mae(val_energy, pred_cov)\n", "print(f\"MAE on validation set with covariates: {mae_cov:.2f}\")" ] }, { "cell_type": "markdown", "id": "9591d946", "metadata": {}, "source": [ "## Probabilistic Forecasting\n", "Here, we show how to perform probabilistic forecasting with Chronos-2 by using [`QuantileRegression`](https://unit8co.github.io/darts/generated_api/darts.utils.likelihood_models.sklearn.html#darts.utils.likelihood_models.sklearn.QuantileRegression) likelihood. The quantiles passed to `QuantileRegression` must be a subset of pre-trained quantiles supported by Chronos-2 (see \"Model Creation\" section above).\n", "\n", "Because sampling with large models like Chronos-2 can be computationally expensive, we here call `predict()` with `predict_likelihood_parameters=True` to obtain quantile estimates directly without sampling. However, if the forecast horizon is longer than `output_chunk_length` (i.e., auto-regressive forecasting is required), you must call `predict()` with a large enough `num_samples` value (e.g., 1000) to generate probabilistic forecasts via Monte Carlo sampling." ] }, { "cell_type": "code", "execution_count": 12, "id": "14a424fc", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fff6dd69ddea481097232ea71176c65e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: | | 0/? [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = Chronos2Model(\n", " input_chunk_length=30 * 24 * 4,\n", " output_chunk_length=7 * 24 * 4,\n", " likelihood=QuantileRegression(quantiles=[0.1, 0.5, 0.9]),\n", ")\n", "model.fit(\n", " series=train_energy,\n", " future_covariates=ts_weather,\n", " verbose=True,\n", ")\n", "pred_prob = model.predict(\n", " n=7 * 24 * 4,\n", " series=train_energy,\n", " future_covariates=ts_weather,\n", " predict_likelihood_parameters=True,\n", ")\n", "val_energy.plot(label=\"actual\")\n", "pred_prob.plot(label=\"forecast\");" ] }, { "cell_type": "markdown", "id": "a570c81c", "metadata": {}, "source": [ "For probabilistic forecasts, we can evaluate the forecast quality by computing the [Mean Interval Coverage (MIC)](https://unit8co.github.io/darts/generated_api/darts.metrics.metrics.html#darts.metrics.metrics.mic) (the share of actuals inside the prediction intervals) and [Mean Interval Width (MIW)](https://unit8co.github.io/darts/generated_api/darts.metrics.metrics.html#darts.metrics.metrics.miw) (the width of the prediction intervals) metrics to evaluate the quality of the prediction intervals.\n", "\n", "For MIC, we expect a value close to the nominal coverage of the prediction intervals (i.e., 80% for the (0.1, 0.9) interval). For MIW, lower values indicate narrower prediction intervals and thus better forecast quality when MIC is satisfactory." ] }, { "cell_type": "code", "execution_count": 13, "id": "70e28256", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MIC on validation set with covariates: 82.74%\n", "MIW on validation set with covariates: 1719.57\n" ] } ], "source": [ "mic_prob = mic(val_energy, pred_prob, q_interval=(0.1, 0.9))\n", "miw_prob = miw(val_energy, pred_prob, q_interval=(0.1, 0.9))\n", "print(f\"MIC on validation set with covariates: {mic_prob:.2%}\")\n", "print(f\"MIW on validation set with covariates: {miw_prob:.2f}\")" ] }, { "cell_type": "markdown", "id": "d8f08fb4", "metadata": {}, "source": [ "## Final Remarks\n", "Just like other Torch Forecasting Models in Darts, Chronos-2 supports historical forecasting (`historical_forecasts()`), backtesting (`backtest()`), residual computation (`residuals()`), custom PyTorch Lightning arguments (`pl_trainer_kwargs`), and more. Check out the following resources to learn more about those topics:\n", "- [Backtesting: simulate historical forecasting](https://unit8co.github.io/darts/quickstart/00-quickstart.html#Backtesting:-simulate-historical-forecasting)\n", "- [Torch Forecasting Models User Guide](https://unit8co.github.io/darts/userguide/torch_forecasting_models.html)\n", "- [Using Torch Models with GPUs and TPUs](https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html)\n", "- [Chronos-2 Model API](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.chronos2_model.html)" ] }, { "cell_type": "code", "execution_count": null, "id": "c6b18c34", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "darts", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }