-
Notifications
You must be signed in to change notification settings - Fork 20
Description
Problem Statement
Currently, src/neuronx_distributed/utils/tensor_utils.py only provides a basic cumsum function which:
- Only supports
dim=0 - Uses a sequential chunked matrix multiplication approach
- Has O(n) depth complexity, limiting parallelization potential
As sequence lengths grow (e.g., 128K+ context windows in modern LLMs), and with the increasing adoption of Sequence Parallelism (SP) and Context Parallelism (CP), there's a need for efficient parallel scan primitives that can leverage the parallel nature of accelerators.
Proposed Solution
Extend tensor_utils.py with a family of work-efficient parallel scan operations based on the Blelloch algorithm:
1. parallel_inclusive_scan(tensor, dim, op='sum')
- Blelloch-style parallel prefix sum
- Support for multiple dimensions (not just dim=0)
- Work-efficient: O(n) work with O(log n) depth
2. parallel_exclusive_scan(tensor, dim, op='sum')
- Prefix sum excluding the current element
- Critical for operations like token routing in MoE
- Used in expert assignment computations
3. segmented_scan(tensor, segment_ids, dim, op='sum') (optional extension)
- Scan with segment boundaries
- Enables variable-length sequence batching
- Supports packed sequence formats
Technical Background
Blelloch Algorithm Complexity
| Metric | Sequential (Current) | Blelloch (Proposed) |
|---|---|---|
| Work | O(n) | O(n) |
| Depth | O(n) | O(log n) |
| Parallelism | O(1) | O(n / log n) |
The algorithm consists of two phases:
- Up-Sweep (Reduce): Build a reduction tree in O(log n) parallel steps
- Down-Sweep (Distribute): Propagate prefix sums in O(log n) parallel steps
Algorithm Visualization (n=8)
UP-SWEEP:
[a₀] [a₁] [a₂] [a₃] [a₄] [a₅] [a₆] [a₇]
↘ ↙ ↘ ↙ ↘ ↙ ↘ ↙ (Level 0)
[a₀₁] [a₂₃] [a₄₅] [a₆₇]
↘ ↙ ↘ ↙ (Level 1)
[a₀₋₃] [a₄₋₇]
↘ ↙ (Level 2)
[a₀₋₇] ← Total sum
DOWN-SWEEP:
Set last to 0, then propagate prefix sums back down
→ Result: Exclusive scan [0, a₀, a₀₁, a₀₂, ...]
Use Cases in NxD
| Use Case | Current Approach | With Parallel Scan |
|---|---|---|
| MoE Token Routing | Sequential cumsum | O(log n) parallel depth |
| Cumulative Attention Masks | Matrix mult O(n²) | O(n) work, O(log n) depth |
| Context Parallelism Batching | Sequential slicing | Parallelizable |
| Distributed Softmax | Multiple collectives | Optimized partial sums |
Relevant Code Paths
modules/moe/routing.py- Token assignment computationsutils/batch_utils.py- Sequence slicing for CPparallel_layers/loss_functions.py- Distributed softmax normalization
Implementation Plan
Phase 1: Core Local Implementation
- Implement
parallel_inclusive_scanandparallel_exclusive_scan - Pure PyTorch implementation (no custom kernels initially)
- Support multiple reduction operations (sum, prod, max, min)
- Handle non-power-of-2 tensor sizes
Phase 2: Extended Functionality
- Add support for arbitrary dimensions
- Implement
segmented_scanfor variable-length sequences - Add numerical precision options (fp32 accumulation for bf16 inputs)
Phase 3: Distributed Extension (future)
- Extend to work across SP/TP groups using existing collective primitives
Testing Strategy
- Unit tests with CPU mode (
NXD_CPU_MODE=1) - Correctness validation against
torch.cumsum - Parameterized tests for various tensor shapes, dtypes, and dimensions
- Numerical stability tests for different precision modes
Prototype Implementation
I have prepared a prototype implementation. Here's a sketch:
def blelloch_exclusive_scan(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
Blelloch work-efficient parallel exclusive scan.
Args:
x: Input tensor
dim: Dimension to scan over
Returns:
Exclusive prefix sum along specified dimension
Complexity:
Work: O(n)
Depth: O(log n)
"""
x = x.movedim(dim, 0).clone()
n = x.shape[0]
log_n = int(math.ceil(math.log2(n))) if n > 1 else 1
n_padded = 1 << log_n
# Pad to power of 2 if needed
if n_padded > n:
x = torch.cat([x, torch.zeros(n_padded - n, *x.shape[1:],
dtype=x.dtype, device=x.device)])
# Up-sweep (reduce phase)
for d in range(log_n):
stride = 1 << (d + 1)
right_idx = torch.arange(stride - 1, n_padded, stride)
left_idx = right_idx - (1 << d)
x[right_idx] = x[right_idx] + x[left_idx]
# Down-sweep (distribute phase)
x[-1] = 0
for d in range(log_n - 1, -1, -1):
stride = 1 << (d + 1)
right_idx = torch.arange(stride - 1, n_padded, stride)
left_idx = right_idx - (1 << d)
temp = x[left_idx].clone()
x[left_idx] = x[right_idx]
x[right_idx] = x[right_idx] + temp
return x[:n].movedim(0, dim)Additional Context
I'm interested in contributing this feature and have prepared:
- Detailed mathematical derivation of the algorithm
- Correctness proofs for both phases
- Complexity analysis
- Prototype implementation
I have experience with distributed PyTorch (Data Parallelism with CUDA) and mathematical algorithm optimization. Happy to discuss the design approach and address any concerns before proceeding with the full implementation.
References
- Blelloch, G. E. (1990). "Prefix Sums and Their Applications". CMU Technical Report.
- Harris, M. (2007). "Parallel Prefix Sum (Scan) with CUDA". GPU Gems 3, Chapter 39.
- Sengupta, S. et al. (2008). "Efficient Parallel Scan Algorithms for GPUs". NVIDIA Technical Report.