fla-org/flash-linear-attention
This repo aims at providing a collection of efficient Triton-based implementations for state-of-the-art linear attention models. All implementations are written purely in PyTorch and Triton, making them platform-agnostic. Currently verified platforms include NVIDIA, AMD, and Intel. Any pull requests are welcome!
News
- $\texttt{[2025-09]}$: 🐻 Thrilled to announce that GDN has been integrated into Qwen3-Next. Check out the PR and their blog post for more infos!
- $\texttt{[2025-08]}$: 🌲 Add Log-Linear Attention implementation to
fla(paper). - $\texttt{[2025-08]}$: 🎓 Add MoM implementation to
fla(paper). - $\texttt{[2025-07]}$: 🐳 Add MLA implementation to
fla(paper). - $\texttt{[2025-07]}$: 🛣️ Added PaTH Attention to fla (paper).
- $\texttt{[2025-06]}$: 🎉 Added MesaNet to fla (paper).
- $\texttt{[2025-06]}$: 🐍 Add Comba implementation to
fla(paper). - $\texttt{[2025-05]}$: 🎉 Add Rodimus* implementation to
fla(paper). - $\texttt{[2025-04]}$: 🎉 Add DeltaProduct implementation to
fla(paper). - $\texttt{[2025-04]}$: 🎉 Add FoX implementation to
fla(paper). - $\texttt{[2025-03]}$:
We have changed the defaultTheinitializer_rangeto the magic 🐳 0.006initializer_rangewas rolled back to the default value of 0.02. For actual training, we recommend trying both. - $\texttt{[2025-02]}$: 🐳 Add NSA implementations to
fla. See kernels here. - $\texttt{[2025-01]}$: 🔥 We are migrating to
torchtitan-based training framework. Check out the flame repo for more details. - $\texttt{[2025-01]}$: 🦅 Add RWKV7 implementations (both kernels and models) to
fla. - $\texttt{[2024-12]}$: Integrated
flash-bidirectional-attentiontofla-org(repo) - $\texttt{[2024-12]}$: 🎉 Add Gated DeltaNet implementation to
fla(paper). - $\texttt{[2024-12]}$: 🚀
flanow officially supports kernels with variable-length inputs. - $\texttt{[2024-11]}$: The inputs are now switched from head-first to seq-first format.
- $\texttt{[2024-11]}$: 💥
flanow provides a flexible way for training hybrid models. - $\texttt{[2024-10]}$: 🔥 Announcing
flame, a minimal and scalable framework for trainingflamodels. Check out the details here. - $\texttt{[2024-09]}$:
flanow includes a fused linear and cross-entropy layer, significantly reducing memory usage during training. - $\texttt{[2024-09]}$: 🎉 Add GSA implementation to
fla(paper). - $\texttt{[2024-05]}$: 🎉 Add DeltaNet implementation to
fla(paper). - $\texttt{[2024-05]}$: 💥
flav0.1: a variety of subquadratic kernels/layers/models integrated (RetNet/GLA/Mamba/HGRN/HGRN2/RWKV6, etc., see Models). - $\texttt{[2023-12]}$: 💥 Launched
fla, offering a collection of implementations for state-of-the-art linear attention models.
Models
Roughly sorted according to the timeline supported in fla. The recommended training mode is chunk when available.
Installation
The following requirements should be satisfied
- PyTorch >= 2.5
- Triton >=3.0 (or nightly version, see FAQs)
- einops
- transformers >=4.45.0
- datasets >=3.3.0
Starting from v0.3.2, the packages published on PyPI are fla-core and flash-linear-attention. The former contains all our customized kernels and only depends on PyTorch, Triton, and einops. The latter is an extension package of the former, containing fla/layers and fla/models, and depends on transformers. We also provide Triton implementations for conv1d operations, so causal-conv1d is not required.
You can install fla with pip:
|
|
As fla is actively developed now, for the latest features and updates, an alternative way is to install the package from source. Note that installing from git uses the default mode, so you need to uninstall both fla-core and flash-linear-attention first:
|
|
or manage fla with submodules
|
|
If you have installed triton-nightly and torch pre version, please use the following command:
|
|
Usage
Token Mixing
We provide ``token mixing’’ linear attention layers in fla.layers for you to use.
You can replace the standard multihead attention layer in your model with other linear attention layers.
Example usage is as follows:
|
|
We provide the implementations of models that are compatible with 🤗 Transformers library.
Here’s an example of how to initialize a GLA model from the default configs in fla:
|
|
Fused Modules
We offer a collection of fused modules in fla.modules to facilitate faster training:
Rotary Embedding: rotary positional embeddings as adopted by the Llama architecture, a.k.a., Transformer++.Norm Layers:RMSNorm,LayerNormandGroupNormRMSNormLinear,LayerNormLinearandGroupNormLinearto reduce memory usage of intermediate tensors for improved memory efficiency.
Norm Layers with Gating: combine norm layers with element-wise sigmoid or swish gating, as used by RetNet/GLA.Cross Entropy: faster Triton implementation of cross entropy loss.Linear Cross Entropy: fused linear layer and cross entropy loss to avoid the materialization of large logits tensors. Also refer to implementations by mgmalek and Liger-Kernel.Linear KL Divergence: fused linear layer and KL divergence loss in a similar vein as CE loss.
[!IMPORTANT] You can control using
fuse_linear_cross_entropyin the model configuration to enable/disable the fused linear cross entropy loss.This fused implementation is more memory-efficient but may reduce numerical precision. Due to this trade-off, it is disabled by default. If you enable this feature and encounter training instability (e.g., loss divergence), we recommend disabling it to see if the issue is resolved.
Generation
Upon successfully pretraining a model, it becomes accessible for generating text using the 🤗 text generation APIs. In the following, we give a generation example:
|
|
We also provide a simple script here for benchmarking the generation speed. Simply run it by:
|
|
All of the pretrained models currently available can be found in fla-hub.
|
|
Hybrid Models
fla provides a flexible method to incorporate standard attention layers into existing linear attention models.
This is easily achieved by specifying the attn argument in the model configuration.
For example, to create a 2-layer Samba model with interleaved Mamba and local attention layers, using a sliding window size of 2048:
|
|
During inference, you DO NOT need to revise anything for generation! The model will produce output as-is, without any need for additional configurations or modifications.
Training
We provide a minimal framework called 🔥 flame built on top of torchtitan, for efficient training of fla models.
Checkout the GLA example for more details.
Evaluation
The lm-evaluation-harness library allows you to easily perform (zero-shot) model evaluations. Follow the steps below to use this library:
-
Install
lm_evalfollowing their instructions. -
Run evaluation with:
|
|
We’ve made fla compatible with hf-style evaluations, you can call evals.harness to finish the evaluations.
Running the command above will provide the task results reported in the GLA paper.
- Multi-GPU Evaluation with Hugging Face accelerate 🚀
To perform data-parallel evaluation (where each GPU loads a separate full copy of the model), we leverage the accelerate launcher as follows:
|
|
- 📏 RULER Benchmark suite
The RULER benchmarks are commonly used for evaluating model performance on long-context tasks.
You can evaluate fla models on RULER directly using lm-evaluation-harness. RULER is only available in a relatively recent version of lm-evaluation-harness, so make sure you have the latest version installed.
|
|
Then, install the necessary dependencies for RULER:
|
|
and run evaluation by (e.g., 32k contexts):
|
|
If a GPU can’t load a full copy of the model, please refer to this link for FSDP settings.
[!Tip] If you are using
lm-evaluation-harnessas an external library and can’t find (almost) any tasks available, before callinglm_eval.evaluate()orlm_eval.simple_evaluate(), simply run the following to load the library’s stock tasks!
|
|
Benchmarks
We compared our Triton-based RetNet implementation with CUDA-based FlashAttention2, using a batch size of 8, 32 heads, and a head dimension of 128, across different sequence lengths. These tests were conducted on a single H100 80GB GPU, as illustrated in the following graph
|
|
Citation
If you find this repository helpful, please cite our work:
|
|
Star History
Acknowledgments
We extend our gratitude to Bitdeer for providing CI server resources that power our infrastructure.