Home

MAMBA: Can We Achieve Infinite Context Length?

Infinite context length is important for Large Language Models (LLMs) as it can significantly enhance their ability to understand and solve complex tasks. With unlimited context, users could provide vast amounts of data of any modalities, enabling the model to dynamically select the most relevant information for a given query. This capability could lead to more advanced AI agents that continuously integrate and retain knowledge over extended periods. However, achieving infinite context length is challenging due to computational constraints and the need for efficient memory management. These limitations are primarily due to their quadratic complexity of Transformers, with respect to sequence length, making scalability a major hurdle.

1. Are Transformers all we need?

Transformers are the go-to-architecture for large language models (LLMs), primarily due to their ability to model long context (reaching millions of tokens) using the self-attention mechanism. Self-attention enables the model to selectively focus on relevant tokens in the input sequence when generating the output. Mathematically, self-attention computes a weighted sum of values \( V \) based on the similarity between query \( Q \) and keys \( K \):

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

Here, \( Q, K, V \in \mathbb{R}^{N \times d_k} \) are the query, key, and value matrices, and \( d_k \) is the dimension of the keys, and the softmax ensures that the attention weights sum to 1.


def attention(Q, K, V):
    """
    Computes the scaled dot-product attention.

    Args:
        Q (torch.Tensor): Query tensor of shape (batch_size, seq_len, d_k).
        K (torch.Tensor): Key tensor of shape (batch_size, seq_len, d_k).
        V (torch.Tensor): Value tensor of shape (batch_size, seq_len, d_v).

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, seq_len, d_v).
    """
    d_k = Q.size(-1)  # Dimension of the key vectors
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))  # QK^T / sqrt(d_k)
    weights = F.softmax(scores, dim=-1)  # Softmax over the last dimension
    output = torch.matmul(weights, V)  # Weighted sum of values
    return output    

This mechanism enables efficient parallelization during training because the attention scores for all tokens can be computed in parallel. However, Transformers suffer from quadratic complexity with respect to sequence length \( N \), i.e., \( O(N^2d_k) \). This makes training and inference computationally expensive for long sequences.

1.1 Inference Bottlenecks and KV Caching

Unlike training, where self-attention is computed in parallel across tokens in a sequence, inference in autoregressive models proceeds sequentially. The model generates one token at a time, and each new token requires computing attention scores with all the previously generated tokens. This leads to increasing memory and compute costs as the sequence grows.

To mitigate this, Key-Value (KV) caching is commonly used (including many other strategies) to improve computational efficiency. KV caching stores the attention keys and value from previous time steps, allowing the model to reuse them instead of recomputing attention scores from scratch. This reduces redundant computations and significantly speeds up the inference.


class KVCache:
    def __init__(self, max_length, d_model):
        """
        Initialize the KV cache.
        """
        self.max_length = max_length
        self.d_model = d_model
        self.keys = torch.empty((0, d_model))  # Initialize empty cache for keys
        self.values = torch.empty((0, d_model))  # Initialize empty cache for values

    def update(self, new_keys, new_values):
        """
        Update the KV cache with new keys and values.
        """
        # Concatenate new keys and values to the cache
        self.keys = torch.cat([self.keys, new_keys], dim=1)  # Concatenate along sequence dimension
        self.values = torch.cat([self.values, new_values], dim=1)

        # Truncate the cache if it exceeds max_length
        if self.keys.size(1) > self.max_length:
            self.keys = self.keys[:, -self.max_length:, :]
            self.values = self.values[:, -self.max_length:, :]

    def get_attention_output(self, query):
        """
        Compute attention output using the cached keys and values.
        """        
        output = attention(query, self.keys, self.values)
        return output

# Example Usage
batch_size = 2
seq_len = 10
d_model = 64
max_length = 128

# Initialize KV cache
kv_cache = KVCache(max_length=max_length, d_model=d_model)

# Dummy input tensors (batch_size=2, seq_len=10, d_model=64)
new_keys = torch.randn(batch_size, seq_len, d_model)
new_values = torch.randn(batch_size, seq_len, d_model)
query = torch.randn(batch_size, seq_len, d_model)

# Update cache with new keys and values
kv_cache.update(new_keys, new_values)

# Compute attention output
output = kv_cache.get_attention_output(query)

KV caching improves inference speed by avoiding redundant computations, making it feasible to handle long sequences. However for very long sequences, the cache may need to be pruned or truncated to fit within memory limits.

2. Can RNNs help?

Recurrent Neural Networks (RNNs) are a class of neural networks designed for sequential data processing. They maintain a hidden state that captures past context and updates it at each time step based on the current input and the previous hidden state. Mathematically, the update rule and the output is given by:

Hidden State Update:

\[ h(t) = f\big(W_h h(t-1) + W_i x(t)\big) \]

