## Create accurate and interpretable predictions

According to [2],** ***Temporal Fusion Transformer*** outperforms all prominent Deep Learning models for time series forecasting.**

Including a featured *Gradient Boosting Tree* model for tabular time series data.

But what is **Temporal Fusion Transformer (TFT)[3]** and why is it so interesting?

In this article, we briefly explain the novelties of *Temporal Fusion Transformer* and build an end-to-end project on **Energy Demand Forecasting**. Specifically, we will cover:

- How to prepare our data for the TFT format.
- How to build, train, and evaluate the TFT model.
- How to get predictions on validation data and out-of-sample predictions.
- How to calculate
**feature importances**,**seasonality patterns,**and**extreme events robustness**using the built-in model’s*interpretable attention*mechanism.

Let’s dive in!

For an in-depth analysis of the Temporal Fusion Transformer architecture, check my previous article.

TemporalFusionTransformer (TFT) is a Transformer-based model that leverages self-attention to capture the complex temporal dynamics of multiple time sequences.

TFT supports:

**Multiple time series:**We can train a TFT model on thousands of univariate or multivariate time series.**Multi-Horizon Forecasting:**The model outputs multi-step predictions of one or more target variables — including prediction intervals.**Heterogeneous features:**TFT supports many types of features, including time-variant and static exogenous variables.**Interpretable predictions:**Predictions can be interpreted in terms of variable importance and seasonality.

One of those traits is unique to *Temporal Fusion Transformer*. We will cover this in the next section.

Among notable DL time-series models (e.g., *DeepAR*[4]), TFT stands out because it supports various types of features. These are:

**Time-varying***known***Time-varying***unknown***Time-invariant***real***Time-invariant***categorical*

For example, imagine we have a **sales forecasting case**:

Let’s say we have to predict the sales of 3 products. The `num sales`

is the target variable. The `CPI index`

or the `number of visitors`

are *time-varying unknown* features because they are only known up to prediction time. However, `holidays`

and `special days`

are *time-varying known* events.

The `product id`

is *a time-invariant (static) categorical* feature. Other features which are numerical and not time-dependent such as `yearly_revenue`

can be categorized as *time-invariant real*.

Before moving to our project, we will first show a mini-tutorial on how to convert your data to the **extended time-series format**.

Note:All images and figures in this article are created by the author.

For this tutorial, we use the **TemporalFusionTransformer**** **model from the PyTorch Forecasting library and PyTorch Lightning:

`pip install torch pytorch-lightning pytorch_forecasting`

The whole process involves 3 things:

- Create a pandas dataframe with our time-series data.
- Wrap our dataframe into a
*TimeSeriesDataset*instance. - Pass our
*TimeSeriesDataset*instance to**TemporalFusionTransformer****.**

The *TimeSeriesDataset* is very useful because it helps us specify whether features are time-varying or static. Plus, it’s the only format that **TemporalFusionTransformer**** **accepts.

Let’s create a minimal training dataset to show how *TimeSeriesDataset* works:

We should format our data in the following way: Each colored box represents a different time series, represented by its `group`

value.

The most important column of our dataframe is the `time_idx`

— it determines the sequence of samples. If there are no missing observations, the values should increase by *+1* **for each time-series.**

Next, we wrap our dataframe into a *TimeSeriesDataset* instance:

All arguments are self-explanatory: The `max_encoder_length`

defines the lookback period and `max_prediction_length`

specifies how many datapoints will be predicted. In our case, we look back 3 time steps in the past to output 2 predictions.

The *TimeSeriesDataset* instance now serves as a dataloader. Let’s print a batch and check how our data will be passed to TFT:

This batch contains the training values `[0,1]`

from the first time-series (`group 0`

) and the testing values`[2,3,4]`

. If you rerun this code, you will get different values because the data are shuffled by default.

Our project will use the **ElectricityLoadDiagrams20112014 **[5]** **dataset from UCI. The notebook for this example can be downloaded from here:

This dataset contains the power usage (in KWs) of 370 clients/consumers with a 15-minute frequency. The data span 4 years (2011–2014).

Some consumers were created after 2011, so their power usage initially is zero.

We do data preprocessing according to [3]:

- Aggregate our target variable
`power_usage`

by hour. - Find the earliest date for every time-series where power is non-zero.
- Create new features :
`month`

,`day`

,`hour`

and`day_of_week`

. - Select all days between
`2014–01–01`

and`2014–09–07`

.

Let’s start:

## Download Data

`wget https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip`

!unzip LD2011_2014.txt.zip

## Data Preprocessing

Each column represents a consumer. Most initial `power_usage`

values are 0.

Next, we aggregate to hourly data. Due to the model’s size and complexity, we train our model on 5 consumers only (for those with non-zero values).

Now, we prepare our dataset for the *TimeSeriesDataset* format. Notice that each column represents a different time-series. Hence, we ‘melt’ our dataframe, so that all time-series are stacked vertically instead of horizontally. In the process, we create our new features.

The final preprocessed dataframe is called `time_df`

. Let’s print its contents:

The `time_df`

is now in the proper format for the *TimeSeriesDataset*. As you have guessed by now, since the granularity is hourly, the `hours_from_start`

variable will be the **time index.**

## Exploratory Data Analysis

The choice of 5 consumers/time-series is not random. The `power usage`

of each time-series has different properties, such as the mean value:

`time_df[[‘consumer_id’,’power_usage’]].groupby(‘consumer_id’).mean()`

Let’s plot the first month of every time-series:

There is no noticeable trend, but each time-series has slightly different seasonality and amplitude. We can further experiment and check stationarity, signal decompositions, and so on, but in our case, we focus on the model-building aspect only.

Also, notice that other time-series forecasting methods like **ARIMA** must satisfy a few requirements (for instance, the time-series must first become stationary.) With TFT, we can leave our data as-is.

## Create DataLoaders

In this step, we pass our `time_df`

to the *TimeSeriesDataSet* format which is immensely useful because:

- It spares us from writing our own Dataloader.
- We can specify how TFT will handle the dataset’s features.
- We can normalize our dataset with ease. In our case, normalization is mandatory because all time sequences differ in magnitude. Thus, we use the
**GroupNormalizer**to normalize each time-series individually.

Our model uses a lookback window of one week (7*24) to predict the power usage of the next 24 hours.

Also, notice that the `hours_from_start`

is both the time index and a time-varying feature. The `power_usage`

is our target variable. For the sake of demonstration, our validation set is the last day:

## Baseline Model

Next, the step that almost everyone forgets: A baseline model. Especially in time-series forecasting, you will be surprised at how often a naive predictor outperforms even a fancier model!

As a naive baseline, we predict the power usage curve of the previous day:

## Training the Temporal Fusion Transformer Model

We can train our TFT model using the familiar *Trainer* interface from PyTorch Lightning.

Notice the following things:

- We use the
**EarlyStopping**callback to monitor the validation loss. - We use
**Tensorboard**to log our training and validation metrics. - Our model uses
*Quantile Loss*— a special type of loss that helps us output the prediction intervals. For more on the Quantile Loss function, check this article. - We use 4
*attention heads*, like the original paper.

We are now ready to build and train our model:

That’s it! After 6 epochs, EarlyStopping kicks in and halts training.

## Load and Save the Best Model

Don’t forget to save your model. Although we can pickle it, the safest option is to save the best epoch directly:

`!zip -r model.zip lightning_logs/lightning_logs/version_1/*`

To load the model again, unzip *model.zip *and execute the following — just remember the best model path:

**Check Tensorboard**

Take a closer look at training and validation curves with Tensorboard:

## Model Evaluation

Get predictions on the validation set and calculate the average **P50** (quantile median) **loss**:

The last 2 time-series have a bit higher loss because their relative magnitude is also high.

## Plot Predictions on Validation Data

If we pass the `mode=raw`

on the* predict()* method, we get more information, including predictions for all seven quantiles. We also have access to the attention values (more about that later).

Take a closer look at the `raw_predictions`

variable:

We use the *plot_prediction() *to create our plots. Of course, you could make your own custom plot — the *plot_prediction() *has the extra benefit of adding the attention values.

Note:Our model predicts the next 24 datapointsin one go. This is not a rolling forecasting scenario where a model predicts asinglevalue each time and ‘stitches’ all predictions together.

We create one plot for each consumer (5 in total).

The results are quite impressive.

Our *Temporal Fusion Transformer* model was able to capture the behaviour of all 5 time-series, in terms of both seasonality and magnitude!

Also, notice that:

- We did not perform any hyperparameter tuning.
- We did not implement any fancy feature engineering technique.

In a subsequent section, we show how to improve our model with hyperparameter optimization.

## Plot Predictions For A Specific Time Series

Previously, we plot predictions on the validation data using the `idx`

argument, which iterates over all time-series in our dataset. We can be more specific and output predictions on a specific time-series:

In **Figure 7, **we plot the day-ahead of **MT_004 **consumer for time index=26512.

Remember, our time-indexing column `hours_from_start`

