A gentle and comprehensive introduction to the DeltaNet
This blog post series accompanies our NeurIPS ‘24 paper - Parallelizing Linear Transformers with the Delta Rule over Sequence Length (w/ Bailin Wang, Yu Zhang, Yikang Shen and Yoon Kim). You can find the implementation here and the presentation slides here.
Notations: we use CAPITAL BOLD letters to represent matrices, lowercase bold letters to represent vectors, and regular lowercase letters to represent scalars.
The vanilla softmax attention mechanism, though powerful, suffers from quadratic complexity in sequence length. Let’s see how linear attention addresses this issue by starting with the standard softmax attention (assuming single head):
Here,
What linear attention
While removing softmax alone doesn’t immediately reduce computational complexity, it enables a crucial mathematical property: linearity. This property, particularly associativity, allows us to restructure the computations in ways that significantly improve efficiency. For training, researchers have developed chunkwise parallel techniques
For inference, we can also rearrange the computation as follows:
Let’s define a state matrix
This formulation reveals that linear attention is essentially a linear RNN with a matrix-valued state
With this approach, we only need to store and update
Long sequence modeling where quadratic complexity of softmax attention could be a significant bottleneck.
During generation, where computation is usually memory-bound, removing the KV cache can significantly enhance inference latency for
Unfortunately, there is no free lunch. The fixed-size state matrix in linear attention means it cannot perfectly preserve all historical information, making exact retrieval particularly challenging.
More formally, linear attention implements a key-value associative memory, which is the sum of outer products between keys and values
To minimize the retrieval error term, we need
This theoretical limitation manifests in practice: vanilla linear attention has underperformed compared to softmax attention (by a large margin) in language modeling. The primary cause is memory “overload”: in this key-value associative memory system, we can only add new key-value associations without the ability to erase existing information. As sequences grow longer, this leads to accumulating “retrieval errors” that degrade performance. Indeed, as noted by David Eagleman in his book “Livewired: The Inside Story of the Ever-Changing Brain”,
“The enemy of memory is not time; it’s other memories.”
(Thanks to Kazuki Irie for the reference!). Recent advances in gated variants of linear attention (such as GLA
with different structured parameterization for
For Decaying Fast weight
For GLA
For Mamba1
For Mamba2
Cf. Table 1 of GLA
The Delta Rule
To understand this intuitively, imagine teaching a child to aim at a target. If they shoot too far to the left, you’d tell them to adjust right; too far right, adjust left. The size of the adjustment depends on how far they missed - a concept directly reflected in the Delta Rule.
import numpy as np
def delta_rule(x, y, epochs=100, lr=0.1):
"""
Simple delta rule implementation
x: input features (N samples by D features)
y: target values (N samples)
"""
# Initialize weights
w = np.zeros(x.shape[1])
# Train
for _ in range(epochs):
for i in range(len(x)):
# Forward pass
pred = np.dot(x[i], w)
# Compute error
error = y[i] - pred
# Update weights
w += lr * error * x[i]
return w
# Example usage
if __name__ == "__main__":
# Generate toy data
x = np.random.randn(100, 3) # 100 samples, 3 features
true_w = np.array([0.5, -0.2, 0.1])
y = np.dot(x, true_w) + 0.1 * np.random.randn(100)
# Train
w = delta_rule(x, y)
print("True weights:", true_w)
print("Learned weights:", w)
DeltaNet
The parallel to the Delta Rule becomes clear when we break down the components:
We will revisit this form later, showing how it can emerge naturally from a single gradient descent step on a (online) loss function.
There’s another intuitive way to understand this update rule. Think of
where
MQAR (Multi-Query Associative Recall)
The MQAR task works as follows: Each letter is associated with a number, and the model is asked to correctly recall the number associated with each letter in a query sequence.
For example, given the input:
A 4 B 3 C 6 F 1 E 2 → A ? C ? F ? E ? B ?
The format consists of:
The correct output for this example would be:
4, 6, 1, 2, 3
While conventional gated convolution and recurrent models generally underperform in this task, in our experiments, we show that DeltaNet
This initial success was particularly exciting—achieving perfect performance on MQAR exceeded our expectations. What makes this result especially promising is that MQAR performance strongly correlates with “Associative-Recall-Hit” in real-world language modeling tasks
We’ve also conducted experiments on MAD
Model | Compress | Fuzzy Recall | In-Context Recall | Memorize | Noisy Recall | Selective Copy | Average |
---|---|---|---|---|---|---|---|
Transformer | 51.6 | 29.8 | 94.1 | 85.2 | 86.8 | 99.6 | 74.5 |
Hyena | 45.2 | 7.9 | 81.7 | 89.5 | 78.8 | 93.1 | 66.0 |
Multihead Hyena | 44.8 | 14.4 | 99.0 | 89.4 | 98.6 | 93.0 | 73.2 |
Mamba | 52.7 | 6.7 | 90.4 | 89.5 | 90.1 | 86.3 | 69.3 |
GLA | 38.8 | 6.9 | 80.8 | 63.3 | 81.6 | 88.6 | 60.0 |
DeltaNet | 42.2 | 35.7 | 100 | 52.8 | 100 | 100 | 71.8 |
where DeltaNet demonstrates its strong in-context recall capacities. These synthetic tasks are inexpensive to run and offer clear evidence that DeltaNet is likely to perform well at scale. This motivated us to focus on developing DeltaNet’s training algorithm and kernel implementation—after all, scaling up an arbitrary architecture without demonstrating its potential would risk wasting significant time and resources.
In the next post, we’ll explore a beautiful algorithm that parallelizes DeltaNet across sequence length. But first, let’s build some intuition about why DeltaNet is particularly well-suited for in-context retrieval tasks.
DeltaNet’s update rule can be derived by sequentially minimizing the mean squared error (MSE) between the desired output and the predicted output at each time step
Applying gradient descent to minimize this MSE loss gives:
When the learning rate
In contrast, vanilla linear attention employs a linear loss function:
The corresponding update rule for linear attention is:
By setting
Thus, DeltaNet’s superior performance in in-context retrieval becomes evident—it minimizes MSE at each step, making it ideal for tasks like associative recall where reducing large errors is crucial for accurate retrieval.