Add beam search, diverse beam search, and beam sampling to generate()#1539
Open
justsml wants to merge 3 commits intohuggingface:mainfrom
Open
Add beam search, diverse beam search, and beam sampling to generate()#1539justsml wants to merge 3 commits intohuggingface:mainfrom
justsml wants to merge 3 commits intohuggingface:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements beam search decoding in
generate(), resolving the long-standingTODO: Support beam searchat the core of the generation loop. This covers all three beam-based decoding strategies:num_beams > 1)num_beam_groups > 1)do_sample = true, num_beams > 1)Both encoder-decoder (e.g., T5, BART, Whisper) and decoder-only (e.g., LLaMA, GPT-2) architectures are supported.
Motivation
Beam search is the default decoding strategy for many seq2seq tasks (translation, summarization, ASR) and is required by models whose configs ship with
num_beams > 1. Until now, the generation loop silently ignored all but the first sampled token (breakon line 1162), makingnum_beamsa no-op. This PR makes the existingGenerationConfigparameters (num_beams,num_beam_groups,diversity_penalty,length_penalty,early_stopping,num_return_sequences,do_sample) functional.Design
Rather than adding separate decoding methods (which would duplicate ~80 lines of setup logic), all three strategies are integrated inline into the existing generation loop, consistent with the original architecture comment: "Generic search which handles 4 generation modes."
Architecture
Key decisions
Bypass
BeamSearchSampler, score directly. The existingBeamSearchSamplerreturns only topnum_beamstokens per beam. Proper beam search requires considering allnum_beams × vocab_sizecandidates per batch item to find the globally best continuations. We computelog_softmax+ cumulative beam scores directly, then take top candidates with a two-level sort.Group-sequential diversity penalty. For diverse beam search, groups are processed sequentially within each step. A token-count map accumulates tokens selected by earlier groups, and later groups subtract
diversity_penalty × countfrom those token scores. This ensures each group explores distinct regions of the output space.Beam sampling via direct top-k + softmax. Rather than routing through the
MultinomialSampler, beam sampling applies top-k filtering inline, computes softmax, and usessampler.randomSelect()to drawnum_beamsstochastic samples per beam. The combined score isbeam_score + log(prob), maintaining proper cumulative log-probability tracking.KV cache reordering via
index_select. When beams are reordered, past key-value tensors are reindexed along the batch dimension. Addedindex_select(sync, CPU) andindex_select_async(GPU download fallback) totensor.js._reorder_cache()passes encoder PKVs through unchanged and disposes stale decoder PKV tensors.What's supported
num_beams > 1)num_beam_groups > 1,diversity_penalty)do_sample = true, num_beams > 1)length_penalty(exponential length normalization)early_stopping(true/false/"never")num_return_sequences(up tonum_beams)num_beam_groups > 1+do_sample)Changes
New:
src/generation/beam_search.jsBeamHypotheses— per-batch-element bounded priority queue of completed hypotheses, scored assum_logprobs / (length ^ length_penalty). Implements threeearly_stoppingmodes foris_done().BeamSearchScorer— orchestrates beam state across the batch.process()routes EOS-terminated beams to hypotheses and selects continuing beams.finalize()/finalize_with_scores()returns top hypotheses per batch item. The_with_scoresvariant exposes normalized scores for cross-group ranking in diverse beam search.New:
index_select/index_select_asyncinsrc/utils/tensor.jsRow-level gather along dimension 0. The sync variant operates on CPU typed arrays; the async variant calls
ort_tensor.getData(true)for GPU/ML-tensor locations.Modified:
src/models/modeling_utils.jsnum_beam_groupsdivisibility,num_return_sequences <= num_beams, incompatible combos (CFG, streaming, diverse + sampling).BeamSearchScorerper group for diverse mode), beam score initialization (one live beam per group).finalize_with_scores()and re-ranks globally; standard/sampling modes usefinalize()._reorder_cache(): async method that reindexes decoder PKVs, passes encoder PKVs through, disposes stale GPU tensors.Modified:
src/transformers.jsExports
BeamHypothesesandBeamSearchScorer.Test plan
num_beam_groups=2, diversity_penalty=0.5) — both architecturesdo_sample=true, top_k=10) — both architecturesnum_return_sequences > 1output shape[N, seq_len]— both architecturesnum_return_sequences > num_beamsthrows