Skip to content

Comments

Add beam search, diverse beam search, and beam sampling to generate()#1539

Open
justsml wants to merge 3 commits intohuggingface:mainfrom
justsml:dan/add-beam-support
Open

Add beam search, diverse beam search, and beam sampling to generate()#1539
justsml wants to merge 3 commits intohuggingface:mainfrom
justsml:dan/add-beam-support

Conversation

@justsml
Copy link

@justsml justsml commented Feb 20, 2026

Summary

Implements beam search decoding in generate(), resolving the long-standing TODO: Support beam search at the core of the generation loop. This covers all three beam-based decoding strategies:

  1. Beam search — deterministic, globally-optimal beam decoding (num_beams > 1)
  2. Diverse beam search — group-based decoding with a diversity penalty to encourage distinct hypotheses (num_beam_groups > 1)
  3. Beam sampling — stochastic beam decoding that samples candidates rather than taking argmax (do_sample = true, num_beams > 1)

Both encoder-decoder (e.g., T5, BART, Whisper) and decoder-only (e.g., LLaMA, GPT-2) architectures are supported.

Motivation

Beam search is the default decoding strategy for many seq2seq tasks (translation, summarization, ASR) and is required by models whose configs ship with num_beams > 1. Until now, the generation loop silently ignored all but the first sampled token (break on line 1162), making num_beams a no-op. This PR makes the existing GenerationConfig parameters (num_beams, num_beam_groups, diversity_penalty, length_penalty, early_stopping, num_return_sequences, do_sample) functional.

Design

Rather than adding separate decoding methods (which would duplicate ~80 lines of setup logic), all three strategies are integrated inline into the existing generation loop, consistent with the original architecture comment: "Generic search which handles 4 generation modes."

Architecture

generate()
├─ [pre-loop]  if num_beams > 1:
│   ├─ validate config
│   ├─ create BeamSearchScorer (one per group for diverse beam search)
│   ├─ expand inputs × num_beams (interleaved repeat)
│   └─ init beam_scores (one live beam per group, rest = -1e9)
├─ [loop]
│   ├─ forward pass (batch_size × num_beams sequences in parallel)
│   ├─ logits processing (repetition penalty, forced tokens, etc.)
│   ├─ if group_beam_search:
│   │   └─ for each group sequentially:
│   │       ├─ extract group's beams/logits
│   │       ├─ apply diversity penalty from earlier groups' token selections
│   │       ├─ score candidates, select top 2×group_size
│   │       ├─ group's BeamSearchScorer.process()
│   │       └─ record selected tokens for next group's penalty
│   ├─ elif beam_sample (do_sample=true):
│   │   ├─ top-k filter → softmax → stochastic sampling (num_beams samples per beam)
│   │   ├─ score = beam_score + log(sampled_prob)
│   │   └─ BeamSearchScorer.process()
│   ├─ else (standard beam search):
│   │   ├─ log_softmax(processed_logits) + beam_scores → candidate scores
│   │   ├─ top 2×num_beams candidates per batch (across all beams)
│   │   └─ BeamSearchScorer.process()
│   ├─ reorder all_input_ids by beam parent indices
│   ├─ _reorder_cache() → reindex KV cache along batch dim
│   └─ check stopping: all scorers done || max_length || EOS
└─ [post-loop]
    ├─ if group_beam_search: merge hypotheses across groups, rank by score
    └─ select top num_return_sequences per input

Key decisions

Bypass BeamSearchSampler, score directly. The existing BeamSearchSampler returns only top num_beams tokens per beam. Proper beam search requires considering all num_beams × vocab_size candidates per batch item to find the globally best continuations. We compute log_softmax + cumulative beam scores directly, then take top candidates with a two-level sort.

Group-sequential diversity penalty. For diverse beam search, groups are processed sequentially within each step. A token-count map accumulates tokens selected by earlier groups, and later groups subtract diversity_penalty × count from those token scores. This ensures each group explores distinct regions of the output space.

Beam sampling via direct top-k + softmax. Rather than routing through the MultinomialSampler, beam sampling applies top-k filtering inline, computes softmax, and uses sampler.randomSelect() to draw num_beams stochastic samples per beam. The combined score is beam_score + log(prob), maintaining proper cumulative log-probability tracking.

KV cache reordering via index_select. When beams are reordered, past key-value tensors are reindexed along the batch dimension. Added index_select (sync, CPU) and index_select_async (GPU download fallback) to tensor.js. _reorder_cache() passes encoder PKVs through unchanged and disposes stale decoder PKV tensors.

What's supported

Feature Status
Standard beam search (num_beams > 1) Implemented
Diverse beam search (num_beam_groups > 1, diversity_penalty) Implemented
Beam sampling (do_sample = true, num_beams > 1) Implemented
length_penalty (exponential length normalization) Implemented
early_stopping (true / false / "never") Implemented
num_return_sequences (up to num_beams) Implemented
Encoder-decoder models Implemented
Decoder-only models Implemented
Diverse beam sampling (num_beam_groups > 1 + do_sample) Throws (not yet supported)
Classifier-free guidance + beam search Throws (incompatible)
Streaming + beam search Throws (incompatible)

Changes

New: src/generation/beam_search.js

  • BeamHypotheses — per-batch-element bounded priority queue of completed hypotheses, scored as sum_logprobs / (length ^ length_penalty). Implements three early_stopping modes for is_done().
  • BeamSearchScorer — orchestrates beam state across the batch. process() routes EOS-terminated beams to hypotheses and selects continuing beams. finalize() / finalize_with_scores() returns top hypotheses per batch item. The _with_scores variant exposes normalized scores for cross-group ranking in diverse beam search.

New: index_select / index_select_async in src/utils/tensor.js

Row-level gather along dimension 0. The sync variant operates on CPU typed arrays; the async variant calls ort_tensor.getData(true) for GPU/ML-tensor locations.

Modified: src/models/modeling_utils.js

  • Validation: num_beam_groups divisibility, num_return_sequences <= num_beams, incompatible combos (CFG, streaming, diverse + sampling).
  • Pre-loop: input expansion, scorer creation (one BeamSearchScorer per group for diverse mode), beam score initialization (one live beam per group).
  • In-loop: three branches — group beam search (sequential group processing with diversity penalty), beam sampling (top-k + stochastic sampling), standard beam search (deterministic top-k). All share the same KV cache reordering and stopping logic.
  • Post-loop: group beam search merges hypotheses across groups via finalize_with_scores() and re-ranks globally; standard/sampling modes use finalize().
  • _reorder_cache(): async method that reindexes decoder PKVs, passes encoder PKVs through, disposes stale GPU tensors.

Modified: src/transformers.js

Exports BeamHypotheses and BeamSearchScorer.

Test plan

  • Standard beam search — encoder-decoder (T5) and decoder-only (LLaMA)
  • Diverse beam search (num_beam_groups=2, diversity_penalty=0.5) — both architectures
  • Beam sampling (do_sample=true, top_k=10) — both architectures
  • num_return_sequences > 1 output shape [N, seq_len] — both architectures
  • Greedy vs beam search produce valid outputs
  • CFG + beam search throws
  • num_return_sequences > num_beams throws
  • All pre-existing generation tests pass (no regressions)

@justsml justsml changed the title Add beam search support with BeamHypotheses and BeamSearchScorer classes Add beam search decoding to generate() Feb 20, 2026
@justsml justsml marked this pull request as ready for review February 20, 2026 22:11
@justsml justsml changed the title Add beam search decoding to generate() Add beam search, diverse beam search, and beam sampling to generate() Feb 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant