Nous Research Proposes Lighthouse Attention: A Training-Only Selection-Based Hierarchical Attention That Delivers 1.4–1.7× Pretraining Speedup at Long Context
Training large language models on long sequences has a well-known problem: attention is expensive. The scaled dot-product attention (SDPA) at the core of every transformer scales quadratically Θ(N²) in both compute and memory with sequence length N. FlashAttention addressed this through IO-aware tiling that avoids materializing the full N×N attention matrix in high-bandwidth memory, reducing the memory footprint significantly, but the underlying Θ(N²) compute scaling remains. Researchers at Nous Research have introduced a new method called Lighthouse Attention that addresses this bottleneck specifically at pretraining time, achieving a 1.40× to 1.69× end-to-end wall-clock speedup against a cuDNN-backed SDPA baseline, with matching or lower final training loss.
The core problem with existing sparse attention methods
To understand why Lighthouse works the way it does, it helps to know what existing sparse attention methods do. Most prior work like NSA, HISA, DSA, MoBA makes the same two design decisions. First, they pool only the key and value side while leaving queries at full resolution (asymmetric compression). Second, their selection logic lives inside a custom attention kernel, which means teams can’t reuse the optimized dense-attention kernels that modern GPU tensor cores are built around.
There is also a concern specific to training that inference-only sparse methods don’t face. An inference-time sparse method is evaluated only against its dense backbone and it is at most as good as that backbone. A training-time sparse method faces a harder test: once training is done, will the resulting weights still produce a competent dense-attention model at inference? Lighthouse treats that question as its central correctness criterion.
Lighthouse takes a different approach on both design decisions. It pools queries, keys, and values symmetrically across a multi-level pyramid, and it places selection entirely outside the attention kernel. After selection, the system gathers the chosen entries into a contiguous, dense sub-sequence and runs stock FlashAttention on it — the same kernel used by the dense baseline.

How the four-stage pipeline works
A Lighthouse attention layer wraps around, but does not modify, scaled dot-product attention. The pipeline has four stages.
In the first stage, average pooling constructs an L-level pyramid from Q, K, and V. With pooling factor p, level ℓ of the pyramid has N/p^ℓ tokens, each summarizing p^ℓ base positions. Crucially, the same pooling applies to all three projections, producing coherent (Q^(ℓ), K^(ℓ), V^(ℓ)) triples at every level. Total pyramid construction costs Θ(N) time and memory.
In the second stage, a parameter-free scorer assigns each pyramid entry two scalar scores using per-head ℓ₂ norms: one as a query score (∥Q^(ℓ)_i∥₂) and one as a key score (∥K^(ℓ)_i∥₂). Coarser levels inherit scores from finer ones via max-pooling, so a coarse span picks up the importance of its strongest token. A fused chunked-bitonic top-K kernel then selects k entries jointly across all pyramid levels. One design detail worth noting: the coarsest pyramid level is always retained in full — it is cheap and guarantees at least one contributor at every base position; the remaining selection budget is spent on finer levels. Additionally, the chunked-bitonic design produces a stratified top-K rather than a strict global top-K: the score stream is partitioned into fixed-size chunks, each maintaining an in-register top-m buffer, so if the k globally highest-scoring entries clustered in one chunk, some would be replaced by lower-scoring entries from other chunks. The result is more balanced attention coverage across the sequence and avoids selection collapse onto a narrow span.
The top-K step is discrete and non-differentiable — no straight-through estimator, no Gumbel softmax. Selection indices carry no gradient. Gradients flow only through the gathered Q, K, V entries into WQ, WK, WV, so the projections learn to produce values that are useful when selected rather than scores that are good at selecting.
In the third stage, the selected entries are gathered into a contiguous sub-sequence of length S = N/p^(L−1) + (L−1)·p·k and passed to standard FlashAttention. At N = 1,000,000 with L = 4, p = 4, k = 4,096, S ≈ 65,000 — far smaller than N. A critical property of the gathering process is that it guarantees no “holes” or empty spaces in the assembled sub-sequence. This matters specifically because Lighthouse also compresses queries: a gap in the sequence would mean those missing tokens have no gradient path during the backward pass and could cause training instabilities. Asymmetric methods that leave queries at full resolution don’t face this problem, but Lighthouse’s symmetric design requires that the gathered sub-sequence remains fully dense.
In the fourth stage, each output entry is scattered back to the p^ℓ base positions it represents via a deterministic integer-atomic scatter kernel, with a shift of p^ℓ − 1 to preserve causality. The per-position fan-in is bounded by L regardless of k.

Why symmetric pooling changes the compute
Pooling queries alongside keys and values changes the computational character of the attention call from O(N Sd) to O(S² d) at training time. Because S ≪ N at long contexts, this is what produces the latency advantage. Benchmarked on a single NVIDIA B200 at 512K context (bfloat16, B=1, H=8, head dimension 128, L=3, p=4, sparsity ≈ 1:64), Lighthouse is 21× faster on the forward pass and 17.3× faster on the combined forward+backward pass relative to cuDNN-backed SDPA.
From an asymptotic standpoint, setting L = logp(N/k) gives a gathered sub-sequence size of S = Θ(k log N), which makes the dense FlashAttention call cost Θ(k² log² N d) — polylogarithmic in N at fixed k. Combined with the linear-cost stages (pyramid construction, scoring, scatter-back), total per-layer compute is Θ(T d) at bounded k — the same asymptotic class as linear attention and SSMs — while preserving softmax attention’s recall properties on the selected sub-sequence.
Inference is a different constraint. Autoregressive decoding presents one query at a time, which violates the assumption that all queries co-occur in one forward pass. Lighthouse is a training-only method, and the symmetric pooling design cannot be used directly at inference.
The two-stage training recipe and recoverability
The experimental setup used a 530M-parameter Llama-3-style decoder (dmodel=1024, 30 layers, 8 heads, head dimension 128, FFN width 1536, byte-level tokenizer), trained on C4 at 98,304-token context with AdamW at learning rate 2×10⁻³, β1=0.9, β2=0.95, weight decay 0.1, linear warmup over 2k steps, gradient-norm clip 1, bfloat16, and FSDP. One implementation detail that matters for practitioners: of the 30 layers, layers {0, 1, 28, 29} retain dense SDPA throughout — only the other 26 layers use Lighthouse. The inner attention call within those 26 Lighthouse layers uses the same cuDNN-backed SDPA kernel as the dense baseline.
The training aproach is two-stage. Stage 1 trains with Lighthouse selection enabled for the majority of the step budget. Stage 2 resumes the Stage 1 checkpoint under dense SDPA (same optimizer state, same dataloader) for a short tail. If Stage 1 had hollowed out the model’s dense-attention capability, Stage 2 recovery would fail.
It doesn’t fail. Testing at a total budget of 16,000 steps (~50.3B tokens), three split points (10k+6k, 11k+5k, 12k+4k) were evaluated against a dense-from-scratch SDPA baseline. At each resume point the training loss spikes transiently by 1.12–1.57 nats as the model is first run through attention it was not trained against, then recovers within approximately 1,000–1,500 SDPA steps and crosses below the dense baseline. By step 16,000, all three resumed Lighthouse runs reach final losses of 0.6980–0.7102, against the dense baseline’s 0.7237, while spending 22.5h to 27.0h wall-clock compared to 37.9h for dense-SDPA-from-scratch on the same token budget.
Ablations and throughput
The full ablation grid covers scorer type, pooling factor p, number of pyramid levels L, and top-K budget k. Key findings: the projection-norm scorer is within ~0.01 of the dilated softmax-attention scorer in either direction (no uniform winner) but is roughly 9% cheaper in B200-hours, since it skips the attention pass over the pyramid entirely. Shallower pyramids (L=3) consistently outperform deeper ones (L=4, L=5) at matched budgets. Smaller k values produce lower post-resume loss within the tested range — the lowest-loss configuration across the grid is L=3, p=2, k=1536 with the dilated scorer, reaching a final loss of 0.6825 — a counter-intuitive result the research teams attribute to hierarchical selection acting as a regularizer at this token budget scale.
Stage-1 throughput across the ablation grid ranges from 84,000 to 126,000 tokens/s/GPU against approximately 46,000 for dense SDPA. The projection-norm scorer at L=3, p=4, k=1536 tops the range at 126,000 tokens/s/GPU by skipping the dilated-attention pass entirely.
Long-context retrieval
To complement the loss-based recoverability results, the research team ran a simplified Needle-in-a-Haystack (NIAH) evaluation: a single passkey digit hidden in random alphanumeric filler at depths of 0–100% across context lengths of 4K to 96K tokens, with retrieval scored as a one-token argmax over the ten digit tokens (random chance: 10%). Four Lighthouse configurations (varying k ∈ {1536, 2048} and scorer ∈ {dilated, norm} at L=3, p=4) were tested against the dense-SDPA-from-scratch baseline. Three of four Lighthouse runs match or beat the dense baseline’s mean retrieval rate of 0.72: k=2048 dilated reaches 0.76, k=1536 dilated reaches 0.73, and k=2048 norm matches the baseline at 0.72. Only k=1536 norm dips, to 0.65. A pattern emerges across the grid: larger k is the dominant axis for retrieval performance, and the norm scorer hurts retrieval more than it hurts training loss at the same k. The practical implication is that the optimal configuration depends on whether the downstream task is loss-driven or retrieval-driven.
Context parallelism scaling
For sequences beyond ~100K tokens, Lighthouse runs under context parallelism (CP). Pyramid pooling, scoring, and top-K run shard-locally on each rank with no inter-rank communication, since the coarsest pool window (e.g., 64 tokens) is orders of magnitude smaller than the shard size. The gathered sub-sequence is dense, so it participates in standard ring attention without sparse-aware collectives — something sparse-index-based methods cannot do without engineering specific to the sparse layout. Context parallelism introduces approximately 10% per-rank throughput overhead from ring rotation, but the Lighthouse vs. SDPA speedup ratio is preserved. The method scales to 1M-token training across 32 Blackwell GPUs (4 nodes, CP degree 8) with no changes to the inner attention kernel.
Marktechpost’s Visual Explainer
Lighthouse Attention
Nous Research — arXiv:2605.06554Key Takeaways
- Nous Research's Lighthouse Attention pools Q, K, and V symmetrically across a multi-level pyramid — unlike NSA and HISA which only pool K and V — cutting the attention call from O(N S d) to O(S² d) and making the expensive step stock FlashAttention on a small dense sub-sequence.
- It's a training-only method: a brief dense-SDPA resumption at the end converts the checkpoint into a normal full-attention model that matches or beats dense-from-scratch at the same token budget (final loss 0.6980–0.7102 vs. 0.7237 baseline, 16k steps, ~50.3B tokens).
- At 512K context on a single B200, Lighthouse is 21× faster on the forward pass and 17.3× faster on forward+backward vs. cuDNN SDPA — translating to a 1.40×–1.69× end-to-end pretraining wall-clock speedup.
- The top-K selection step is deliberately non-differentiable — no straight-through estimator, no Gumbel softmax — so projection matrices learn to produce values that are useful when selected, not to game a learnable scorer.
- Scales to 1M-token training across 32 Blackwell GPUs (4 nodes, CP degree 8) under context parallelism with no changes to the inner attention kernel, because the gathered sub-sequence is dense and participates in standard ring attention.
Check out the Paper, GitHub Repo and Technical details. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post Nous Research Proposes Lighthouse Attention: A Training-Only Selection-Based Hierarchical Attention That Delivers 1.4–1.7× Pretraining Speedup at Long Context appeared first on MarkTechPost.
from MarkTechPost https://ift.tt/javETdy
via IFTTT

Comments
Post a Comment