Output:

\[ y(t) = g\big(W_o h(t)\big) \]
where:

Theoretically, RNNs can model arbitrarily long sequences, but in practice, they struggle with long-range dependencies due to the vanishing gradient problem—gradients shrink as they propagate through time, making it difficult to learn from distant tokens. Also, the fixed size of the given hidden state makes it difficult to compress long sequences, leading to information loss.

Long Short-Term Memory (LSTM) and Gated Recurrent Unit (GRU) architectures introduced gating mechanisms to control the flow of information. These gating mechanism allow the model to selectively retain or forget information, improving its ability to capture long-range dependencies. However they still struggle with very long range sequences and their sequential nature prevents parallelization, making training slow.

3. How about State Space Models?

State Space Models (SSMs) provide an alternative to RNNs by replacing discrete recurrence with structured state transitions, leading to faster computation and better long-range memory retention. SSMs are commonly used in control theory to modeling the dynamics of systems, and have been adopted for sequence modeling.

Continuous-Time SSM Equations:

\[ \dot{h}(t) = A h(t) + B x(t) \]
\[ y(t) = C h(t) + D x(t) \]

Discrete-Time SSM Equations:

\[ h(t) = \overline{A} h(t-1) + \overline{B} x(t) \]
\[ y(t) = C h(t) + D x(t) \]
Where:

Unlike RNNs, SSMs can process sequences in parallel by leveraging structured matrix multiplication (e.g., using convolutional or FFT operations), making them significantly faster and more memory efficient. However, SSMs are time-invariant, meaning the matrices \( A \), \( B \), \( C \), and \( D \) remain the same for each token. This limits their ability to perform content-aware reasoning, such as dynamically focusing on or ignoring specific parts of the input sequence.

4. What Makes Mamba Special?

Mamba introduces Selective State Space Models and utilizes hardware optimizations (surprising fact: MAMBA was rejected at ICLR-2024). It is designed to handle long sequences (like text or time series data) faster and more efficiently than RNNs or Transformers. Mamba stands out because of its:

In essence, Mamba combines the best of both worlds: the parallel processing capabilities of Transformers with the efficient memory retention of RNNs.

4.1 How Mamba Works ?

At its core, Mamba is based on State Space Models (SSMs). However it add selectivity to these models. This enables it to dynamically decide which part of the input to focus on and which part to ignore. This selectivity is achieved by making some of the parameters input dependent.

4.2 The Math Behind Mamba

Mamba uses two main equations to process data. These equations describe how the hidden state evolves over time and how the output is computed:

Hidden State Update

$$h(t) = A h(t-1) + B x(t)$$

Here:

Output Equation

$$y(t) = C h(t) + D x(t)$$

Here:

4.3 Discretization: Making It Work for Sequences

Since real-world data (like text) is discrete (e.g., words or tokens), these continuous-time equations are converted into a form that works for discrete sequences.

For this zero-order hold (ZOH) method is used to discretize the continuous equations. The discretized version is given as:

$$A_{\Delta} = e^{A \Delta t}, \quad B_{\Delta} = A^{-1}(e^{A \Delta t} - I)B$$

Where:

This discretization step is crucial because it allows Mamba to handle real-world data efficiently. Mamba makes \( B \), \( C \), and the step size \( \Delta \) input-dependent. This means it can adapt its behavior based on the input, making it context dependent.

4.4 Parallel Processing

One of Mamba’s biggest strengths is its ability to process sequences in parallel, even though it’s based on a recurrent model. This is achieved using parallel scan algorithm.

Recurrent models compute the hidden state \( h_t \) sequentially, where each state depends on the previous one:

$$h_t = f(h_{t-1}, x_t),$$

This sequential dependency makes parallelization challenging. Mamba breaks the sequence into smaller chunks, processes them in parallel, and then combines the results. The key insight is that the state update operation often satisfies the associative property, which allows for efficient parallel computation. Mathematically, if the state update function \( f \) is associative, then:

$$f(f(h_1, x_2), x_3) = f(h_1, f(x_2, x_3)).$$

This property enables Mamba to compute partial states independently and combine them later.

