Skip to content

[Feature Request] Extend tensor_utils.py with Parallel Scan Operations for Sequence Parallelism #61

@vibeswithkk

Description

@vibeswithkk

Problem Statement

Currently, src/neuronx_distributed/utils/tensor_utils.py only provides a basic cumsum function which:

  • Only supports dim=0
  • Uses a sequential chunked matrix multiplication approach
  • Has O(n) depth complexity, limiting parallelization potential

As sequence lengths grow (e.g., 128K+ context windows in modern LLMs), and with the increasing adoption of Sequence Parallelism (SP) and Context Parallelism (CP), there's a need for efficient parallel scan primitives that can leverage the parallel nature of accelerators.

Proposed Solution

Extend tensor_utils.py with a family of work-efficient parallel scan operations based on the Blelloch algorithm:

1. parallel_inclusive_scan(tensor, dim, op='sum')

  • Blelloch-style parallel prefix sum
  • Support for multiple dimensions (not just dim=0)
  • Work-efficient: O(n) work with O(log n) depth

2. parallel_exclusive_scan(tensor, dim, op='sum')

  • Prefix sum excluding the current element
  • Critical for operations like token routing in MoE
  • Used in expert assignment computations

3. segmented_scan(tensor, segment_ids, dim, op='sum') (optional extension)

  • Scan with segment boundaries
  • Enables variable-length sequence batching
  • Supports packed sequence formats

Technical Background

Blelloch Algorithm Complexity

Metric Sequential (Current) Blelloch (Proposed)
Work O(n) O(n)
Depth O(n) O(log n)
Parallelism O(1) O(n / log n)

The algorithm consists of two phases:

  1. Up-Sweep (Reduce): Build a reduction tree in O(log n) parallel steps
  2. Down-Sweep (Distribute): Propagate prefix sums in O(log n) parallel steps

Algorithm Visualization (n=8)

UP-SWEEP:
[a₀] [a₁] [a₂] [a₃] [a₄] [a₅] [a₆] [a₇]
  ↘   ↙     ↘   ↙     ↘   ↙     ↘   ↙    (Level 0)
 [a₀₁]     [a₂₃]     [a₄₅]     [a₆₇]
     ↘       ↙           ↘       ↙       (Level 1)
      [a₀₋₃]               [a₄₋₇]
           ↘               ↙              (Level 2)
                 [a₀₋₇]  ← Total sum

DOWN-SWEEP:
Set last to 0, then propagate prefix sums back down
→ Result: Exclusive scan [0, a₀, a₀₁, a₀₂, ...]

Use Cases in NxD

Use Case Current Approach With Parallel Scan
MoE Token Routing Sequential cumsum O(log n) parallel depth
Cumulative Attention Masks Matrix mult O(n²) O(n) work, O(log n) depth
Context Parallelism Batching Sequential slicing Parallelizable
Distributed Softmax Multiple collectives Optimized partial sums

Relevant Code Paths

  • modules/moe/routing.py - Token assignment computations
  • utils/batch_utils.py - Sequence slicing for CP
  • parallel_layers/loss_functions.py - Distributed softmax normalization

Implementation Plan

Phase 1: Core Local Implementation

  • Implement parallel_inclusive_scan and parallel_exclusive_scan
  • Pure PyTorch implementation (no custom kernels initially)
  • Support multiple reduction operations (sum, prod, max, min)
  • Handle non-power-of-2 tensor sizes

Phase 2: Extended Functionality

  • Add support for arbitrary dimensions
  • Implement segmented_scan for variable-length sequences
  • Add numerical precision options (fp32 accumulation for bf16 inputs)

Phase 3: Distributed Extension (future)

  • Extend to work across SP/TP groups using existing collective primitives

Testing Strategy

  • Unit tests with CPU mode (NXD_CPU_MODE=1)
  • Correctness validation against torch.cumsum
  • Parameterized tests for various tensor shapes, dtypes, and dimensions
  • Numerical stability tests for different precision modes

Prototype Implementation

I have prepared a prototype implementation. Here's a sketch:

def blelloch_exclusive_scan(x: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """
    Blelloch work-efficient parallel exclusive scan.
    
    Args:
        x: Input tensor
        dim: Dimension to scan over
        
    Returns:
        Exclusive prefix sum along specified dimension
        
    Complexity:
        Work: O(n)
        Depth: O(log n)
    """
    x = x.movedim(dim, 0).clone()
    n = x.shape[0]
    log_n = int(math.ceil(math.log2(n))) if n > 1 else 1
    n_padded = 1 << log_n
    
    # Pad to power of 2 if needed
    if n_padded > n:
        x = torch.cat([x, torch.zeros(n_padded - n, *x.shape[1:], 
                                       dtype=x.dtype, device=x.device)])
    
    # Up-sweep (reduce phase)
    for d in range(log_n):
        stride = 1 << (d + 1)
        right_idx = torch.arange(stride - 1, n_padded, stride)
        left_idx = right_idx - (1 << d)
        x[right_idx] = x[right_idx] + x[left_idx]
    
    # Down-sweep (distribute phase)
    x[-1] = 0
    for d in range(log_n - 1, -1, -1):
        stride = 1 << (d + 1)
        right_idx = torch.arange(stride - 1, n_padded, stride)
        left_idx = right_idx - (1 << d)
        temp = x[left_idx].clone()
        x[left_idx] = x[right_idx]
        x[right_idx] = x[right_idx] + temp
    
    return x[:n].movedim(0, dim)

Additional Context

I'm interested in contributing this feature and have prepared:

  • Detailed mathematical derivation of the algorithm
  • Correctness proofs for both phases
  • Complexity analysis
  • Prototype implementation

I have experience with distributed PyTorch (Data Parallelism with CUDA) and mathematical algorithm optimization. Happy to discuss the design approach and address any concerns before proceeding with the full implementation.

References

  1. Blelloch, G. E. (1990). "Prefix Sums and Their Applications". CMU Technical Report.
  2. Harris, M. (2007). "Parallel Prefix Sum (Scan) with CUDA". GPU Gems 3, Chapter 39.
  3. Sengupta, S. et al. (2008). "Efficient Parallel Scan Algorithms for GPUs". NVIDIA Technical Report.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions