Infinite context length is important for Large Language Models (LLMs) as it can significantly enhance their ability to understand and solve complex tasks. With unlimited context, users could provide vast amounts of data of any modalities, enabling the model to dynamically select the most relevant information for a given query. This capability could lead to more advanced AI agents that continuously integrate and retain knowledge over extended periods. However, achieving infinite context length is challenging due to computational constraints and the need for efficient memory management. These limitations are primarily due to their quadratic complexity of Transformers, with respect to sequence length, making scalability a major hurdle.
What would it look like to combine Google search (shallow Knowledge Graph reasoning over an ultra-ultra-wide index)
— Dwarkesh Patel (@dwarkesh_sp) February 14, 2025
with LLM in-context learning (highly intelligent operations on a tiny index)? pic.twitter.com/PLh265Bloe
Transformers are the go-to-architecture for large language models (LLMs), primarily due to their ability to model long context (reaching millions of tokens) using the self-attention mechanism. Self-attention enables the model to selectively focus on relevant tokens in the input sequence when generating the output. Mathematically, self-attention computes a weighted sum of values \( V \) based on the similarity between query \( Q \) and keys \( K \):
Here, \( Q, K, V \in \mathbb{R}^{N \times d_k} \) are the query, key, and value matrices, and \( d_k \) is the dimension of the keys, and the softmax ensures that the attention weights sum to 1.
def attention(Q, K, V):
"""
Computes the scaled dot-product attention.
Args:
Q (torch.Tensor): Query tensor of shape (batch_size, seq_len, d_k).
K (torch.Tensor): Key tensor of shape (batch_size, seq_len, d_k).
V (torch.Tensor): Value tensor of shape (batch_size, seq_len, d_v).
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, d_v).
"""
d_k = Q.size(-1) # Dimension of the key vectors
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k)) # QK^T / sqrt(d_k)
weights = F.softmax(scores, dim=-1) # Softmax over the last dimension
output = torch.matmul(weights, V) # Weighted sum of values
return output
This mechanism enables efficient parallelization during training because the attention scores for all tokens can be computed in parallel. However, Transformers suffer from quadratic complexity with respect to sequence length \( N \), i.e., \( O(N^2d_k) \). This makes training and inference computationally expensive for long sequences.
Unlike training, where self-attention is computed in parallel across tokens in a sequence, inference in autoregressive models proceeds sequentially. The model generates one token at a time, and each new token requires computing attention scores with all the previously generated tokens. This leads to increasing memory and compute costs as the sequence grows.
To mitigate this, Key-Value (KV) caching is commonly used (including many other strategies) to improve computational efficiency. KV caching stores the attention keys and value from previous time steps, allowing the model to reuse them instead of recomputing attention scores from scratch. This reduces redundant computations and significantly speeds up the inference.
class KVCache:
def __init__(self, max_length, d_model):
"""
Initialize the KV cache.
"""
self.max_length = max_length
self.d_model = d_model
self.keys = torch.empty((0, d_model)) # Initialize empty cache for keys
self.values = torch.empty((0, d_model)) # Initialize empty cache for values
def update(self, new_keys, new_values):
"""
Update the KV cache with new keys and values.
"""
# Concatenate new keys and values to the cache
self.keys = torch.cat([self.keys, new_keys], dim=1) # Concatenate along sequence dimension
self.values = torch.cat([self.values, new_values], dim=1)
# Truncate the cache if it exceeds max_length
if self.keys.size(1) > self.max_length:
self.keys = self.keys[:, -self.max_length:, :]
self.values = self.values[:, -self.max_length:, :]
def get_attention_output(self, query):
"""
Compute attention output using the cached keys and values.
"""
output = attention(query, self.keys, self.values)
return output
# Example Usage
batch_size = 2
seq_len = 10
d_model = 64
max_length = 128
# Initialize KV cache
kv_cache = KVCache(max_length=max_length, d_model=d_model)
# Dummy input tensors (batch_size=2, seq_len=10, d_model=64)
new_keys = torch.randn(batch_size, seq_len, d_model)
new_values = torch.randn(batch_size, seq_len, d_model)
query = torch.randn(batch_size, seq_len, d_model)
# Update cache with new keys and values
kv_cache.update(new_keys, new_values)
# Compute attention output
output = kv_cache.get_attention_output(query)
KV caching improves inference speed by avoiding redundant computations, making it feasible to handle long sequences. However for very long sequences, the cache may need to be pruned or truncated to fit within memory limits.
Recurrent Neural Networks (RNNs) are a class of neural networks designed for sequential data processing. They maintain a hidden state that captures past context and updates it at each time step based on the current input and the previous hidden state. Mathematically, the update rule and the output is given by:
\( h(t) \) is the hidden state at time \( t \), which serves as memory.
\( x(t) \) is the input at time \( t \).
\( W_h \) and \( W_i \) are weight matrices.
\( f \) is the activation function (e.g., ReLU, tanh).
\( W_o \) is the weight matrix for the output.
\( g \) is the output activation function (e.g., softmax for classification).
Theoretically, RNNs can model arbitrarily long sequences, but in practice, they struggle with long-range dependencies due to the vanishing gradient problem—gradients shrink as they propagate through time, making it difficult to learn from distant tokens. Also, the fixed size of the given hidden state makes it difficult to compress long sequences, leading to information loss.
Long Short-Term Memory (LSTM) and Gated Recurrent Unit (GRU) architectures introduced gating mechanisms to control the flow of information. These gating mechanism allow the model to selectively retain or forget information, improving its ability to capture long-range dependencies. However they still struggle with very long range sequences and their sequential nature prevents parallelization, making training slow.
State Space Models (SSMs) provide an alternative to RNNs by replacing discrete recurrence with structured state transitions, leading to faster computation and better long-range memory retention. SSMs are commonly used in control theory to modeling the dynamics of systems, and have been adopted for sequence modeling.
Unlike RNNs, SSMs can process sequences in parallel by leveraging structured matrix multiplication (e.g., using convolutional or FFT operations), making them significantly faster and more memory efficient. However, SSMs are time-invariant, meaning the matrices \( A \), \( B \), \( C \), and \( D \) remain the same for each token. This limits their ability to perform content-aware reasoning, such as dynamically focusing on or ignoring specific parts of the input sequence.
Mamba introduces Selective State Space Models and utilizes hardware optimizations (surprising fact: MAMBA was rejected at ICLR-2024). It is designed to handle long sequences (like text or time series data) faster and more efficiently than RNNs or Transformers. Mamba stands out because of its:
Selective Focus — Mamba can dynamically select parts of the input to pay attention to. This makes it great at understanding context.
Fast Inference — Unlike Transformers, which use a slow and memory-hungry attention mechanism, Mamba uses a lightweight State Space Model (SSM) block. This makes it much faster, especially for long sequences.
Hardware-Friendly — Mamba is designed to work efficiently on modern hardwares. It minimizes memory usage by fusing multiple operations into a single step, reducing the need to constantly read and write data.
In essence, Mamba combines the best of both worlds: the parallel processing capabilities of Transformers with the efficient memory retention of RNNs.
At its core, Mamba is based on State Space Models (SSMs). However it add selectivity to these models. This enables it to dynamically decide which part of the input to focus on and which part to ignore. This selectivity is achieved by making some of the parameters input dependent.
Mamba uses two main equations to process data. These equations describe how the hidden state evolves over time and how the output is computed:
Here:
Here:
Since real-world data (like text) is discrete (e.g., words or tokens), these continuous-time equations are converted into a form that works for discrete sequences.
For this zero-order hold (ZOH) method is used to discretize the continuous equations. The discretized version is given as:
Where:
This discretization step is crucial because it allows Mamba to handle real-world data efficiently. Mamba makes \( B \), \( C \), and the step size \( \Delta \) input-dependent. This means it can adapt its behavior based on the input, making it context dependent.
One of Mamba’s biggest strengths is its ability to process sequences in parallel, even though it’s based on a recurrent model. This is achieved using parallel scan algorithm.
Recurrent models compute the hidden state \( h_t \) sequentially, where each state depends on the previous one:
This sequential dependency makes parallelization challenging. Mamba breaks the sequence into smaller chunks, processes them in parallel, and then combines the results. The key insight is that the state update operation often satisfies the associative property, which allows for efficient parallel computation. Mathematically, if the state update function \( f \) is associative, then:
This property enables Mamba to compute partial states independently and combine them later.
Mamba leverages this associative property of the state update operation to enable parallel processing. Here’s how it works:
Break the Sequence into Chunks — The input sequence \( X = [x_1, x_2, \dots, x_T] \) is divided into \( K \) smaller chunks \( C_1, C_2, \dots, C_K \), , where each chunk \( C_i \) contains a subset of the sequence. For example, if \( T = 8 \) and \( K = 2 \), the chunks might be \( C_1 = [x_1, x_2, x_3, x_4] \) and \( C_2 = [x_5, x_6, x_7, x_8] \).
Compute Partial States in Parallel — Each chunk is processed in parallel to compute its partial hidden state. Each chunk \( C_i \) is processed independently to compute its partial hidden state \( H_i \). This is done using the same recurrent update rule, but applied to each chunk in parallel. For example, for chunk \( C_1 \):
where \( h_0 \) is the initial hidden state. Similarly, \( H_2 \) is computed for \( C_2 \).
Combine the Results — The partial hidden states \( H_1, H_2, \dots, H_K \) are combined using a parallel scan operation. This step ensures that the final hidden state for each token takes into account all previous tokens in the sequence. The parallel scan operation leverages the associative property of \( f \) to efficiently merge the partial states. For example:
By leveraging the associative property and parallel scan, Mamba achieves significant speedups over traditional sequential processing, making it highly efficient for long sequences.
By processing chunks in parallel, Mamba avoids the sequential bottleneck of traditional RNNs. This makes it much faster at inference time, especially for long sequences. Here’s why:
Hardware Efficiency — GPUs are designed to handle parallel computations, and Mamba’s parallel scan algorithm takes full advantage of this capability.
Reduced Memory Overhead — Instead of storing intermediate hidden states for every token, Mamba only needs to store the partial results for each chunk, reducing memory usage.
The Mamba Block is designed for efficient sequence modeling. It replaces the quadratic complexity of attention mechanisms with a linear-time approach, combining State Space Models (SSMs), local feature extraction, and input-dependent parameterization.
The Mamba Block consists of the following components:
Root Mean Square Normalization (RMSNorm): This normalization stabilizes training by normalizing the input while preserving its direction, effectively controlling feature magnitudes.
where \( \gamma \) is a learnable scaling parameter. This formulation ensures that the normalized output \( \hat{x} \) has unit variance, while the learnable parameter \( \gamma \) allows the model to scale the normalized features appropriately.
class RMSNorm(nn.Module):
"""
Adapted from:
- https://github.com/johnma2006/mamba-minimal
Root Mean Square Layer Normalization (RMSNorm)
"""
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(d_model)) # Learnable Scale
def forward(self, x):
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # Compute RMS
return x * rms * self.gamma
Local Convolution: MAMBA uses local convolution before Selective State Space layer which helps in processing local information. This enables mixing of information across tokens, preventing the model from treating tokens independently enabling better understanding of the realtionship between nearby tokens. Hence the SSM to recieves an input where the token interactions are already partially modeled.
class LocalConv(nn.Module):
def __init__(self, d_model, kernel_size=4, conv_bias=True):
super().__init__()
self.conv1d = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model, bias=conv_bias, padding=kernel_size - 1)
def forward(self, x):
x = x.transpose(1, 2) # Change to (batch, channels, length) for Conv1D
x = self.conv1d(x)
x = x[:, :, :-self.conv1d.kernel_size[0] + 1] # Adjust shape after padding
return x.transpose(1, 2) # Restore original shape (batch, length, channels)
Selective State Space Model (SSM): The Selective State Space Model is designed to model long-range dependencies by discretizing a continuous-time SSM. The hidden state \( h_t \) is given as:
Here, \( A_{\Delta} \) and \( B_{\Delta} \) are discretized versions of the learned matrices \( A \) and \( B \). \( x_t \) is the input at time \( t \). Unlike SSM, \( B \) and \( C \) here are input-dependent, allowing the model to dynamically focus on relevant parts of the input.
The output \( y_t \) is computed as:
class SelectiveScan(nn.Module):
"""
Adapted from:
- https://github.com/johnma2006/mamba-minimal
Selective Scan module for state-space computation.
Args:
u: Input sequence (batch, length, dim)
dt: Time step scaling (batch, length, dim)
A: State transition matrix (dim, state_dim)
B: Input projection matrix (batch, length, state_dim)
C: Output projection matrix (batch, length, state_dim)
D: Skip connection (dim)
Returns:
Output sequence (batch, length, dim)
"""
def __init__(self):
super().__init__()
def forward(self, u, dt, A, B, C, D):
# Discretize A: A_Δ = exp(dt * A)
A_delta = torch.exp(torch.einsum('bld,dn->bldn', dt, A)).clamp(min=-20)
# Input-dependent state update: B_Δ * x_t = dt * u * B
B_delta_u = torch.einsum('bld,bld,bln->bldn', dt, u, B)
# Cumulative state evolution: cumsum(exp(A_Δ))
A_delta_cumsum = torch.exp(F.pad(A_delta[:, 1:], (0, 0, 0, 0, 1, 0)).cumsum(1))
# Normalized state: h_t = (B_Δ * x_t) / (cumsum(exp(A_Δ)) + eps)
h_t = B_delta_u / (A_delta_cumsum + 1e-12)
# Output computation: y_t = C * cumsum(h_t * cumsum(exp(A_Δ)))
y_t = torch.einsum('bldn,bln->bld', h_t.cumsum(1) * A_delta_cumsum, C)
# Add skip connection: y_t = y_t + D * x_t
return y_t + u * D
SiLU Activation: The SiLU (Sigmoid Linear Unit) activation function combines a sigmoid gate with a linear transformation,. The SiLU activation is defined as:
Here \( x \) is the input, and \( \sigma(x) \) is the sigmoid function.
In the Mamba block, SiLU is applied to:
The convolutional output to introduce non-linearity.
The residual connection to gate the skip connection.
SiLU has a smooth gradient, which helps with stable training.
Residual Connection: The Residual Connection in the Mamba block stabilizes training and enables the model to combine information from both the current token and the context computed by the SSM. This allows the model to compute the similarity between the context-aware SSM output and the current token's embedding, enabling it to dynamically balance between long-range dependencies and local information.
The Mamba Block combines input projection, convolutional filtering, and selective-space computation to process sequences efficiently. The block is designed to handle long-range dependencies by dynamically focusing on relevant parts of the input. The key computations in the MAMBA architecture are:
Input Projection: The input is projected into a higher-dimensional space and split into two parts: one for further processing and one for the residual conection.
Convolutional Filtering: A 1D convolution is applied to capture local patterns in the sequence.
Selective State-Space Computation: The input-dependent parameters \( \Delta, B, C \) are computed, and the selective scan mechanim is applied to update the hidden state and compute the output.
Output Projection: The processed sequence is projected back to the original dimension.
class MambaBlock(nn.Module):
"""
Adapted from:
- https://github.com/johnma2006/mamba-minimal
Mamba Block: Combines input projection, convolution, and selective state-space computation.
"""
def __init__(self, d_model, d_state=16, expand=2, dt_rank='auto', d_conv=4, bias=False):
super().__init__()
d_inner = expand * d_model
dt_rank = math.ceil(d_model / 16) if dt_rank == 'auto' else dt_rank
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
self.conv = LocalConv(d_inner, d_conv)
self.x_proj = nn.Linear(d_inner, dt_rank + d_state * 2, bias=False)
self.dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).repeat(d_inner, 1)))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
self.scan = SelectiveScan()
def forward(self, x):
# Input projection and split into x and residual
x, res = self.in_proj(x).chunk(2, dim=-1)
# Apply 1D convolution
x = self.conv(x)
# Apply activation and selective scan
x = F.silu(x)
y = self.scan(x, * self.compute_params(x))
# Combine with residual and project output
return self.out_proj(y * F.silu(res))
def compute_params(self, x):
"""
Compute input-dependent parameters for selective scan.
Args:
x: Input tensor (batch, length, dim)
Returns:
delta: Time step scaling (batch, length, dim)
A: State transition matrix (dim, state_dim)
B: Input projection matrix (batch, length, state_dim)
C: Output projection matrix (batch, length, state_dim)
D: Skip connection (dim)
"""
# Project input to delta, B, and C
delta, B, C = self.x_proj(x).split([self.dt_proj.in_features, self.A_log.shape[1], self.A_log.shape[1]], dim=-1)
# Compute delta and A
delta = F.softplus(self.dt_proj(delta)) # Time step scaling
A = -torch.exp(self.A_log) # State transition matrix
return delta, A, B, C, self.D
Mamba is a powerful alternative to transformers, offering linear time complexity for sequence modeling. Its selective scanning mechanism dynamically focuses on relevant tokens, making it faster and more memory-efficient than transformers. This efficiency allows Mamba to handle long context lengths effectively, as its context cache (hidden states) does not grow with the sequence length, unlike transformers.
While Mamba shows great promise for scaling to very long sequences, achieving truly infinite context length remains a challenge. However, recent works like Falcon Mamba-7B have introduced interesting strategies to push these boundaries.
The full code is available to run from this repository.