{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Time Series Mixer (TSMixer)\n", "This notebook walks through how to use Darts' `TSMixerModel` and benchmarks it against `TiDEModel`.\n", "\n", "TSMixer (Time-series Mixer) is an all-MLP architecture for time series forecasting. \n", "\n", "It does so by integrating historical time series data, future known inputs, and static contextual information. The architecture uses a combination of conditional feature mixing and mixer layers to process and combine these different types of data for effective forecasting.\n", "\n", "Translated to Darts, this model supports all types of covariates (past, future, and/or static).\n", "\n", "See the original paper and model description [here](https://arxiv.org/abs/2303.06053).\n", "\n", "According to the authors, the model outperforms several state-of-the-art models on multivariate forecasting tasks.\n", "\n", "Let's see how it performs against `TideModel` on the ETTh1 and ETTh2 datasets." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# fix python path if working locally\n", "from utils import fix_pythonpath_if_working_locally\n", "\n", "fix_pythonpath_if_working_locally()\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "import logging\n", "\n", "logging.disable(logging.CRITICAL)\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", "\n", "from darts import concatenate\n", "from darts.dataprocessing.transformers.scaler import Scaler\n", "from darts.datasets import ETTh1Dataset, ETTh2Dataset\n", "from darts.metrics import mql\n", "from darts.models import TiDEModel, TSMixerModel\n", "from darts.utils.callbacks import TFMProgressBar\n", "from darts.utils.likelihood_models.torch import QuantileRegression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data Loading and preparation\n", "We consider the ETTh1 and ETTh2 datasets which contain hourly multivariate data of an electricity transformer (load, oil temperature, ...).\n", "You can find more information [here](https://unit8co.github.io/darts/generated_api/darts.datasets.html#darts.datasets.ETTh1Dataset).\n", "\n", "We will add static information to each transformer time series, that identifies whether it is the `ETTh1` or `ETTh2` transformer.\n", "Both TSMixer and TiDE can levarage this information." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| component | \n", "HUFL | \n", "HULL | \n", "MUFL | \n", "MULL | \n", "LUFL | \n", "LULL | \n", "OT | \n", "
|---|---|---|---|---|---|---|---|
| date | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
| 2016-07-01 00:00:00 | \n", "5.827 | \n", "2.009 | \n", "1.599 | \n", "0.462 | \n", "4.203 | \n", "1.340 | \n", "30.531000 | \n", "
| 2016-07-01 01:00:00 | \n", "5.693 | \n", "2.076 | \n", "1.492 | \n", "0.426 | \n", "4.142 | \n", "1.371 | \n", "27.787001 | \n", "
| 2016-07-01 02:00:00 | \n", "5.157 | \n", "1.741 | \n", "1.279 | \n", "0.355 | \n", "3.777 | \n", "1.218 | \n", "27.787001 | \n", "
| 2016-07-01 03:00:00 | \n", "5.090 | \n", "1.942 | \n", "1.279 | \n", "0.391 | \n", "3.807 | \n", "1.279 | \n", "25.044001 | \n", "
| 2016-07-01 04:00:00 | \n", "5.358 | \n", "1.942 | \n", "1.492 | \n", "0.462 | \n", "3.868 | \n", "1.279 | \n", "21.948000 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 2018-06-26 15:00:00 | \n", "-1.674 | \n", "3.550 | \n", "-5.615 | \n", "2.132 | \n", "3.472 | \n", "1.523 | \n", "10.904000 | \n", "
| 2018-06-26 16:00:00 | \n", "-5.492 | \n", "4.287 | \n", "-9.132 | \n", "2.274 | \n", "3.533 | \n", "1.675 | \n", "11.044000 | \n", "
| 2018-06-26 17:00:00 | \n", "2.813 | \n", "3.818 | \n", "-0.817 | \n", "2.097 | \n", "3.716 | \n", "1.523 | \n", "10.271000 | \n", "
| 2018-06-26 18:00:00 | \n", "9.243 | \n", "3.818 | \n", "5.472 | \n", "2.097 | \n", "3.655 | \n", "1.432 | \n", "9.778000 | \n", "
| 2018-06-26 19:00:00 | \n", "10.114 | \n", "3.550 | \n", "6.183 | \n", "1.564 | \n", "3.716 | \n", "1.462 | \n", "9.567000 | \n", "
17420 rows × 7 columns
\n", "| \n", " | q_0.05 | \n", "q_0.1 | \n", "q_0.2 | \n", "q_0.5 | \n", "q_0.8 | \n", "q_0.9 | \n", "q_0.95 | \n", "
|---|---|---|---|---|---|---|---|
| ETTh1_TSM | \n", "0.501772 | \n", "0.769545 | \n", "1.136141 | \n", "1.568439 | \n", "1.098847 | \n", "0.721835 | \n", "0.442062 | \n", "
| ETTh1_TiDE | \n", "0.573716 | \n", "0.885452 | \n", "1.298672 | \n", "1.671870 | \n", "1.151501 | \n", "0.727515 | \n", "0.446724 | \n", "
| ETTh2_TSM | \n", "0.659187 | \n", "1.030655 | \n", "1.508628 | \n", "1.932923 | \n", "1.317960 | \n", "0.857147 | \n", "0.524620 | \n", "
| ETTh2_TiDE | \n", "0.627251 | \n", "0.982114 | \n", "1.450893 | \n", "1.897117 | \n", "1.323661 | \n", "0.862239 | \n", "0.528638 | \n", "