Skip to content

Conversation

@larryliu0820
Copy link
Contributor

Add CudaSampler class that provides a high-level interface for GPU sampling:

  • cuda_sampler.h: Class declaration with sample_argmax() method.
    Pre-allocates GPU memory to avoid allocation in hot path.
  • cuda_sampler.cu: Implementation using the default CUDA stream (nullptr)
    for implicit synchronization with the CUDA backend's stream.

The default stream approach ensures proper ordering between decoder
output and argmax without requiring explicit cross-stream synchronization
or access to the backend's internal stream.

[ghstack-poisoned]
@larryliu0820
Copy link
Contributor Author

larryliu0820 commented Dec 24, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16387

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 3 Unrelated Failures

As of commit 8573c9a with merge base c5d66a5 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 24, 2025
// - The argmax kernel will wait for the decoder to finish writing logits
// - No explicit cudaDeviceSynchronize() or cross-stream synchronization needed
//
// 4. Trade-off: Using the default stream prevents concurrent execution between
Copy link
Contributor

@Gasoonjia Gasoonjia Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im wondering if we really need to make sampler and cuda backend using same cuda stream, since the sampling and decoding should be able to work in parallel: the argmax process of logits_{i} should be able to work with the decoder generating logits_{i+1} since they do not have any dependency, and such parallelism may not happen if argmax and decoder share the same cudastream.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants