Create accurate and interpretable predictions
According to , Temporal Fusion Transformer outperforms all prominent Deep Learning models for time series forecasting.
Including a featured Gradient Boosting Tree model for tabular time series data.
But what is Temporal Fusion Transformer (TFT) and why is it so interesting?
In this article, we briefly explain the novelties of Temporal Fusion Transformer and build an end-to-end project on Energy Demand Forecasting. Specifically, we will cover:
- How to prepare our data for the TFT format.
- How to build, train, and evaluate the TFT model.
- How to get predictions on validation data and out-of-sample predictions.
- How to calculate feature importances, seasonality patterns, and extreme events robustness using the built-in model’s interpretable attention mechanism.
Let’s dive in!
For an in-depth analysis of the Temporal Fusion Transformer architecture, check my previous article.
Temporal Fusion Transformer (TFT) is a Transformer-based model that leverages self-attention to capture the complex temporal dynamics of multiple time sequences.
- Multiple time series: We can train a TFT model on thousands of univariate or multivariate time series.
- Multi-Horizon Forecasting: The model outputs multi-step predictions of one or more target variables — including prediction intervals.
- Heterogeneous features: TFT supports many types of features, including time-variant and static exogenous variables.
- Interpretable predictions: Predictions can be interpreted in terms of variable importance and seasonality.
One of those traits is unique to Temporal Fusion Transformer. We will cover this in the next section.
Among notable DL time-series models (e.g., DeepAR), TFT stands out because it supports various types of features. These are:
- Time-varying known
- Time-varying unknown
- Time-invariant real
- Time-invariant categorical
For example, imagine we have a sales forecasting case:
Let’s say we have to predict the sales of 3 products. The
num sales is the target variable. The
CPI index or the
number of visitors are time-varying unknown features because they are only known up to prediction time. However,
special days are time-varying known events.
product id is a time-invariant (static) categorical feature. Other features which are numerical and not time-dependent such as
yearly_revenue can be categorized as time-invariant real.
Before moving to our project, we will first show a mini-tutorial on how to convert your data to the extended time-series format.
Note: All images and figures in this article are created by the author.
For this tutorial, we use the
TemporalFusionTransformer model from the PyTorch Forecasting library and PyTorch Lightning:
pip install torch pytorch-lightning pytorch_forecasting
The whole process involves 3 things:
- Create a pandas dataframe with our time-series data.
- Wrap our dataframe into a TimeSeriesDataset instance.
- Pass our TimeSeriesDataset instance to
The TimeSeriesDataset is very useful because it helps us specify whether features are time-varying or static. Plus, it’s the only format that
Let’s create a minimal training dataset to show how TimeSeriesDataset works:
We should format our data in the following way: Each colored box represents a different time series, represented by its
The most important column of our dataframe is the
time_idx — it determines the sequence of samples. If there are no missing observations, the values should increase by +1 for each time-series.
Next, we wrap our dataframe into a TimeSeriesDataset instance:
All arguments are self-explanatory: The
max_encoder_length defines the lookback period and
max_prediction_length specifies how many datapoints will be predicted. In our case, we look back 3 time steps in the past to output 2 predictions.
The TimeSeriesDataset instance now serves as a dataloader. Let’s print a batch and check how our data will be passed to TFT:
This batch contains the training values
[0,1] from the first time-series (
group 0) and the testing values
[2,3,4]. If you rerun this code, you will get different values because the data are shuffled by default.
Our project will use the ElectricityLoadDiagrams20112014  dataset from UCI. The notebook for this example can be downloaded from here:
This dataset contains the power usage (in KWs) of 370 clients/consumers with a 15-minute frequency. The data span 4 years (2011–2014).
Some consumers were created after 2011, so their power usage initially is zero.
We do data preprocessing according to :
- Aggregate our target variable
- Find the earliest date for every time-series where power is non-zero.
- Create new features :
- Select all days between
Each column represents a consumer. Most initial
power_usage values are 0.
Next, we aggregate to hourly data. Due to the model’s size and complexity, we train our model on 5 consumers only (for those with non-zero values).
Now, we prepare our dataset for the TimeSeriesDataset format. Notice that each column represents a different time-series. Hence, we ‘melt’ our dataframe, so that all time-series are stacked vertically instead of horizontally. In the process, we create our new features.
The final preprocessed dataframe is called
time_df. Let’s print its contents:
time_df is now in the proper format for the TimeSeriesDataset. As you have guessed by now, since the granularity is hourly, the
hours_from_start variable will be the time index.
Exploratory Data Analysis
The choice of 5 consumers/time-series is not random. The
power usage of each time-series has different properties, such as the mean value:
Let’s plot the first month of every time-series:
There is no noticeable trend, but each time-series has slightly different seasonality and amplitude. We can further experiment and check stationarity, signal decompositions, and so on, but in our case, we focus on the model-building aspect only.
Also, notice that other time-series forecasting methods like ARIMA must satisfy a few requirements (for instance, the time-series must first become stationary.) With TFT, we can leave our data as-is.
In this step, we pass our
time_df to the TimeSeriesDataSet format which is immensely useful because:
- It spares us from writing our own Dataloader.
- We can specify how TFT will handle the dataset’s features.
- We can normalize our dataset with ease. In our case, normalization is mandatory because all time sequences differ in magnitude. Thus, we use the GroupNormalizer to normalize each time-series individually.
Our model uses a lookback window of one week (7*24) to predict the power usage of the next 24 hours.
Also, notice that the
hours_from_start is both the time index and a time-varying feature. The
power_usage is our target variable. For the sake of demonstration, our validation set is the last day:
Next, the step that almost everyone forgets: A baseline model. Especially in time-series forecasting, you will be surprised at how often a naive predictor outperforms even a fancier model!
As a naive baseline, we predict the power usage curve of the previous day:
Training the Temporal Fusion Transformer Model
We can train our TFT model using the familiar Trainer interface from PyTorch Lightning.
Notice the following things:
- We use the EarlyStopping callback to monitor the validation loss.
- We use Tensorboard to log our training and validation metrics.
- Our model uses Quantile Loss — a special type of loss that helps us output the prediction intervals. For more on the Quantile Loss function, check this article.
- We use 4 attention heads, like the original paper.
We are now ready to build and train our model:
That’s it! After 6 epochs, EarlyStopping kicks in and halts training.
Load and Save the Best Model
Don’t forget to save your model. Although we can pickle it, the safest option is to save the best epoch directly:
!zip -r model.zip lightning_logs/lightning_logs/version_1/*
To load the model again, unzip model.zip and execute the following — just remember the best model path:
Take a closer look at training and validation curves with Tensorboard:
Get predictions on the validation set and calculate the average P50 (quantile median) loss:
The last 2 time-series have a bit higher loss because their relative magnitude is also high.
Plot Predictions on Validation Data
If we pass the
mode=raw on the predict() method, we get more information, including predictions for all seven quantiles. We also have access to the attention values (more about that later).
Take a closer look at the
We use the plot_prediction() to create our plots. Of course, you could make your own custom plot — the plot_prediction() has the extra benefit of adding the attention values.
Note: Our model predicts the next 24 datapoints in one go. This is not a rolling forecasting scenario where a model predicts a single value each time and ‘stitches’ all predictions together.
We create one plot for each consumer (5 in total).
The results are quite impressive.
Our Temporal Fusion Transformer model was able to capture the behaviour of all 5 time-series, in terms of both seasonality and magnitude!
Also, notice that:
- We did not perform any hyperparameter tuning.
- We did not implement any fancy feature engineering technique.
In a subsequent section, we show how to improve our model with hyperparameter optimization.
Plot Predictions For A Specific Time Series
Previously, we plot predictions on the validation data using the
idx argument, which iterates over all time-series in our dataset. We can be more specific and output predictions on a specific time-series:
In Figure 7, we plot the day-ahead of MT_004 consumer for time index=26512.
Remember, our time-indexing column
hours_from_start starts from 26304 and we can get predictions from 26388 onwards (because we set earlier
min_encoder_length=max_encoder_length // 2 which equals
26304 + 168//2=26388
Let’s create out-of-sample predictions, beyond the final datapoint of validation data — which is
All we have to do is to create a new dataframe that contains:
- The number of
max_encoder_lengthpast dates, which act as the lookback window — the encoder data in TFT terminology.
- The future dates of size
max_prediction_lengthfor which we want to compute our predictions — the decoder data.
We can create predictions for all 5 of our time-series, or just one. Figure 7 shows the out-of-sample predictions for consumer MT_002:
Accurate forecasting is one thing, but explainability also matters a lot nowadays.
And it’s even worse for Deep Learning models, which are considered black boxes. Methods such as LIME and SHAP can provide explainability (to some extent) but don’t work well for time-series. Plus, they are external post-hoc methods and are not tied to a particular model.
Temporal Fusion Transformer provides three types of interpretability:
- Seasonality-wise: TFT leverages its novel Interpretable Multi-Head Attention mechanism to calculate the importance of past time steps.
- Feature-wise: TFT leverages its Variable Selection Network module to calculate the importance of every feature.
- Extreme events robustness: We can investigate how time series behave during rare events
If you want to learn in-depth about the inner workings of Interpretable Multi-Head Attention and Variable Selection Network, check my previous article.
TFT explores the attention weights to understand the temporal patterns across past time steps.
The gray lines in all previous plots represent the attention scores. Look at those plots again — do you notice anything? Figure 8 shows the findings of Figure 7 and also accounts for the attention scores:
The attention scores reveal how impactful are those time steps when the model outputs its prediction. The small peaks reflect the daily seasonality, while the higher peak towards the end probably implies the weekly seasonality.
If we average the attention curves across all timesteps and time-series (not just the 5 ones we used in this tutorial), we will get the symmetrically-looking shape in Figure 9 from the TFT paper:
Question: What good is this? Can’t we simply estimate seasonality patterns with methods such as ACF plots, time signal decomposotion etc. ?
Answer: True. However, studying the attention weights of TFT has extra advantages:
- We can confirm our model captures the apparent seasonal dynamics of our sequences.
- Our model may also reveal hidden patterns because the attention weights of the current input windows consider all past inputs.
- The attention weights plot is not the same as an autocorrelation plot: The autocorrelation plot refers to a particular sequence, while the attention weights here focus on the impact of each timestep by looking across all covariates and time series.
The Variable Selection Network component of TFT can easily estimate the feature importances:
In Figure 10, we notice the following:
day_of_weekhave strong scores, both as past observations and future covariates. The benchmark in the original paper shares the same conclusion.
power_usageis obviously the most impactful observed covariate.
consumer_idis not very significant here because we use only 5 consumers. In the TFT paper, where the authors use all 370 consumers, this variable is more significant.
Note: If your grouping static variable is not important, it is very likely your dataset can also be modeled equally well by a single distribution model (like ARIMA).
Extreme Event Detection
Time series are notorious for being susceptible to sudden changes in their properties during rare events (also referred to as shocks).
Even worse, those events are very elusive. Imagine if your target variable becomes volatile for a brief period because a covariate silently changes behavior:
Is this some random noise or a hidden persistent pattern that escapes our model?
With TFT, we can analyze the robustness of each individual feature across their range of values. Unfortunately, the current dataset does not exhibit volatility or rare events — those are more likely to be found in financial, sales data and so on. Still, we will show how to calculate them:
Some features do have not all their values present in the validation dataset, so we only show the
In both Figures, the results are encouraging. In Figure 12, we notice that consumer MT_004 slightly underperforms compared to other consumers. We could verify this if we normalize the P50 loss of every consumer with their average power usage that we calculated previously.
The gray bars denote the distribution of each variable. One thing I always do is find which values have a low frequency. Then, I check how the model performs in those areas. Hence, you can easily detect if your model captures the behavior of rare events.
In general, you can use this TFT feature to probe your model for weaknesses and proceed to further investigation.
We can seamlessly use Temporal Fusion Transformer with Optuna to perform hyperparameter tuning:
The problem is that since TFT is a Transformer-based model, you will need significant hardware resources!
Temporal Fusion Transformer is undoubtedly a milestone for the Time-Series community.
Not only does the model achieves SOTA results, but also provides a framework for the interpretability of predictions. The model is also available in the Darts python library, which is based on the PyTorch Forecasting library.
Finally, if you are curious to learn about the architecture of the Temporal Fusion Transformer in detail, check the companion article on the original paper.
Temporal Fusion Transformer: Time Series Forecasting with Deep Learning — Complete Tutorial Republished from Source https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91?source=rss—-7f60cf5620c9—4 via https://towardsdatascience.com/feed