Feat/sampler memory efficient logits #525
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Sampler: Implement Memory-Efficient Top-K Logit Storage in
SamplingStateSummary
Implements efficient storage of top-k logits and their corresponding indices in the
SamplingStateto enable detailed analysis of model predictions without the substantial memory overhead of storing full vocabulary logits.Problem Statement
The previous implementation of
SamplingStatecontained apredicted_logitsfield, typed asFloat['B max_out_length V'], which was intended to store the full probability distribution at each sampling step.Given typical Large Language Model (LLM) vocabulary sizes (e.g., 256k tokens), storing full logits for every step incurs prohibitive memory costs:
Batch Size (B) * Sequence Length (L) * Vocabulary Size (V) * 4 bytes (Float32)1 * 2048 * 256000 * 4 ≈ 2.1 GBof memory per sample.This excessive memory consumption led to the functionality being commented out (
TODO(epot): Better way to extract logits. This takes a lot of memory.), leaving users with no access to the model's confidence scores or alternative token candidates during generation.Solution
This PR introduces a memory-efficient alternative by storing only the top-k logits and their indices. This drastically reduces memory usage while preserving the most critical information required for analysis (e.g., probability mass coverage, alternative high-probability candidates).
Key Changes
Schema Update:
predicted_logitsfromSamplingState.predicted_top_k_values: Float['B max_out_length 10'].predicted_top_k_indices: Int['B max_out_length 10'].k=10was chosen as a reasonable default balancing inspection depth and memory efficiency.Logic Update:
SamplerLoop._sample_step, we invokejax.lax.top_k(logits, 10)immediately after logit computation.SamplingStatebuffers at the current stepstate.step.Impact Analysis
O(B * L * V)toO(B * L * k).k=10, this represents a reduction factor ofV / 10 ≈ 25,600x.top_koperation on the accelerator (TPU/GPU) is negligible compared to the matrix multiplication cost of the forward pass.Verification
SamplingStatestructure and_sample_steplogic are type-safe and consistent with Flax/JAX patterns.verify_sampler_logits.py(mocking JAX/Flax dependencies) to confirm:SamplingStateclass correctly exposes the newpredicted_top_k_valuesandpredicted_top_k_indicesfields._sample_stepcorrectly computes and assigns top-k values.