Mamba leverages this associative property of the state update operation to enable parallel processing. Here’s how it works:

  1. Break the Sequence into Chunks — The input sequence \( X = [x_1, x_2, \dots, x_T] \) is divided into \( K \) smaller chunks \( C_1, C_2, \dots, C_K \), , where each chunk \( C_i \) contains a subset of the sequence. For example, if \( T = 8 \) and \( K = 2 \), the chunks might be \( C_1 = [x_1, x_2, x_3, x_4] \) and \( C_2 = [x_5, x_6, x_7, x_8] \).

  2. Compute Partial States in Parallel — Each chunk is processed in parallel to compute its partial hidden state. Each chunk \( C_i \) is processed independently to compute its partial hidden state \( H_i \). This is done using the same recurrent update rule, but applied to each chunk in parallel. For example, for chunk \( C_1 \):

    $$H_1 = f(f(f(f(h_0, x_1), x_2), x_3), x_4), $$

    where \( h_0 \) is the initial hidden state. Similarly, \( H_2 \) is computed for \( C_2 \).

  3. Combine the Results — The partial hidden states \( H_1, H_2, \dots, H_K \) are combined using a parallel scan operation. This step ensures that the final hidden state for each token takes into account all previous tokens in the sequence. The parallel scan operation leverages the associative property of \( f \) to efficiently merge the partial states. For example:

    $$ h_4 = H_1, \quad h_8 = f(H_1, H_2).$$

By leveraging the associative property and parallel scan, Mamba achieves significant speedups over traditional sequential processing, making it highly efficient for long sequences.

Why Is This Fast?

By processing chunks in parallel, Mamba avoids the sequential bottleneck of traditional RNNs. This makes it much faster at inference time, especially for long sequences. Here’s why:

5. Mamba Architecture

The Mamba Block is designed for efficient sequence modeling. It replaces the quadratic complexity of attention mechanisms with a linear-time approach, combining State Space Models (SSMs), local feature extraction, and input-dependent parameterization.

5.1 Core Components

The Mamba Block consists of the following components:

5.2 Implementation

The Mamba Block combines input projection, convolutional filtering, and selective-space computation to process sequences efficiently. The block is designed to handle long-range dependencies by dynamically focusing on relevant parts of the input. The key computations in the MAMBA architecture are:


class MambaBlock(nn.Module):
    """
    Adapted from:
    - https://github.com/johnma2006/mamba-minimal
    Mamba Block: Combines input projection, convolution, and selective state-space computation.
    """
    def __init__(self, d_model, d_state=16, expand=2, dt_rank='auto', d_conv=4, bias=False):
        super().__init__()
        d_inner = expand * d_model
        dt_rank = math.ceil(d_model / 16) if dt_rank == 'auto' else dt_rank
        
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
        self.conv = LocalConv(d_inner, d_conv)
        self.x_proj = nn.Linear(d_inner, dt_rank + d_state * 2, bias=False)
        self.dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1).repeat(d_inner, 1)))
        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
        self.scan = SelectiveScan()
        
    def forward(self, x):
        # Input projection and split into x and residual
        x, res = self.in_proj(x).chunk(2, dim=-1)
        
        # Apply 1D convolution
        x = self.conv(x)
        
        # Apply activation and selective scan
        x = F.silu(x)
        y = self.scan(x, * self.compute_params(x))
        
        # Combine with residual and project output
        return self.out_proj(y * F.silu(res))
    
    def compute_params(self, x):
        """
        Compute input-dependent parameters for selective scan.
        Args:
            x: Input tensor (batch, length, dim)
        Returns:
            delta: Time step scaling (batch, length, dim)
            A: State transition matrix (dim, state_dim)
            B: Input projection matrix (batch, length, state_dim)
            C: Output projection matrix (batch, length, state_dim)
            D: Skip connection (dim)
        """
        # Project input to delta, B, and C
        delta, B, C = self.x_proj(x).split([self.dt_proj.in_features, self.A_log.shape[1], self.A_log.shape[1]], dim=-1)
        
        # Compute delta and A
        delta = F.softplus(self.dt_proj(delta))  # Time step scaling
        A = -torch.exp(self.A_log)  # State transition matrix
        
        return delta, A, B, C, self.D

6. Conclusion

Mamba is a powerful alternative to transformers, offering linear time complexity for sequence modeling. Its selective scanning mechanism dynamically focuses on relevant tokens, making it faster and more memory-efficient than transformers. This efficiency allows Mamba to handle long context lengths effectively, as its context cache (hidden states) does not grow with the sequence length, unlike transformers.

While Mamba shows great promise for scaling to very long sequences, achieving truly infinite context length remains a challenge. However, recent works like Falcon Mamba-7B have introduced interesting strategies to push these boundaries.

The full code is available to run from this repository.

7. References

  1. A Visual Guide to Mamba and State Space Models by Maarten Grootendorst
  2. Gu, Albert, and Tri Dao. "Mamba: Linear-time sequence modeling with selective state spaces." arXiv preprint arXiv:2312.00752 (2023).
  3. Code adapted from mamba-minimal and mamba-tiny, simplified implementations of MAMBA.
  4. MAMBA and State Space Models explained amazing video by AI Coffee Break with Letitia.
  5. Zuo, Jingwei, et al. "Falcon mamba: The first competitive attention-free 7b language model." arXiv preprint arXiv:2410.05355 (2024).