Back to Blog
tradingPart 5 of 8tradingtfttime-series

Temporal Fusion Transformer: Gating, Memory, and Sector Learnability

The architecture that finally worked — and a precise explanation of why each component addresses a specific failure mode of the Standard Transformer on financial data.

June 2, 2026·7 min read

Why TFT, and What Makes It Different

The Temporal Fusion Transformer (TFT) was introduced by Google researchers Bryan Lim, Sercan Arık, Nicolas Loeff, and Tomas Pfister in 2021 specifically for multi-horizon forecasting of complex real-world time series. It was designed with a clear acknowledgment that financial and operational time series have properties standard transformers handle poorly:

  • Features have context-dependent importance (not all features matter equally in all regimes)
  • Sequences have persistent state (today's conditions partially come from yesterday's)
  • Deep networks on noisy data suffer from gradient instability
  • Predictions are needed at multiple future horizons simultaneously

TFT addresses each with a specific architectural component. Let's go through them one by one.


Component 1: Gated Variable Selection (Feature Gating)

This is the most important TFT innovation for financial applications.

Feature Gating
A learned mechanism that produces a weight between 0 and 1 for each input feature at each time step. High gate weight = use this feature. Low gate weight = suppress this feature. The weights are learned from data and can vary by context — the model learns not just what features matter, but when they matter.

Mathematically, for each feature i at time step t:

gate_i(t) = sigmoid(W_gate · context(t) + b)

weighted_feature_i(t) = gate_i(t) × feature_i(t)

The context vector is a function of all features at time t — so the gate learns "given this market state, how much should I trust this feature?"

What the gates learned in our experiment

After training on NASDAQ-100 sector models, examining the learned gate weights revealed:

FeatureAverage gate (Consumer)Average gate (Semis)
iv_hv_ratio0.780.93 rsi_140.410.12 sentiment_score0.610.34 vol_ratio0.520.88 sma_200_ratio0.740.45 put_call_ratio0.380.71

Interpretation: for semiconductor stocks (high gamma, speculative), the model heavily weights volatility-related features (iv_hv_ratio, vol_ratio, put_call_ratio). For consumer stocks (stable, trend-following), the model weights trend (sma_200) and sentiment more heavily, and nearly ignores RSI.

This is exactly what a sophisticated human analyst would do. The model learned it from data.

class GatedVariableSelection(nn.Module):

"""Simplified feature gating implementation."""
def __init__(self, d_features: int, d_model: int):
super().__init__()

Gate network: maps features → gate weights

self.gate_net = nn.Sequential(
nn.Linear(d_features, d_features), nn.Sigmoid() )

Feature transformation

self.feature_proj = nn.Linear(d_features, d_model)

def forward(self, x: torch.Tensor) -> torch.Tensor:

# x: (batch, seq_len, d_features)
gates = self.gate_net(x)           # (batch, seq_len, d_features)
gated_x = gates * x               # element-wise gating
return self.feature_proj(gated_x) # project to d_model

Component 2: LSTM Temporal Encoder (Temporal Memory)

LSTM — Long Short-Term Memory
A recurrent neural network designed to retain information across long sequences. Unlike vanilla RNNs that suffer from vanishing gradients, LSTMs use gated mechanisms (input gate, forget gate, output gate) to selectively remember or forget information. The "memory cell" persists state across many time steps.

In TFT, the LSTM sits between feature gating and the attention layer. It processes the gated feature sequence and produces a hidden state that encodes the temporal trajectory of the market.

hidden_t, cell_t = LSTM(gated_features_t, hidden_{t-1}, cell_{t-1})

What temporal memory captures that attention can't

Think of the LSTM as watching the market unfold like a movie, vs. the transformer's attention looking at isolated photographs. Financial patterns that require "memory":

  • Volatility clustering: "Volatility has been expanding for 10 days" — the LSTM carries this state
  • Momentum persistence: "This stock has been in an uptrend for 6 weeks" — accumulates in the hidden state
  • Regime continuity: "We're in a risk-off environment" — the LSTM remembers this even when individual day features are noisy
  • Earnings cycle position: "3 weeks until earnings" — timing effects accumulate in hidden state
class TemporalEncoder(nn.Module):

"""Bidirectional LSTM for temporal state encoding."""
def __init__(self, d_model: int, hidden_size: int, n_layers: int = 2):
super().__init__()

self.lstm = nn.LSTM(
input_size=d_model,
hidden_size=hidden_size,
num_layers=n_layers,
batch_first=True,
dropout=0.1,
bidirectional=True  # Bidirectional within historical window only

)
# Project bidirectional output back to d_model
self.proj = nn.Linear(hidden_size * 2, d_model)

def forward(self, x: torch.Tensor):

# x: (batch, seq_len, d_model)
output, (hidden, cell) = self.lstm(x)
# output: (batch, seq_len, hidden_size*2) — all time steps
return self.proj(output)  # → (batch, seq_len, d_model)
Bidirectional LSTM — the leakage concern
A bidirectional LSTM processes the sequence in both forward (past→future) and backward (future→past) directions. For forecasting, the backward pass must only operate within the historical context window — not across the train/validation boundary. This is safe when the sequence length is fixed (e.g., 60-day window). The backward LSTM sees day 60 at most, not day 61+. Claude's audit confirmed this was correctly implemented.

Component 3: Gated Residual Connections

This is a stability mechanism that addresses gradient degradation in deep networks trained on noisy data.

Standard Residual Connection

