{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Transformer Model\n", "In this notebook, we show an example of how Transformer can be used with darts.\n", "If you are new to darts, we recommend you first follow the [quick start](https://unit8co.github.io/darts/quickstart/00-quickstart.html) notebook." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# fix python path if working locally\n", "from utils import fix_pythonpath_if_working_locally\n", "\n", "fix_pythonpath_if_working_locally()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "import warnings\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from darts.dataprocessing.transformers import Scaler\n", "from darts.datasets import AirPassengersDataset, SunspotsDataset\n", "from darts.metrics import mape\n", "from darts.models import ExponentialSmoothing, TransformerModel\n", "from darts.utils.statistics import check_seasonality\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "import logging\n", "\n", "logging.disable(logging.CRITICAL)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Air Passengers Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will test the performance of the transformer architecture on the 'air passengers' dataset." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Read data:\n", "series = AirPassengersDataset().load().astype(np.float32)\n", "\n", "# Create training and validation sets:\n", "train, val = series.split_after(pd.Timestamp(\"19590101\"))\n", "\n", "# Normalize the time series (note: we avoid fitting the transformer on the validation set)\n", "# Change name\n", "scaler = Scaler()\n", "train_scaled = scaler.fit_transform(train)\n", "val_scaled = scaler.transform(val)\n", "series_scaled = scaler.transform(series)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"the 'air passengers' dataset has 144 data points\"" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f\"the 'air passengers' dataset has {len(series)} data points\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We train a standard transformer architecture with default hyperparameters, tweaking only two of them:\n", "\n", "* _d\\_model_, the input dimensionality of the transformer architecture (*after* performing time series embedding). Its default value is 512. We lower the value from 512 to 64, since it is hard to learn such an high-dimensional representation from an univariate time series\n", "* _nhead_, the number of heads in the multi-head attention mechanism. We increase the value from 8 to 32. This means that we compute multi-head attention with 32 heads of size _d\\_model_/_nhead_=64/32=2 each. This way, we obtain low-dimensional heads that are hopefully suitable to learn from univariate time series\n", "\n", "The goal is to perform one-step forecasting." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [], "source": [ "my_model = TransformerModel(\n", " input_chunk_length=12,\n", " output_chunk_length=1,\n", " batch_size=32,\n", " n_epochs=200,\n", " model_name=\"air_transformer\",\n", " nr_epochs_val_period=10,\n", " d_model=16,\n", " nhead=8,\n", " num_encoder_layers=2,\n", " num_decoder_layers=2,\n", " dim_feedforward=128,\n", " dropout=0.1,\n", " activation=\"relu\",\n", " random_state=42,\n", " save_checkpoints=True,\n", " force_reset=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2f9a2a1d3b8c4b33a99279fb2373d78b", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/200 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# this function evaluates a model on a given validation set for n time-steps\n", "def eval_model(model, n, series, val_series):\n", " pred_series = model.predict(n=n)\n", " plt.figure(figsize=(8, 5))\n", " series.plot(label=\"actual\")\n", " pred_series.plot(label=\"forecast\")\n", " plt.title(f\"MAPE: {mape(pred_series, val_series):.2f}%\")\n", " plt.legend()\n", "\n", "\n", "eval_model(my_model, 26, series_scaled, val_scaled)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, by using the best model obtained over training, according to validation loss:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "best_model = TransformerModel.load_from_checkpoint(\n", " model_name=\"air_transformer\", best=True\n", ")\n", "eval_model(best_model, 26, series_scaled, val_scaled)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's backtest our `Transformer` model to evaluates its performance at a forecast horizon of 6 months:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fb9d97d905246dba0903dfa6b2e19b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/19 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8, 5))\n", "series_scaled.plot(label=\"actual\", lw=2)\n", "backtest_series.plot(label=\"backtest\", lw=2)\n", "plt.legend()\n", "plt.title(\"Backtest, starting Jan 1959, with a 6-months horizon\")\n", "print(\n", " \"MAPE: {:.2f}%\".format(\n", " mape(\n", " scaler.inverse_transform(series_scaled),\n", " scaler.inverse_transform(backtest_series),\n", " )\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Monthly Sun spots Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's test the transformer architecture on a more complex dataset, the 'monthly sunspots'. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "series_sunspot = SunspotsDataset().load().astype(np.float32)\n", "\n", "series_sunspot.plot()\n", "check_seasonality(series_sunspot, max_lag=240)\n", "\n", "train_sp, val_sp = series_sunspot.split_after(pd.Timestamp(\"19401001\"))\n", "\n", "scaler_sunspot = Scaler()\n", "train_sp_scaled = scaler_sunspot.fit_transform(train_sp)\n", "val_sp_scaled = scaler_sunspot.transform(val_sp)\n", "series_sp_scaled = scaler_sunspot.transform(series_sunspot)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"the 'monthly sun spots' dataset has 2820 data points\"" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f\"the 'monthly sun spots' dataset has {len(series_sunspot)} data points\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's perform one-step ahead forecasting." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "my_model_sp = TransformerModel(\n", " batch_size=32,\n", " input_chunk_length=125,\n", " output_chunk_length=36,\n", " n_epochs=20,\n", " model_name=\"sun_spots_transformer\",\n", " nr_epochs_val_period=5,\n", " d_model=16,\n", " nhead=4,\n", " num_encoder_layers=2,\n", " num_decoder_layers=2,\n", " dim_feedforward=128,\n", " dropout=0.1,\n", " random_state=42,\n", " optimizer_kwargs={\"lr\": 1e-3},\n", " save_checkpoints=True,\n", " force_reset=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ccd2a30ded334bc2a68338f7a77bc6d0", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20 [00:00