starts from 26304 and we can get predictions from 26388 onwards (because we set earlier `min_encoder_length=max_encoder_length // 2`

which equals `26304 + 168//2=26388`

## Out-of-Sample Forecasts

Let’s create out-of-sample predictions, beyond the final datapoint of validation data — which is `2014–09–07 23:00:00`

All we have to do is to create a new dataframe that contains:

- The number of
`N`

=`max_encoder_length`

past dates, which act as the lookback window — the**encoder data**in TFT terminology. - The future dates of size
`max_prediction_length`

for which we want to compute our predictions — the**decoder data.**

We can create predictions for all 5 of our time-series, or just one. **Figure 7** shows the out-of-sample predictions for consumer **MT_002**:

Accurate forecasting is one thing, but explainability also matters a lot nowadays.

And it’s even worse for Deep Learning models, which are considered black boxes. Methods such as **LIME** and **SHAP** can provide explainability (to some extent) but don’t work well for time-series. Plus, they are external post-hoc methods and are not tied to a particular model.

*Temporal Fusion Transformer* provides three types of interpretability:

**Seasonality-wise:**TFT leverages its novel**Interpretable Multi-Head Attention**mechanism to calculate the importance of past time steps.**Feature-wise:**TFT leverages its**Variable Selection Network**module to calculate the importance of every feature.**Extreme events robustness:**We can investigate how time series behave during rare events

If you want to learn in-depth about the inner workings of **Interpretable Multi-Head Attention **and** Variable Selection Network, **check my previous article.

## Seasonality-wise Interpretability

TFT explores the attention weights to understand the temporal patterns across past time steps.

The gray lines in all previous plots represent the attention scores. Look at those plots again — do you notice anything? **Figure 8** shows the findings of **Figure 7** and also accounts for the attention scores:

The attention scores reveal how impactful are those time steps when the model outputs its prediction. The small peaks reflect the daily seasonality, while the higher peak towards the end probably implies the weekly seasonality.

If we average the attention curves across all timesteps and time-series (not just the 5 ones we used in this tutorial), we will get the symmetrically-looking shape in **Figure 9 **from the TFT paper**:**

Question:What good is this? Can’t we simply estimate seasonality patterns with methods such as ACF plots, time signal decomposotion etc. ?

**Answer: **True. However, studying the attention weights of TFT has extra advantages:

- We can confirm our model captures the apparent seasonal dynamics of our sequences.
- Our model may also reveal hidden patterns because the attention weights of the current input windows consider all past inputs.
- The attention weights plot is not the same as an autocorrelation plot: The autocorrelation plot refers to a particular sequence, while the attention weights here focus on the impact of each timestep by looking across all covariates and time series.

## Feature-wise Interpretability

The **Variable Selection Network** component of TFT can easily estimate the **feature importances:**

In **Figure 10**, we notice the following:

- The
`hour`

and`day_of_week`

have strong scores, both as past observations and future covariates. The benchmark in the original paper shares the same conclusion. - The
`power_usage`

is obviously the most impactful observed covariate. - The
`consumer_id`

is not very significant here because we use only 5 consumers. In the TFT paper, where the authors use all 370 consumers, this variable is more significant.

Note:If your grouping static variable is not important, it is very likely your dataset can also be modeled equally well by a single distribution model (like ARIMA).

## Extreme Event Detection

Time series are notorious for being susceptible to sudden changes in their properties during rare events (also referred to as **shocks**).

Even worse, those events are very elusive. Imagine if your target variable becomes volatile for a brief period because a covariate silently changes behavior:

Is this some random noise or a hidden persistent pattern that escapes our model?

With TFT, we can analyze the robustness of each individual feature across their range of values. Unfortunately, the current dataset does not exhibit volatility or rare events — those are more likely to be found in financial, sales data and so on. Still, we will show how to calculate them:

Some features do have not all their values present in the validation dataset, so we only show the `hour`

and `consumer_id`

:

In both Figures, the results are encouraging. In **Figure 12**, we notice that consumer **MT_004** slightly underperforms compared to other consumers. We could verify this if we normalize the P50 loss of every consumer with their average power usage that we calculated previously.

The gray bars denote the distribution of each variable. One thing I always do is find which values have a low frequency. Then, I check how the model performs in those areas. Hence, you can easily detect if your model captures the behavior of rare events.

In general, you can use this TFT feature to probe your model for weaknesses and proceed to further investigation.

We can seamlessly use *Temporal Fusion Transformer*** **with** Optuna** to perform hyperparameter tuning:

The problem is that since TFT is a Transformer-based model, you will need significant hardware resources!

*Temporal Fusion Transformer* is undoubtedly a milestone for the Time-Series community.

Not only does the model achieves SOTA results, but also provides a framework for the interpretability of predictions. The model is also available in the Darts python library, which is based on the PyTorch Forecasting library.

Finally, if you are curious to learn about the architecture of the *Temporal Fusion Transformer* in detail, check the companion article on the original paper.

Temporal Fusion Transformer: Time Series Forecasting with Deep Learning — Complete Tutorial Republished from Source https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91?source=rss—-7f60cf5620c9—4 via https://towardsdatascience.com/feed

<!–

–>