output = layer(x) + x — preserves the original signal by adding it back after transformation. Prevents information loss as depth increases. Used in ResNets, standard transformers.

Gated Residual Connection (TFT)
output = gate × transformed(x) + (1 − gate) × x — the model learns how much to "trust" the transformation vs. preserve the original. When the transformation would introduce noise, the gate can set itself to near-zero and pass the original signal through untouched.

gate = sigmoid(W_g · [x, transformed(x)])

output = gate · transformed(x) + (1 − gate) · x

class GatedResidual(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1):
super().__init__()
self.transform = nn.Sequential(
nn.Linear(d_model, d_model), nn.ELU(), nn.Linear(d_model, d_model), nn.Dropout(dropout) )

Gate depends on both original and transformed input

self.gate = nn.Sequential(
nn.Linear(d_model * 2, d_model), nn.Sigmoid() ) self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
transformed = self.transform(x)
gate_input = torch.cat([x, transformed], dim=-1)
g = self.gate(gate_input)
output = g * transformed + (1 - g) * x
return self.norm(output)

Why this matters for financial data

Financial features are inherently noisy. On any given day, many features might be misleading (earnings whispers, random news spikes, thin liquidity). The gated residual lets the network "ignore" a bad transformation at a specific layer rather than propagating noise. This is why TFT models trained on market data tend to be more stable than standard transformers — the gating acts as a continuous regularization mechanism.


Component 4: Multi-Horizon Output Heads

Instead of predicting a single return horizon, TFT predicts multiple simultaneously: ret5 (1 week), ret10 (2 weeks), ret20 (1 month). This is valuable for several reasons:

  • Shared representations: The model can learn features useful for all horizons simultaneously, making the training signal denser
  • Horizon consistency checks: If a 5-day signal is strong but the 20-day signal is flat, that reveals information about the trade's expected duration
  • Options strategy mapping: Different horizons map to different options contract choices — this becomes critical in Part 8
class MultiHorizonHead(nn.Module):
def __init__(self, d_model: int, horizons: list = [5, 10, 20]):
super().__init__()

Separate head per horizon — each learns different temporal patterns

self.horizon_heads = nn.ModuleDict({
f'ret{h}': nn.Sequential( nn.Linear(d_model, d_model // 2), nn.ELU(), nn.Linear(d_model // 2, 1) )
for h in horizons

})

def forward(self, x: torch.Tensor) -> dict:

# x: (batch, d_model) — final temporal representation
return {

name: head(x).squeeze(-1)
for name, head in self.horizon_heads.items()

}

Returns: {'ret5': tensor, 'ret10': tensor, 'ret20': tensor}


The Full TFT Architecture

class TemporalFusionTransformer(nn.Module):
    def __init__(
        self,
        d_features: int = 30,
        d_model: int = 128,
        lstm_hidden: int = 128,
        n_lstm_layers: int = 2,
        n_attn_heads: int = 4,
        n_attn_layers: int = 2,
        horizons: list = [5, 10, 20],
        dropout: float = 0.1,
    ):
        super().__init__()

        # Stage 1: Variable selection gating
        self.var_selection = GatedVariableSelection(d_features, d_model)

        # Stage 2: LSTM temporal encoder
        self.temporal_encoder = TemporalEncoder(d_model, lstm_hidden, n_lstm_layers)

        # Stage 3: Gated residual after LSTM
        self.lstm_gate = GatedResidual(d_model, dropout)

        # Stage 4: Multi-head self-attention (contextual attention)
        self.attention = nn.MultiheadAttention(
            d_model, n_attn_heads, dropout=dropout, batch_first=True
        )

        # Stage 5: Gated residual after attention
        self.attn_gate = GatedResidual(d_model, dropout)

        # Stage 6: Final gated residual for position-wise processing
        self.final_gate = GatedResidual(d_model, dropout)

        # Stage 7: Multi-horizon prediction heads
        self.output_heads = MultiHorizonHead(d_model, horizons)

    def forward(self, x: torch.Tensor) -> dict:
        # x: (batch, seq_len, d_features)

        # 1. Gate and select features
        x = self.var_selection(x)           # (batch, seq_len, d_model)

        # 2. LSTM temporal encoding
        lstm_out = self.temporal_encoder(x)
        x = self.lstm_gate(x + lstm_out)    # gated residual

        # 3. Self-attention over temporal representations
        attn_out, _ = self.attention(x, x, x)
        x = self.attn_gate(x + attn_out)    # gated residual

        # 4. Final processing
        x = self.final_gate(x)

        # 5. Use last time step for prediction
        final_repr = x[:, -1, :]            # (batch, d_model)

        return self.output_heads(final_repr) # dict of horizon predictions

Sector Decomposition: Why It Mattered

Training a single TFT on all 103 NASDAQ-100 tickers still underperformed. The breakthrough came from sector decomposition — training separate models for each sector group.

The learnability principle

Predictability in financial ML is not uniform across all stocks. Some cohorts have:

  • More stable statistical relationships between features and returns
  • Lower gamma / options reflexivity (less driven by dealer hedging flows)
  • Stronger institutional ownership patterns (more persistent, less noisy)
  • Cleaner temporal structure (trend and mean-reversion more consistent)
Temporal Structure / Cleaner Signal
A stock has "clean temporal structure" when its past patterns have a more stable relationship to its future returns — because its price dynamics are driven by gradual, persistent processes (institutional flows, fundamental improvement) rather than reflexive, nonlinear ones (gamma squeezes, meme flows, macro shock cascades).

Sector results

0.106
Variance Ratio (TFT, Consumer)