Skip to content

Conversation

@shreed27
Copy link

@shreed27 shreed27 commented Feb 1, 2026

Sampler: Implement Memory-Efficient Top-K Logit Storage in SamplingState

Summary

Implements efficient storage of top-k logits and their corresponding indices in the SamplingState to enable detailed analysis of model predictions without the substantial memory overhead of storing full vocabulary logits.

Problem Statement

The previous implementation of SamplingState contained a predicted_logits field, typed as Float['B max_out_length V'], which was intended to store the full probability distribution at each sampling step.

Given typical Large Language Model (LLM) vocabulary sizes (e.g., 256k tokens), storing full logits for every step incurs prohibitive memory costs:

  • Memory Usage: Batch Size (B) * Sequence Length (L) * Vocabulary Size (V) * 4 bytes (Float32)
  • For a single sequence (B=1, L=2048) and V=256k: 1 * 2048 * 256000 * 4 ≈ 2.1 GB of memory per sample.

This excessive memory consumption led to the functionality being commented out (TODO(epot): Better way to extract logits. This takes a lot of memory.), leaving users with no access to the model's confidence scores or alternative token candidates during generation.

Solution

This PR introduces a memory-efficient alternative by storing only the top-k logits and their indices. This drastically reduces memory usage while preserving the most critical information required for analysis (e.g., probability mass coverage, alternative high-probability candidates).

Key Changes

  1. Schema Update:

    • Removed commented-out predicted_logits from SamplingState.
    • Added predicted_top_k_values: Float['B max_out_length 10'].
    • Added predicted_top_k_indices: Int['B max_out_length 10'].
    • Fixed k=10 was chosen as a reasonable default balancing inspection depth and memory efficiency.
  2. Logic Update:

    • In SamplerLoop._sample_step, we invoke jax.lax.top_k(logits, 10) immediately after logit computation.
    • The resulting values and indices are scattered into the SamplingState buffers at the current step state.step.

Impact Analysis

  • Memory Complexity: Reduced from O(B * L * V) to O(B * L * k).
    • For k=10, this represents a reduction factor of V / 10 ≈ 25,600x.
  • Computational Overhead: The addition of top_k operation on the accelerator (TPU/GPU) is negligible compared to the matrix multiplication cost of the forward pass.
  • Usability: Users can now inspect the model's top predictions and confidence levels for every generated token, enabling better debugging and interpretability of the generation process.

Verification

  • Static Analysis: Verified that SamplingState structure and _sample_step logic are type-safe and consistent with Flax/JAX patterns.
  • Mock Verification: Created a standalone script verify_sampler_logits.py (mocking JAX/Flax dependencies) to confirm:
    • The SamplingState class correctly exposes the new predicted_top_k_values and predicted_top_k_indices fields.
    • The storage logic in _sample_step correctly computes and assigns top-k values.

- Implemented  for  model loading to prevent redundant IO and parsing when creating multiple instances.
- Added auto-download capability to : now automatically downloads/copies remote files (e.g., gs://) to the local cache if missing.
- Refactored  to separate model loading logic into standalone cached functions.
- Updated  to verify download behavior.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant