Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 30, 2025

TODO

  • SIMT
  • TMA

Summary by CodeRabbit

  • New Features

    • Remote copy support with per-endpoint src/dst specification for distributed transfers
    • Distributed memory intrinsics and host-visible metadata for tile-scale operations
    • New example demonstrating distributed tile-scale memory copy
    • New helper APIs: symmetric-buffer declaration, tensor-like construction, and tilescale-aware expression rendering
  • Tests

    • Added CUDA tests for distributed copy kernels

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link

coderabbitai bot commented Oct 30, 2025

Walkthrough

Adds remote-PE support and symmetric-buffer wiring for distributed TileLang copies, new tilescale intrinsics (get_remote_base/get_local_base), host-visible metadata, compiler passes to declare symm buffers, JIT/descriptor updates, CUDA codegen changes, and example/tests demonstrating remote tile-scale copy.

Changes

Cohort / File(s) Summary
Examples & tests
examples/distributed/primitives/example_tilescale_copy.py, examples/distributed/primitives/test_tilescale_copy.py
New example demonstrating tile-scale remote copy kernels and new tests that spawn two processes to validate cross-rank copy semantics.
Copy op API & implementation
src/op/copy.h, src/op/copy.cc
Added src_pe, dst_pe, symm_buffer to CopyNode; added is_remote_copy/is_remote_push/is_remote_pull helpers; parse and validate remote copy args and enforce memory/global constraints.
Distributed intrinsics (registration & headers)
src/op/distributed.cc, src/op/distributed.h, src/tl_templates/cuda/distributed.h
Introduced intrinsics get_remote_base, get_local_base, get_local_base_ptr; updated docs/branding to tilescale; exposed host_meta_data and host/device-aware accessors.
CUDA codegen & templates
src/target/codegen_cuda.cc, src/tl_templates/cuda/common.h, src/tl_templates/cuda/sync.h, src/tl_templates/cuda/threadblock_swizzle.h
Emit new intrinsics in CUDA codegen, add host_meta_data alias, introduce TL_HOST/host-device macros, switch get_remote_base_ptr→get_remote_base in sync/templates, and add default offset template parameters.
Symmetric buffer pass & pipeline insertion
src/transform/declare_symm_buffer.cc, tilelang/transform/__init__.py, tilelang/engine/phase.py
New DeclareSymmBuffer transform that computes symmetry pointers and wraps remote copy calls with LetStmt bindings; added pass to LowerAndLegalize pipeline and exported it.
Hopper TMA lowering changes
src/transform/lower_hopper_intrin.cc
Track TMA-related vars and inline let-bound expressions into descriptor construction; added VisitStmt_(LetStmtNode) and related state.
Remote-copy lowering updates
src/op/remote_copy.cc, src/op/sync.cc
Replace uses of get_remote_base_ptr with get_remote_base when computing base offsets for Put/Get/Barrier lowering.
JIT adapter & utils
tilelang/jit/adapter/utils.py, tilelang/jit/adapter/wrapper.py
Added tilescale_pythonic_expr() for PrimExpr→Python string (TileScale-aware); integrated into wrapper, added _tilescale_pythonic_expr helper and host_meta_data initialization; expanded descriptor arg matching for _symm_ symbols.
Language API additions
tilelang/language/copy.py, tilelang/language/proxy.py, tilelang/language/__init__.py
Extended copy() to accept src_pe/dst_pe; added make_tensor_like() to create buffers from existing tensors with optional overrides and exported it.

Sequence Diagram(s)

sequenceDiagram
    participant User as User/Main
    participant Lang as Language API
    participant Transform as Transform Passes
    participant Codegen as CUDA Codegen
    participant Runtime as Runtime (host/device)

    User->>Lang: call copy(src,dst,src_pe,dst_pe)
    Lang->>Lang: emit tl.copy intrinsic with src_pe/dst_pe
    Lang->>Transform: LowerAndLegalize
    Transform->>Transform: DeclareSymmBuffer detects remote push/pull
    alt remote_push
        Transform->>Transform: compute symm_ptr via get_remote_base(dst_pe)
        Transform->>Transform: wrap copy with LetStmt(symm_ptr)
    else remote_pull
        Transform->>Transform: compute symm_ptr via get_remote_base(src_pe)
        Transform->>Transform: wrap copy with LetStmt(symm_ptr)
    end
    Transform->>Codegen: transformed IR
    Codegen->>Codegen: emit tl::get_remote_base / get_local_base calls
    Codegen->>Runtime: produced CUDA + host init (host_meta_data)
    Runtime->>Runtime: host/device branch returns meta_data or host_meta_data
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Review focus areas:
    • src/transform/declare_symm_buffer.cc (IR mutation, correctness of pointer replacement and LetStmt wrapping).
    • src/tl_templates/cuda/distributed.h and src/target/codegen_cuda.cc (host/device guards, pointer casts, host_meta_data exposure).
    • src/transform/lower_hopper_intrin.cc (TMA-related variable tracking and descriptor argument correctness).
    • JIT adapter changes (tilelang/jit/adapter/*) for expression translation and descriptor argument generation.

Possibly related PRs

Suggested reviewers

  • chengyupku
  • tzj-fxz

Poem

🐇 Hopping bytes from PE to PE,

Symm pointers set for all to see.
Tiles align and kernels leap,
Remote copies wake from sleep.
CUDA sings, the rabbits cheer — data's near!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding unified lowering support for T.copy to both SIMT and TMA backends for intra-node copy operations.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch wt/dev

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Rachmanino Rachmanino force-pushed the wt/dev branch 3 times, most recently from 0874b4f to 7aaa99e Compare November 2, 2025 06:58
@Rachmanino Rachmanino requested a review from chengyupku November 2, 2025 07:01
@Rachmanino Rachmanino marked this pull request as ready for review November 2, 2025 07:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/transform/lower_hopper_intrin.cc (1)

154-166: Also mark the im2col base pointer as TMA-related.

We only record the base pointer var for create_tma_descriptor. When the same pattern occurs with create_tma_im2col_descriptor, the hoisted tvm_call_packed still references the Let-bound var, so the var is out of scope and the descriptor init sees a dangling symbol. Please add the same tma_related_vars_ bookkeeping in the im2col branch so its base pointer gets inlined as well.

     } else if (call->op.same_as(create_tma_im2col_descriptor())) {
@@
-        prefetch_calls_.push_back(
-            Evaluate(Call(DataType::Handle(), builtin::call_extern(),
-                          {StringImm("tl::prefetch_tma_descriptor"), var})));
+        prefetch_calls_.push_back(
+            Evaluate(Call(DataType::Handle(), builtin::call_extern(),
+                          {StringImm("tl::prefetch_tma_descriptor"), var})));
+        if (auto base_var = call->args[2].as<Var>()) {
+          tma_related_vars_.insert(base_var.value());
+        }
       }
src/op/copy.cc (1)

133-176: Initialize remote PE defaults before calling is_remote_*.

If any existing tl.copy call site still uses the legacy 5-argument form, args.size() stays ≤5 so src_pe/dst_pe remain undefined. The subsequent is_remote_push()/is_remote_pull() immediately dereference those unset PrimExprs, triggering an ICHECK failure. Defaulting them to Integer(-1) (and then overriding when extra args are present) keeps the legacy path safe.

   ObjectPtr<CopyNode> node = make_object<CopyNode>();
+  node->src_pe = Integer(-1);
+  node->dst_pe = Integer(-1);
   Array<Range> rgs[2];
   Buffer bf[2];
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fae99e9 and 7aaa99e.

📒 Files selected for processing (17)
  • examples/distributed/primitives/example_tilescale_copy.py (1 hunks)
  • src/op/copy.cc (2 hunks)
  • src/op/copy.h (1 hunks)
  • src/op/distributed.cc (1 hunks)
  • src/op/distributed.h (1 hunks)
  • src/target/codegen_cuda.cc (2 hunks)
  • src/tl_templates/cuda/common.h (1 hunks)
  • src/tl_templates/cuda/distributed.h (1 hunks)
  • src/transform/declare_symm_buffer.cc (1 hunks)
  • src/transform/lower_hopper_intrin.cc (4 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/jit/adapter/utils.py (1 hunks)
  • tilelang/jit/adapter/wrapper.py (5 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/copy.py (2 hunks)
  • tilelang/language/proxy.py (2 hunks)
  • tilelang/transform/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (12)
tilelang/engine/phase.py (2)
src/transform/declare_symm_buffer.cc (2)
  • DeclareSymmBuffer (220-227)
  • DeclareSymmBuffer (220-220)
tilelang/transform/__init__.py (1)
  • DeclareSymmBuffer (492-498)
src/transform/declare_symm_buffer.cc (5)
tilelang/language/proxy.py (1)
  • ptr (273-295)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
src/tl_templates/cuda/distributed.h (1)
  • get_uintptr_t (47-49)
tilelang/language/ast/ir.py (1)
  • LetStmt (880-908)
tilelang/transform/__init__.py (1)
  • DeclareSymmBuffer (492-498)
examples/distributed/primitives/example_tilescale_copy.py (8)
src/tl_templates/cuda/reduce.h (1)
  • T (208-280)
tilelang/distributed/utils.py (1)
  • init_dist (40-62)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
tilelang/language/copy.py (1)
  • copy (11-98)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/utils/allocator.py (1)
  • get_allocator (236-248)
tilelang/jit/kernel.py (1)
  • initialize (407-416)
tilelang/utils/tensor.py (1)
  • tensor (43-56)
src/op/copy.cc (2)
src/op/copy.h (1)
  • is_remote_copy (100-124)
tilelang/language/copy.py (1)
  • copy (11-98)
src/tl_templates/cuda/distributed.h (2)
src/op/distributed.h (1)
  • tl (11-238)
tilelang/language/distributed/common.py (2)
  • get_rank (8-11)
  • get_num_ranks (14-17)
tilelang/language/__init__.py (1)
tilelang/language/proxy.py (1)
  • make_tensor_like (305-313)
src/op/copy.h (1)
src/op/copy.cc (6)
  • is_remote_copy (190-192)
  • is_remote_copy (190-190)
  • is_remote_push (180-183)
  • is_remote_push (180-180)
  • is_remote_pull (185-188)
  • is_remote_pull (185-185)
tilelang/transform/__init__.py (1)
src/transform/declare_symm_buffer.cc (2)
  • DeclareSymmBuffer (220-227)
  • DeclareSymmBuffer (220-220)
tilelang/language/copy.py (2)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (124-146)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
tilelang/jit/adapter/utils.py (1)
src/op/distributed.h (1)
  • tvm (10-239)
tilelang/jit/adapter/wrapper.py (2)
tilelang/jit/adapter/utils.py (2)
  • pythonic_expr (110-218)
  • tilescale_pythonic_expr (221-346)
src/op/distributed.h (1)
  • tvm (10-239)
src/transform/lower_hopper_intrin.cc (4)
tilelang/language/ast/ir.py (1)
  • LetStmt (880-908)
tilelang/language/builtin.py (1)
  • create_tma_descriptor (74-83)
src/op/builtin.cc (2)
  • cuTensorMapType (39-39)
  • cuTensorMapType (39-39)
tilelang/language/tir/op.py (1)
  • call_extern (173-195)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc

[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found

(clang-diagnostic-error)

🪛 GitHub Actions: CI
tilelang/language/proxy.py

[error] 307-309: UP045 Use X | None for type annotations in function signature: 'shape: Optional[tuple[PrimExpr, ...]]', 'dtype: Optional[str]', 'strides: Optional[tuple[PrimExpr, ...]]'.


[error] 308-309: UP045 Use X | None for type annotations in function signature: 'dtype' and 'strides'.


[error] 307-309: Ruff check: Multiple UP045 suggestions; convert annotations to use '| None' where appropriate.

tilelang/engine/phase.py

[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.

tilelang/language/__init__.py

[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.

tilelang/transform/__init__.py

[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.

tilelang/language/copy.py

[error] 14-15: UP045 Use X | None for type annotations in 'src_pe' and 'dst_pe' parameters.


[error] 30-30: W293 Blank line contains whitespace in docstring/header area.


[error] 14-16: Ruff check: Use X | None for type annotations (UP045) and related fixes suggested by the linter.

tilelang/jit/adapter/utils.py

[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.

tilelang/jit/adapter/wrapper.py

[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.

🪛 Ruff (0.14.2)
examples/distributed/primitives/example_tilescale_copy.py

15-15: Unused function argument: rank

(ARG001)


182-182: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/__init__.py

19-19: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

* on peer's symm heap
*/

#include <tvm/ffi/reflection/registry.h>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the registry header include
<tvm/ffi/reflection/registry.h> does not exist, so this include breaks the build. Please pull in the registry helpers from <tvm/runtime/registry.h> instead.

-#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/registry.h>
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/registry.h>
🧰 Tools
🪛 Clang (14.0.6)

[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found

(clang-diagnostic-error)

🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around line 9, the include currently
references a non-existent header <tvm/ffi/reflection/registry.h>; replace it
with the correct registry header <tvm/runtime/registry.h> so the file pulls in
the registry helpers used by the code and the build can succeed.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
tilelang/language/copy.py (2)

14-15: Fix type annotations to match the default values.

The type hints specify tir.PrimExpr | tir.IntImm | None but the default value is -1 (a Python int), creating a type mismatch.

Apply this diff to include int in the type union:

-    src_pe: tir.PrimExpr | tir.IntImm | None = -1,
-    dst_pe: tir.PrimExpr | tir.IntImm | None = -1,
+    src_pe: tir.PrimExpr | tir.IntImm | int | None = -1,
+    dst_pe: tir.PrimExpr | tir.IntImm | int | None = -1,

Alternatively, change the defaults to None and normalize inside the function:

-    src_pe: tir.PrimExpr | tir.IntImm | None = -1,
-    dst_pe: tir.PrimExpr | tir.IntImm | None = -1,
+    src_pe: tir.PrimExpr | tir.IntImm | None = None,
+    dst_pe: tir.PrimExpr | tir.IntImm | None = None,
...
+    if src_pe is None:
+        src_pe = -1
+    if dst_pe is None:
+        dst_pe = -1

95-95: Assertion breaks with symbolic PrimExpr arguments.

The assertion src_pe == -1 or dst_pe == -1 attempts to evaluate symbolic PrimExpr values in a boolean context (e.g., src_pe=1 - T.get_rank() as used in the TMA example), which raises TypeError: Cannot convert PrimExpr to bool at runtime.

Apply this diff to check only constant values:

-    assert src_pe == -1 or dst_pe == -1, "At least one of src_pe or dst_pe must be local rank"
+    def _is_const_remote(pe):
+        if isinstance(pe, int):
+            return pe != -1
+        if isinstance(pe, tir.IntImm):
+            return pe.value != -1
+        return False
+
+    if _is_const_remote(src_pe) and _is_const_remote(dst_pe):
+        raise ValueError("At least one of src_pe or dst_pe must be local rank")
examples/distributed/primitives/example_tilescale_copy.py (1)

15-15: Remove the unused rank parameter.

The rank parameter is never referenced in the function body. While the comment at line 158 suggests it's for TMA kernels, those implementations use T.get_rank() directly (lines 98, 128) instead of a compile-time constant.

Apply this diff:

-def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile', rank=None):
+def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'):

And update the call site at line 158:

     kernel = get_kernel(
-        M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel,
-        rank=local_rank)  # only TMA kernels need compile-time aware peer rank
+        M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel)
🧹 Nitpick comments (1)
examples/distributed/primitives/example_tilescale_copy.py (1)

163-183: Validation logic correctly verifies peer-to-peer copy.

The test correctly:

  • Allocates tensors on the distributed allocator
  • Synchronizes before and after kernel execution
  • Uses all_gather to collect each rank's source tensor
  • Validates that local dst matches the peer's src using local_rank ^ 1 indexing

Minor: Line 181 triggers a style warning (TRY003) for the inline exception message. Consider defining a custom exception or using a simpler message:

-        raise ValueError("Test failed")
+        raise ValueError
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 11c8336 and 34b7e46.

📒 Files selected for processing (3)
  • examples/distributed/primitives/example_tilescale_copy.py (1 hunks)
  • tilelang/language/copy.py (2 hunks)
  • tilelang/language/proxy.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/copy.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
examples/distributed/primitives/example_tilescale_copy.py (6)
tilelang/distributed/utils.py (1)
  • init_dist (40-62)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
tilelang/language/copy.py (1)
  • copy (11-98)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/utils/allocator.py (1)
  • get_allocator (236-248)
tilelang/utils/tensor.py (1)
  • tensor (43-56)
🪛 Ruff (0.14.2)
examples/distributed/primitives/example_tilescale_copy.py

15-15: Unused function argument: rank

(ARG001)


181-181: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia
🔇 Additional comments (5)
tilelang/language/proxy.py (1)

305-313: LGTM! Clean helper function.

The make_tensor_like function provides a clean interface for creating tensor-like buffers with optional property overrides. The None-check fallback pattern is clear and the type hints appropriately reflect the optional nature of the parameters.

tilelang/language/copy.py (1)

97-98: LGTM! Intrinsic call correctly extended.

The tir.call_intrin invocation correctly passes the new src_pe and dst_pe parameters, maintaining consistency with the updated function signature and the underlying intrinsic's expanded interface.

examples/distributed/primitives/example_tilescale_copy.py (3)

17-136: LGTM! Comprehensive demonstration of remote copy patterns.

The five kernel implementations effectively demonstrate different distributed copy strategies:

  • SIMT push/pull patterns with explicit per-endpoint signaling
  • TMA load/store patterns using runtime rank expressions
  • Proper staging through shared memory with swizzled layouts

All kernels correctly use 1 - rank to target the peer in a 2-rank configuration.


139-162: LGTM! Proper distributed setup and kernel initialization.

The initialization sequence correctly:

  • Validates the 2-rank constraint required by the peer-signaling logic
  • Configures the distributed allocator with appropriate parameters
  • Compiles the selected kernel and prints source on rank 0 only

186-194: LGTM! Clean script entrypoint.

The argument parser and multiprocessing spawn setup is clean and appropriate for the 2-rank distributed test scenario.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 34b7e46 and 7aaf6a3.

📒 Files selected for processing (4)
  • examples/distributed/primitives/example_tilescale_copy.py (1 hunks)
  • src/op/copy.cc (2 hunks)
  • src/tl_templates/cuda/distributed.h (1 hunks)
  • tilelang/language/copy.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/tl_templates/cuda/distributed.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/copy.cc (2)
src/op/copy.h (1)
  • is_remote_copy (100-124)
tilelang/language/copy.py (1)
  • copy (11-96)
examples/distributed/primitives/example_tilescale_copy.py (10)
tilelang/distributed/utils.py (1)
  • init_dist (40-62)
tilelang/env.py (1)
  • disable_cache (286-287)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/allocate.py (1)
  • alloc_shared (24-39)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
tilelang/language/copy.py (1)
  • copy (11-96)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/utils/allocator.py (1)
  • get_allocator (236-248)
tilelang/jit/kernel.py (1)
  • initialize (407-416)
tilelang/utils/tensor.py (1)
  • tensor (43-56)
tilelang/language/copy.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
🪛 Ruff (0.14.2)
examples/distributed/primitives/example_tilescale_copy.py

179-179: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia
🔇 Additional comments (9)
src/op/copy.cc (2)

158-178: LGTM! Remote copy parameter parsing and validation is correct.

The parameter parsing and validation logic is well-structured:

  • Correctly parses src_pe and dst_pe from args[5] and args[6]
  • Enforces mutual exclusion (at least one endpoint must be local)
  • Validates memory scopes appropriately (remote push requires global dst, remote pull requires global src)

183-195: LGTM! Query helper implementations are correct.

The query helpers correctly identify remote copy operations by checking whether src_pe or dst_pe differ from the local sentinel value (-1).

tilelang/language/copy.py (3)

11-19: LGTM! Function signature is reasonable.

The type hints tir.PrimExpr | tir.IntImm | None = -1 correctly include the implicit int type that Python's -1 literal represents. While the hint could explicitly list int, the current form is acceptable since Python int literals are compatible with the PrimExpr type system in TVM.


20-36: LGTM! Documentation is clear and accurate.

The docstring properly documents the new src_pe and dst_pe parameters, explaining that -1 indicates a local endpoint.


94-96: LGTM! Parameter propagation is correct.

The function correctly passes src_pe and dst_pe through to the tl.copy intrinsic. Validation is appropriately handled at the C++ layer (in src/op/copy.cc), avoiding issues with symbolic PrimExpr values that would arise from Python-level assertions.

examples/distributed/primitives/example_tilescale_copy.py (4)

1-12: LGTM! Imports and setup are appropriate.

The imports and environment configuration are well-suited for a distributed copy example. Disabling the cache and silencing NCCL logs improve the testing experience.


14-136: LGTM! Kernel variants demonstrate remote copy patterns effectively.

The get_kernel function provides a comprehensive set of examples:

  • simt_push_buffer: Full-buffer remote push
  • simt_push_tile: Tiled push via shared memory
  • simt_pull_tile: Tiled pull via shared memory
  • tma_load_tile: TMA-based remote load
  • tma_store_tile: TMA-based remote store

The distinction between using rank[0] (SIMT kernels) and T.get_rank() (TMA kernels) is noted in the comment on line 99, which helps users understand the current limitation.


139-181: LGTM! Main function demonstrates proper distributed workflow.

The main function correctly implements:

  • Distributed initialization with barrier synchronization
  • Allocator configuration for multi-rank execution
  • Kernel compilation and initialization
  • Tensor allocation on distributed allocator
  • Proper synchronization before and after kernel execution
  • Result validation using all_gather and peer data comparison (via XOR to get opposite rank)
  • Cleanup of distributed process group

The validation logic on line 172 (dst_torchs[local_rank ^ 1]) cleverly uses XOR to select the peer's source data for comparison.


184-192: LGTM! Script entrypoint is well-structured.

The argument parser provides sensible defaults and the use of torch.multiprocessing.spawn correctly launches the two required processes for this distributed example.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
src/transform/declare_symm_buffer.cc (1)

9-9: Fix the registry header include.

This issue was already flagged in a previous review. The header <tvm/ffi/reflection/registry.h> does not exist and breaks the build. Replace it with <tvm/runtime/registry.h>.

-#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/registry.h>
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7aaf6a3 and fbb103b.

📒 Files selected for processing (1)
  • src/transform/declare_symm_buffer.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/declare_symm_buffer.cc (4)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
src/tl_templates/cuda/distributed.h (1)
  • get_uintptr_t (47-49)
tilelang/language/ast/ir.py (1)
  • LetStmt (880-908)
tilelang/transform/__init__.py (1)
  • DeclareSymmBuffer (492-498)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc

[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found

(clang-diagnostic-error)

🔇 Additional comments (3)
src/transform/declare_symm_buffer.cc (3)

50-67: LGTM!

The Apply method correctly implements the standard TVM pass pattern: early return for undefined bodies, buffer map initialization, and stmt visitation.


70-79: LGTM!

The block visitor correctly populates the buffer map from both alloc_buffers and match_buffers, then delegates to the parent visitor.


242-254: Pass creation and FFI registration look correct.

The pass creation follows the standard TVM pattern, and the FFI registration is properly structured. However, this code depends on fixing the header include at line 9 to function correctly.

Ensure that after fixing the header include at line 9, the FFI registration compiles and the pass is accessible from Python as tl.transform.DeclareSymmBuffer().

Comment on lines +29 to +30
static int name_suffix_id =
0; // Avoid name collision for symm buffers, start from 0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Race condition: global mutable state in transformation pass.

The global name_suffix_id is incremented without synchronization (lines 110, 178), causing data races when multiple passes run concurrently. Move this counter into the SymmBufferDeclarer class as a member variable.

-static int name_suffix_id =
-    0; // Avoid name collision for symm buffers, start from 0

Add to the SymmBufferDeclarer class private members:

  Map<Var, Buffer> buffer_data_to_buffer_;
+ int name_suffix_id_ = 0;

Then update lines 110 and 178 to use name_suffix_id_++ instead of name_suffix_id++.

🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around lines 29-30 (and usages at ~110
and ~178), a global mutable variable `name_suffix_id` is used and incremented,
causing race conditions when passes run concurrently; move this counter into the
SymmBufferDeclarer class as a private member (e.g., add int name_suffix_id_ = 0
alongside buffer_data_to_buffer_), then replace usages of the global
`name_suffix_id++` at the noted locations with `name_suffix_id_++` and remove
the global variable to eliminate the shared mutable state.

Comment on lines +95 to +162
if (parsed_op.as<CopyNode>()->is_remote_push()) {
// LOG(INFO) << "Found remote push";

Buffer dst = parsed_op.as<CopyNode>()->dst;
Array<Range> dst_range = parsed_op.as<CopyNode>()->dst_range;

// 1. Calculate symm dst ptr
PrimExpr symm_dst_ptr_expr =
CalculateSymmPtr(dst->data, parsed_op.as<CopyNode>()->dst_pe);
// LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr;

// 2. Create a let binding
String storage_scope =
dst->data->type_annotation.as<PointerTypeNode>()->storage_scope;
Var symm_dst_var =
Var(dst->name + "_symm_" + std::to_string(name_suffix_id++),
PointerType(PrimType(dst->dtype), storage_scope));

// 3. Create modified dst buffer with symm var
dst.CopyOnWrite()->data = symm_dst_var;

// 4. Rebuild the destination region call with the modified buffer
// RegionOp args: [BufferLoad(min_indices), access_mask, extent_0,
// extent_1, ...]
Array<PrimExpr> dst_region_mins;
Array<PrimExpr> dst_region_extents;
for (const Range &r : dst_range) {
dst_region_mins.push_back(r->min);
dst_region_extents.push_back(r->extent);
}
BufferLoad dst_load(dst, dst_region_mins);

Array<PrimExpr> dst_region_args;
dst_region_args.push_back(dst_load);
dst_region_args.push_back(
IntImm(DataType::Int(32), call_op->args[1]
.as<CallNode>()
->args[1]
.as<IntImmNode>()
->value)); // access_mask
for (const PrimExpr &extent : dst_region_extents) {
dst_region_args.push_back(extent);
}

// Create new Call for the destination region
Call dst_region_call =
Call(call_op->args[1].as<CallNode>()->dtype,
call_op->args[1].as<CallNode>()->op, dst_region_args,
call_op->args[1].as<CallNode>()->span);

// 5. Rebuild the Copy call with modified args
Array<PrimExpr> new_copy_args;
new_copy_args.push_back(call_op->args[0]); // src region (unchanged)
new_copy_args.push_back(dst_region_call); // modified dst region
// Copy remaining args
for (size_t i = 2; i < call_op->args.size(); i++) {
new_copy_args.push_back(call_op->args[i]);
}

// Create the modified copy call
Call new_copy_call =
Call(call_op->dtype, call_op->op, new_copy_args, call_op->span);

// Wrap it in an Evaluate statement
Stmt modified_stmt = Evaluate(new_copy_call);

// Wrap with LetStmt that defines the symm pointer
return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add defensive null checks for nested property access.

The code has several unsafe property accesses that can crash:

  1. Line 108: type_annotation.as<PointerTypeNode>()->storage_scope assumes type_annotation is defined and is a PointerType.
  2. Lines 130-134: Deeply nested as<CallNode>() and as<IntImmNode>() calls without checking intermediate results.

If any cast fails, dereferencing the null pointer crashes the compiler.

Apply these defensive checks:

         if (parsed_op.as<CopyNode>()->is_remote_push()) {
           // LOG(INFO) << "Found remote push";
 
           Buffer dst = parsed_op.as<CopyNode>()->dst;
           Array<Range> dst_range = parsed_op.as<CopyNode>()->dst_range;
 
           // 1. Calculate symm dst ptr
           PrimExpr symm_dst_ptr_expr =
               CalculateSymmPtr(dst->data, parsed_op.as<CopyNode>()->dst_pe);
           // LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr;
 
           // 2. Create a let binding
+          ICHECK(dst->data->type_annotation.defined()) 
+              << "Buffer data variable must have type annotation";
+          auto ptr_type = dst->data->type_annotation.as<PointerTypeNode>();
+          ICHECK(ptr_type) << "Buffer data type annotation must be PointerType";
           String storage_scope =
-              dst->data->type_annotation.as<PointerTypeNode>()->storage_scope;
+              ptr_type->storage_scope;
           Var symm_dst_var =
               Var(dst->name + "_symm_" + std::to_string(name_suffix_id++),
                   PointerType(PrimType(dst->dtype), storage_scope));
 
           // 3. Create modified dst buffer with symm var
           dst.CopyOnWrite()->data = symm_dst_var;
 
           // 4. Rebuild the destination region call with the modified buffer
           // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0,
           // extent_1, ...]
           Array<PrimExpr> dst_region_mins;
           Array<PrimExpr> dst_region_extents;
           for (const Range &r : dst_range) {
             dst_region_mins.push_back(r->min);
             dst_region_extents.push_back(r->extent);
           }
           BufferLoad dst_load(dst, dst_region_mins);
 
           Array<PrimExpr> dst_region_args;
           dst_region_args.push_back(dst_load);
+          auto dst_region_call_node = call_op->args[1].as<CallNode>();
+          ICHECK(dst_region_call_node) << "Expected Call node for dst region";
+          auto access_mask_node = dst_region_call_node->args[1].as<IntImmNode>();
+          ICHECK(access_mask_node) << "Expected IntImm for access_mask";
           dst_region_args.push_back(
-              IntImm(DataType::Int(32), call_op->args[1]
-                                            .as<CallNode>()
-                                            ->args[1]
-                                            .as<IntImmNode>()
-                                            ->value)); // access_mask
+              IntImm(DataType::Int(32), access_mask_node->value)); // access_mask
🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around lines 95-162, several chained
.as<...> dereferences can return nullptr and crash; add defensive null/type
checks before dereferencing: verify dst->data->type_annotation is non-null and
its .as<PointerTypeNode>() succeeds before accessing storage_scope; verify
call_op and call_op->args size >=2, that call_op->args[1].as<CallNode>() is
non-null, that that CallNode has args size >=2 and that args[1].as<IntImmNode>()
is non-null before reading ->value; on any check failure handle gracefully
(e.g., log/ICHECK(false) with a clear message or return the original
stmt/unmodified Call) instead of dereferencing a null pointer so the compiler
won't crash.

Comment on lines 163 to 230
} else if (parsed_op.as<CopyNode>()->is_remote_pull()) {
LOG(INFO) << "Found remote pull";

Buffer src = parsed_op.as<CopyNode>()->src;
Array<Range> src_range = parsed_op.as<CopyNode>()->src_range;

// 1. Calculate symm src ptr
PrimExpr symm_src_ptr_expr =
CalculateSymmPtr(src->data, parsed_op.as<CopyNode>()->src_pe);
// LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr;

// 2. Create a let binding
String storage_scope =
src->data->type_annotation.as<PointerTypeNode>()->storage_scope;
Var symm_src_var =
Var(src->name + "_symm_" + std::to_string(name_suffix_id++),
PointerType(PrimType(src->dtype), storage_scope));

// 3. Create modified src buffer with symm var
src.CopyOnWrite()->data = symm_src_var;

// 4. Rebuild the source region call with the modified buffer
// RegionOp args: [BufferLoad(min_indices), access_mask, extent_0,
// extent_1, ...]
Array<PrimExpr> src_region_mins;
Array<PrimExpr> src_region_extents;
for (const Range &r : src_range) {
src_region_mins.push_back(r->min);
src_region_extents.push_back(r->extent);
}
BufferLoad src_load(src, src_region_mins);

Array<PrimExpr> src_region_args;
src_region_args.push_back(src_load);
src_region_args.push_back(
IntImm(DataType::Int(32), call_op->args[1]
.as<CallNode>()
->args[1]
.as<IntImmNode>()
->value)); // access_mask
for (const PrimExpr &extent : src_region_extents) {
src_region_args.push_back(extent);
}

// Create new Call for the source region
Call src_region_call =
Call(call_op->args[0].as<CallNode>()->dtype,
call_op->args[0].as<CallNode>()->op, src_region_args,
call_op->args[0].as<CallNode>()->span);

// 5. Rebuild the Copy call with modified args
Array<PrimExpr> new_copy_args;
new_copy_args.push_back(src_region_call); // modified src region
new_copy_args.push_back(call_op->args[1]); // dst region (unchanged)
// Copy remaining args
for (size_t i = 2; i < call_op->args.size(); i++) {
new_copy_args.push_back(call_op->args[i]);
}

// Create the modified copy call
Call new_copy_call =
Call(call_op->dtype, call_op->op, new_copy_args, call_op->span);

// Wrap it in an Evaluate statement
Stmt modified_stmt = Evaluate(new_copy_call);

// Wrap with LetStmt that defines the symm pointer
return LetStmt(symm_src_var, symm_src_ptr_expr, modified_stmt);
Copy link

@coderabbitai coderabbitai bot Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add defensive null checks and remove debug LOG statement.

This code has the same issues as the remote_push block:

  1. Line 164: LOG(INFO) should be removed or commented out (other debug logs are commented).
  2. Line 176: type_annotation.as<PointerTypeNode>()->storage_scope needs null checks.
  3. Lines 198-202: Deeply nested as<>() calls without validation.

Apply similar defensive checks as suggested for remote_push:

         } else if (parsed_op.as<CopyNode>()->is_remote_pull()) {
-          LOG(INFO) << "Found remote pull";
+          // LOG(INFO) << "Found remote pull";
 
           Buffer src = parsed_op.as<CopyNode>()->src;
           Array<Range> src_range = parsed_op.as<CopyNode>()->src_range;
 
           // 1. Calculate symm src ptr
           PrimExpr symm_src_ptr_expr =
               CalculateSymmPtr(src->data, parsed_op.as<CopyNode>()->src_pe);
           // LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr;
 
           // 2. Create a let binding
+          ICHECK(src->data->type_annotation.defined()) 
+              << "Buffer data variable must have type annotation";
+          auto ptr_type = src->data->type_annotation.as<PointerTypeNode>();
+          ICHECK(ptr_type) << "Buffer data type annotation must be PointerType";
           String storage_scope =
-              src->data->type_annotation.as<PointerTypeNode>()->storage_scope;
+              ptr_type->storage_scope;
           Var symm_src_var =
               Var(src->name + "_symm_" + std::to_string(name_suffix_id++),
                   PointerType(PrimType(src->dtype), storage_scope));
 
           // 3. Create modified src buffer with symm var
           src.CopyOnWrite()->data = symm_src_var;
 
           // 4. Rebuild the source region call with the modified buffer
           // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0,
           // extent_1, ...]
           Array<PrimExpr> src_region_mins;
           Array<PrimExpr> src_region_extents;
           for (const Range &r : src_range) {
             src_region_mins.push_back(r->min);
             src_region_extents.push_back(r->extent);
           }
           BufferLoad src_load(src, src_region_mins);
 
           Array<PrimExpr> src_region_args;
           src_region_args.push_back(src_load);
+          auto src_region_call_node = call_op->args[0].as<CallNode>();
+          ICHECK(src_region_call_node) << "Expected Call node for src region";
+          auto access_mask_node = src_region_call_node->args[1].as<IntImmNode>();
+          ICHECK(access_mask_node) << "Expected IntImm for access_mask";
           src_region_args.push_back(
-              IntImm(DataType::Int(32), call_op->args[1]
-                                            .as<CallNode>()
-                                            ->args[1]
-                                            .as<IntImmNode>()
-                                            ->value)); // access_mask
+              IntImm(DataType::Int(32), access_mask_node->value)); // access_mask
🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around lines 163-230, remove the debug
LOG(INFO) and add defensive null/size checks before every use of as<...> and
indexing: validate parsed_op.as<CopyNode>() is non-null, ensure
src->data->type_annotation.as<PointerTypeNode>() is non-null before accessing
storage_scope (provide a safe fallback or return original stmt on failure),
check call_op and call_op->args[0].as<CallNode>() /
call_op->args[1].as<CallNode>() are non-null before accessing their fields, and
verify call_op->args has the expected size before indexing args[1] and iterating
from index 2; if any check fails, bail out cleanly (e.g., return the original
stmt or propagate an error) instead of dereferencing nulls.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(INFO) commented out

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rachmanino, thanks for commenting out the LOG(INFO)! 👍

Just a gentle reminder that there are still two items from the original feedback that could improve robustness:

  1. Line 176: Adding null checks before accessing storage_scope from the type annotation
  2. Lines 198-202: Validating the nested as<>() calls before dereferencing

These defensive checks would help prevent potential crashes if the IR structure is unexpected. However, if you're confident the IR will always have the expected shape at this point in the pipeline, feel free to defer these checks.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
src/op/copy.cc (1)

1979-1983: Fix the input count in the documentation comment.

The comment states "Takes 8 inputs" but only 7 are listed and the registration correctly uses .set_num_inputs(7). Update the comment to reflect the actual count.

Apply this diff:

-// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
+// - Takes 7 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
 // eviction_policy, src_pe, dst_pe
src/transform/declare_symm_buffer.cc (5)

9-9: Fix the incorrect header include.

The header <tvm/ffi/reflection/registry.h> does not exist. Use <tvm/runtime/registry.h> instead.

-#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/registry.h>

29-30: Race condition: Move global mutable counter into class.

The global name_suffix_id is incremented without synchronization (lines 111, 179), causing data races when passes run concurrently. Move this counter into SymmBufferDeclarer as a member variable.


34-42: Fix type inconsistencies in pointer arithmetic.

Multiple issues:

  1. Line 35: local_rank is declared but never used
  2. Lines 36-40: Mixing Handle and integer types incorrectly in pointer arithmetic
  3. get_uintptr_t() should work with DataType::UInt(64), not Handle

All pointer arithmetic should be performed on UInt(64) values, then cast back to Handle for the return.


96-163: Add defensive null checks for nested property access.

Lines 108 and 130-135 contain unsafe property accesses that can crash if any cast fails:

  • type_annotation.as<PointerTypeNode>()->storage_scope assumes the type is defined and is a PointerType
  • Deeply nested .as<CallNode>() and .as<IntImmNode>() without checking intermediate results

164-231: Remove debug LOG statement and add defensive null checks.

Line 165: LOG(INFO) should be commented out (other debug logs are commented).
Lines 176-177 and 198-202 contain unsafe property accesses without null checks.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fbb103b and 37ade50.

📒 Files selected for processing (2)
  • src/op/copy.cc (2 hunks)
  • src/transform/declare_symm_buffer.cc (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/declare_symm_buffer.cc (5)
tilelang/language/proxy.py (1)
  • ptr (273-295)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
src/tl_templates/cuda/distributed.h (1)
  • get_uintptr_t (47-49)
tilelang/language/ast/ir.py (1)
  • LetStmt (880-908)
tilelang/transform/__init__.py (1)
  • DeclareSymmBuffer (492-498)
src/op/copy.cc (2)
src/op/copy.h (1)
  • is_remote_copy (100-124)
tilelang/language/copy.py (1)
  • copy (11-96)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc

[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found

(clang-diagnostic-error)

🔇 Additional comments (3)
src/op/copy.cc (2)

158-179: LGTM! Remote copy parameter parsing and validation is correct.

The code properly:

  • Parses src_pe/dst_pe with bounds checking
  • Ensures at least one side is local (prevents both being remote)
  • Validates memory scope constraints for peer access (remote push requires dst in global, remote pull requires src in global)

183-195: LGTM! Helper methods correctly identify remote copy types.

The implementation properly detects:

  • Remote push when dst_pe ≠ -1
  • Remote pull when src_pe ≠ -1
  • Remote copy as the union of both
src/transform/declare_symm_buffer.cc (1)

243-255: LGTM! Pass registration follows standard TVM patterns.

The pass creation and FFI registration are implemented correctly using the standard TVM idioms.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
.github/workflows/ci.yml (1)

108-132: Consider aligning distributed test execution between examples and tests sections.

The examples section (line 119) runs distributed tests with TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}" directly, while the tests section now uses a per-test script approach with random ports. This inconsistency may indicate the per-test approach is necessary to prevent test collisions, but the examples section hasn't been updated.

Verify whether the examples section would benefit from the same per-test isolation strategy, or document why it's not needed there.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8d30e4e and 5d6e9c7.

📒 Files selected for processing (1)
  • .github/workflows/ci.yml (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
.github/workflows/ci.yml (2)

117-125: Critical: Unaddressed issues from previous review—parallel dependency and port collision risk.

This change introduces the same concerns previously flagged on lines 149–157:

  1. Missing parallel dependency check: The workflow assumes GNU parallel is installed on the self-hosted runner without verification. If missing, the script will silently fail or produce cryptic errors.

  2. Port collision risk: Random port selection from [20000, 21000) provides only 1000 ports. With --jobs 4, multiple concurrent tests picking the same random MASTER_PORT can cause conflicts and test failures, especially with multiple distributed test files.

Apply this fix to verify parallel availability and expand the port range to reduce collisions:

        if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
+         if ! command -v parallel &> /dev/null; then
+           echo "Error: GNU parallel not installed. Aborting distributed examples."
+           exit 1
+         fi
          tmp_script=$(mktemp)
          echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:"
          printf '%s\n' "${DIST_TESTS[@]}"
          for test_file in "${DIST_TESTS[@]}"; do
-           RANDOM_PORT=$(( RANDOM % 1000 + 20000 ))  # random port [20000, 21000)
+           RANDOM_PORT=$(( RANDOM % 30000 + 20000 ))  # expanded range [20000, 50000)
            echo "MASTER_PORT=$RANDOM_PORT TILELANG_USE_DISTRIBUTED=1 python -m pytest \"$test_file\" -v -r fE" >> "$tmp_script"
          done
          parallel --jobs 4 < "$tmp_script"
          rm "$tmp_script"

149-157: Critical: Unaddressed issues from previous review—parallel dependency and port collision risk.

The same critical issues flagged on lines 117–125 apply here:

  1. Missing parallel dependency check: No verification that GNU parallel is installed.
  2. Port collision risk: Random [20000, 21000) range with --jobs 4 creates collision risk.

These concerns were raised in the previous review and remain unresolved.

Apply the same fix:

        if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
+         if ! command -v parallel &> /dev/null; then
+           echo "Error: GNU parallel not installed. Aborting distributed tests."
+           exit 1
+         fi
          tmp_script=$(mktemp)
          echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:"
          printf '%s\n' "${DIST_TESTS[@]}"
          for test_file in "${DIST_TESTS[@]}"; do
-           RANDOM_PORT=$(( RANDOM % 1000 + 20000 ))  # random port [20000, 21000)
+           RANDOM_PORT=$(( RANDOM % 30000 + 20000 ))  # expanded range [20000, 50000)
            echo "MASTER_PORT=$RANDOM_PORT TILELANG_USE_DISTRIBUTED=1 python -m pytest \"$test_file\" -v -r fE" >> "$tmp_script"
          done
          parallel --jobs 4 < "$tmp_script"
          rm "$tmp_script"
🧹 Nitpick comments (1)
.github/workflows/ci.yml (1)

108-170: Consider extracting distributed test execution into a reusable helper function.

Lines 117–125 and 149–157 implement the same pattern twice (create script, generate per-test commands, run via parallel, cleanup). This duplicates logic across the "Run examples" and "Run tests" sections.

Extract this into a shell function at the top of the step to reduce duplication and improve maintainability:

    - name: Run examples
      run: |
        source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
        cd examples
        unset PYTHONPATH
+
+       run_distributed_tests() {
+         local test_type="$1"  # "examples" or "tests"
+         mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
+         if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
+           if ! command -v parallel &> /dev/null; then
+             echo "Error: GNU parallel not installed."
+             exit 1
+           fi
+           tmp_script=$(mktemp)
+           echo "Running distributed $test_type with TILELANG_USE_DISTRIBUTED=1:"
+           printf '%s\n' "${DIST_TESTS[@]}"
+           for test_file in "${DIST_TESTS[@]}"; do
+             RANDOM_PORT=$(( RANDOM % 30000 + 20000 ))
+             echo "MASTER_PORT=$RANDOM_PORT TILELANG_USE_DISTRIBUTED=1 python -m pytest \"$test_file\" -v -r fE" >> "$tmp_script"
+           done
+           parallel --jobs 4 < "$tmp_script"
+           rm "$tmp_script"
+         else
+           echo "No distributed $test_type found."
+         fi
+       }

        # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1
-       mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true)
-       if [ "${#DIST_TESTS[@]}" -gt 0 ]; then
-         ...
-       fi
+       run_distributed_tests "examples"

This approach centralizes the logic, simplifies both sections, and makes it easier to fix issues (like the port collision concern) in one place.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5d6e9c7 and 60fb488.

📒 Files selected for processing (1)
  • .github/workflows/ci.yml (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia

@Rachmanino Rachmanino force-pushed the wt/dev branch 2 times, most recently from 4efa911 to 99501ef Compare November 11, 2025 06:45
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (5)
src/transform/declare_symm_buffer.cc (5)

9-9: Fix the broken registry include.
Line 9 still pulls in <tvm/ffi/reflection/registry.h>, which doesn’t exist in this tree, so the file fails to compile. Please include the runtime registry header instead.

-#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/registry.h>

29-30: Eliminate the global counter race.
The global name_suffix_id is still present and incremented from multiple pass invocations, which races under concurrent compilation. Move the counter into SymmBufferDeclarer and update the call sites to use the member.

-static int name_suffix_id =
-    0; // Avoid name collision for symm buffers, start from 0
+// Avoid name collision for symm buffers, start from 0
-class SymmBufferDeclarer : public StmtExprMutator {
+class SymmBufferDeclarer : public StmtExprMutator {
 public:
@@
 private:
   Map<Var, Buffer> buffer_data_to_buffer_;
+  int name_suffix_id_ = 0;
-          Var symm_dst_var =
-              Var(dst->name + "_symm_" + std::to_string(name_suffix_id++),
+          Var symm_dst_var =
+              Var(dst->name + "_symm_" + std::to_string(name_suffix_id_++),
                   PointerType(PrimType(dst->dtype), storage_scope));
-          Var symm_src_var =
-              Var(src->name + "_symm_" + std::to_string(name_suffix_id++),
+          Var symm_src_var =
+              Var(src->name + "_symm_" + std::to_string(name_suffix_id_++),
                   PointerType(PrimType(src->dtype), storage_scope));

Also applies to: 108-110, 176-178


34-40: Repair pointer arithmetic in CalculateSymmPtr.
This routine still mixes Handle-typed PrimExpr with integer arithmetic, which is undefined after type checking. Convert all operands to UInt(64) via tl::get_uintptr_t before subtracting/adding, then reinterpret the final sum back to a handle.

 PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) {
-  PrimExpr local_base_ptr = Call(DataType::Handle(), tl::get_local_base(), {});
-  PrimExpr offset_to_base =
-      Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), local_base_ptr);
-  PrimExpr result = Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}) +
-                    offset_to_base;
-  return result;
+  PrimExpr local_base_handle = Call(DataType::Handle(), tl::get_local_base(), {});
+  PrimExpr local_base_uint =
+      Call(DataType::UInt(64), tl::get_uintptr_t(), {local_base_handle});
+  PrimExpr ptr_uint = Call(DataType::UInt(64), tl::get_uintptr_t(), {ptr});
+  PrimExpr offset_to_base = Sub(ptr_uint, local_base_uint);
+  PrimExpr remote_base_handle =
+      Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe});
+  PrimExpr remote_base_uint =
+      Call(DataType::UInt(64), tl::get_uintptr_t(), {remote_base_handle});
+  PrimExpr result_uint = Add(remote_base_uint, offset_to_base);
+  return Call(DataType::Handle(), builtin::reinterpret(), {result_uint});
 }

105-143: Add defensive checks before rebuilding the destination region.
This block still assumes the buffer has a pointer type annotation and that call_op->args[1] is a CallNode with an IntImm access mask. Any divergence in the IR crashes the compiler. Guard those assumptions and reuse the validated node when rebuilding the region.

-          String storage_scope =
-              dst->data->type_annotation.as<PointerTypeNode>()->storage_scope;
+          ICHECK(dst->data->type_annotation.defined())
+              << "Buffer data variable must have type annotation";
+          const auto* dst_ptr_type =
+              dst->data->type_annotation.as<PointerTypeNode>();
+          ICHECK(dst_ptr_type)
+              << "Buffer data type annotation must be PointerType";
+          String storage_scope = dst_ptr_type->storage_scope;
@@
-          Array<PrimExpr> dst_region_args;
+          ICHECK_GE(call_op->args.size(), 2U)
+              << "Copy call must provide dst region";
+          const auto* dst_region_call_node = call_op->args[1].as<CallNode>();
+          ICHECK(dst_region_call_node) << "Expected Call node for dst region";
+          Array<PrimExpr> dst_region_args;
           dst_region_args.push_back(dst_load);
-          dst_region_args.push_back(
-              IntImm(DataType::Int(32), call_op->args[1]
-                                            .as<CallNode>()
-                                            ->args[1]
-                                            .as<IntImmNode>()
-                                            ->value)); // access_mask
+          const auto* dst_access_mask =
+              dst_region_call_node->args[1].as<IntImmNode>();
+          ICHECK(dst_access_mask) << "Expected IntImm access mask for dst region";
+          dst_region_args.push_back(
+              IntImm(DataType::Int(32), dst_access_mask->value)); // access_mask
@@
-          Call dst_region_call =
-              Call(call_op->args[1].as<CallNode>()->dtype,
-                   call_op->args[1].as<CallNode>()->op, dst_region_args,
-                   call_op->args[1].as<CallNode>()->span);
+          Call dst_region_call =
+              Call(dst_region_call_node->dtype, dst_region_call_node->op,
+                   dst_region_args, dst_region_call_node->span);

174-211: Do the same safety checks for the source region.
The remote pull branch still dereferences unchecked pointer annotations and assumes call_op->args[0] yields a CallNode(IntImm). Add the corresponding ICHECKs before accessing those members.

-          String storage_scope =
-              src->data->type_annotation.as<PointerTypeNode>()->storage_scope;
+          ICHECK(src->data->type_annotation.defined())
+              << "Buffer data variable must have type annotation";
+          const auto* src_ptr_type =
+              src->data->type_annotation.as<PointerTypeNode>();
+          ICHECK(src_ptr_type)
+              << "Buffer data type annotation must be PointerType";
+          String storage_scope = src_ptr_type->storage_scope;
@@
-          Array<PrimExpr> src_region_args;
+          ICHECK_GE(call_op->args.size(), 2U)
+              << "Copy call must provide src region";
+          const auto* src_region_call_node = call_op->args[0].as<CallNode>();
+          ICHECK(src_region_call_node) << "Expected Call node for src region";
+          Array<PrimExpr> src_region_args;
           src_region_args.push_back(src_load);
-          src_region_args.push_back(
-              IntImm(DataType::Int(32), call_op->args[0]
-                                            .as<CallNode>()
-                                            ->args[1]
-                                            .as<IntImmNode>()
-                                            ->value)); // access_mask
+          const auto* src_access_mask =
+              src_region_call_node->args[1].as<IntImmNode>();
+          ICHECK(src_access_mask) << "Expected IntImm access mask for src region";
+          src_region_args.push_back(
+              IntImm(DataType::Int(32), src_access_mask->value)); // access_mask
@@
-          Call src_region_call =
-              Call(call_op->args[0].as<CallNode>()->dtype,
-                   call_op->args[0].as<CallNode>()->op, src_region_args,
-                   call_op->args[0].as<CallNode>()->span);
+          Call src_region_call =
+              Call(src_region_call_node->dtype, src_region_call_node->op,
+                   src_region_args, src_region_call_node->span);
🧹 Nitpick comments (4)
src/transform/lower_hopper_intrin.cc (1)

66-66: Consider removing unnecessary blank line.

This blank line appears to be unnecessary and could be removed for consistency.

tilelang/language/__init__.py (1)

19-19: Remove unnecessary noqa directive.

The # noqa: F401 comment is flagged as unused by Ruff. Since make_tensor_like is being explicitly exported in __all__ (implicitly via the module's public API), the import is not unused and the directive can be removed.

Apply this diff:

-    make_tensor_like,  # noqa: F401
+    make_tensor_like,
tilelang/language/copy.py (1)

25-26: Clarify PE parameter semantics in docstring.

The docstring states "Defaults to -1, which means copy from/to local" but doesn't explain what non--1 values represent. Consider adding a brief note that non--1 values specify a remote PE index.

-        src_pe (Optional[tir.PrimExpr], optional): Source PE index. Defaults to -1, which means copy from local
-        dst_pe (Optional[tir.PrimExpr], optional): Destination PE index. Defaults to -1, which means copy to local.
+        src_pe (Optional[tir.PrimExpr], optional): Source PE index. Defaults to -1 (local), otherwise specifies remote PE.
+        dst_pe (Optional[tir.PrimExpr], optional): Destination PE index. Defaults to -1 (local), otherwise specifies remote PE.
src/op/copy.cc (1)

1979-1984: Keep the tl.copy registration comment accurate.

The op now exposes seven inputs, so the comment claiming eight (with is_remote_copy) is stale. Updating it avoids the next reader scratching their head.

Apply this diff to synchronize the documentation:

-// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
-// eviction_policy, src_pe, dst_pe, is_remote_copy
+// - Takes 7 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
+//   eviction_policy, src_pe, dst_pe
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 99501ef and a510e74.

📒 Files selected for processing (21)
  • examples/distributed/primitives/example_tilescale_copy.py (1 hunks)
  • examples/distributed/primitives/test_tilescale_copy.py (1 hunks)
  • src/op/copy.cc (2 hunks)
  • src/op/copy.h (1 hunks)
  • src/op/distributed.cc (1 hunks)
  • src/op/distributed.h (1 hunks)
  • src/op/remote_copy.cc (2 hunks)
  • src/op/sync.cc (1 hunks)
  • src/target/codegen_cuda.cc (2 hunks)
  • src/tl_templates/cuda/common.h (1 hunks)
  • src/tl_templates/cuda/distributed.h (1 hunks)
  • src/tl_templates/cuda/sync.h (1 hunks)
  • src/transform/declare_symm_buffer.cc (1 hunks)
  • src/transform/lower_hopper_intrin.cc (4 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/jit/adapter/utils.py (1 hunks)
  • tilelang/jit/adapter/wrapper.py (5 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/copy.py (2 hunks)
  • tilelang/language/proxy.py (1 hunks)
  • tilelang/transform/__init__.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • tilelang/transform/init.py
  • tilelang/jit/adapter/utils.py
  • tilelang/engine/phase.py
  • tilelang/language/proxy.py
  • src/tl_templates/cuda/distributed.h
🧰 Additional context used
🧬 Code graph analysis (11)
src/op/copy.h (1)
src/op/copy.cc (6)
  • is_remote_copy (193-195)
  • is_remote_copy (193-193)
  • is_remote_push (183-186)
  • is_remote_push (183-183)
  • is_remote_pull (188-191)
  • is_remote_pull (188-188)
src/op/copy.cc (2)
src/op/copy.h (1)
  • is_remote_copy (100-124)
tilelang/language/copy.py (1)
  • copy (11-96)
src/op/remote_copy.cc (1)
src/tl_templates/cuda/distributed.h (1)
  • get_uintptr_t (53-55)
src/op/distributed.h (1)
tilelang/language/distributed/common.py (2)
  • get_rank (8-11)
  • get_num_ranks (14-17)
src/transform/declare_symm_buffer.cc (2)
src/tl_templates/cuda/distributed.h (1)
  • get_uintptr_t (53-55)
tilelang/transform/__init__.py (1)
  • DeclareSymmBuffer (492-498)
examples/distributed/primitives/test_tilescale_copy.py (1)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (105-106)
tilelang/language/copy.py (1)
tilelang/language/tir/op.py (1)
  • call_intrin (120-145)
examples/distributed/primitives/example_tilescale_copy.py (8)
tilelang/distributed/utils.py (1)
  • init_dist (40-62)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/allocate.py (1)
  • alloc_shared (24-39)
tilelang/language/distributed/common.py (1)
  • get_rank (8-11)
tilelang/language/copy.py (1)
  • copy (11-96)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/utils/allocator.py (1)
  • get_allocator (236-248)
tilelang/utils/tensor.py (1)
  • tensor (43-56)
src/transform/lower_hopper_intrin.cc (3)
tilelang/language/builtin.py (1)
  • create_tma_descriptor (74-83)
src/op/builtin.cc (2)
  • cuTensorMapType (39-39)
  • cuTensorMapType (39-39)
tilelang/language/tir/op.py (1)
  • call_extern (173-195)
tilelang/jit/adapter/wrapper.py (1)
tilelang/jit/adapter/utils.py (2)
  • pythonic_expr (110-218)
  • tilescale_pythonic_expr (221-352)
tilelang/language/__init__.py (1)
tilelang/language/proxy.py (1)
  • make_tensor_like (305-313)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc

[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
examples/distributed/primitives/example_tilescale_copy.py

179-179: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/__init__.py

19-19: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-nvidia
🔇 Additional comments (20)
src/transform/lower_hopper_intrin.cc (3)

117-131: LGTM! Visitor pattern correctly implemented.

The LetStmt visitor correctly visits the value and body first, then checks if the variable is TMA-related and stores the binding. The structure preservation logic (lines 126-130) follows best practices.


187-189: LGTM! Appropriate data structures for TMA tracking.

The new data members use appropriate hash and equality functors (ObjectPtrHash, ObjectPtrEqual) for Var objects, which is correct for TVM IR nodes.


42-56: No changes required — the inlining is safe by design.

The inlining only replaces base pointer variables with their let-bound expressions. Base pointers—being memory addresses—are semantically constrained to be function parameters or early allocations, never depending on later-defined local variables. The code's narrow focus on only base pointers (stored in tma_related_vars_) ensures the hoisting respects scope.

Likely an incorrect or invalid review comment.

tilelang/jit/adapter/wrapper.py (5)

9-10: LGTM! Import aligns with distributed copy support.

The addition of tilescale_pythonic_expr correctly enables parsing of distributed intrinsics (e.g., tl.get_remote_base, tl.get_local_base) for TileScale-aware expression translation in TMA descriptor generation.


288-289: LGTM! Method follows established pattern.

The new helper method correctly wraps tilescale_pythonic_expr with type mapping, following the same pattern as _pythonic_expr above it.


466-466: LGTM! Correct use of tilescale-aware expression translator.

Switching to _tilescale_pythonic_expr for globalAddress properly handles distributed intrinsics (e.g., tl.get_remote_base, tl.get_local_base) that may appear in TMA descriptor global addresses for distributed copy scenarios, directly supporting the PR's unified copy lowering objectives.


65-65: No changes needed.

host_meta_data is properly declared at file scope in the generated C++ code. The declaration uint64_t* host_meta_data = nullptr; is generated in src/target/codegen_cuda.cc (line 301), and it's also extern-declared in the CUDA template headers (src/tl_templates/cuda/distributed.h, line 9). The assignment in the wrapper function will not cause compilation errors.


333-338: Pattern verification confirms no false positives—change is sound.

The symmetric buffer naming convention in src/transform/declare_symm_buffer.cc creates descriptors with the pattern {buffer_name}_symm_{numeric_id}, where the numeric ID is globally incremented to prevent collisions. The added condition (match.startswith(name + "_symm_") and match.endswith("_desc")) correctly identifies these TMA copy descriptors without risk of false positives, as the naming scheme ensures uniqueness even when buffer names have similar prefixes.

tilelang/language/copy.py (2)

11-19: LGTM! Remote copy parameters integrated correctly.

The function signature correctly extends the copy interface with src_pe and dst_pe parameters for distributed copy support. The type hints appropriately allow PrimExpr, IntImm, int, or None with a default of -1 to indicate local rank.


95-96: Remote copy parameters are correctly registered and passed.

Verification confirms the backend intrinsic registration expects 7 arguments. The tilelang/language/copy.py lines 95-96 pass all 7 arguments in the correct order expected by the Copy constructor in src/op/copy.cc:

  1. src (buffer/region)
  2. dst (buffer/region)
  3. coalesced_width
  4. disable_tma
  5. eviction_policy
  6. src_pe (new remote copy parameter)
  7. dst_pe (new remote copy parameter)

The backend registration at src/op/copy.cc confirms .set_num_inputs(7), matching the implementation. The code is correct.

src/tl_templates/cuda/common.h (1)

35-38: LGTM! Host qualifier macros added correctly.

The new TL_HOST* and TL_HOST_DEVICE* macros properly extend the device-only qualifiers to support host-side and dual host-device contexts, enabling distributed intrinsics to work on both host and device.

src/op/remote_copy.cc (2)

99-109: LGTM! Remote base intrinsic updated in PUT operation.

The PutOpNode::Lower method now correctly uses tl::get_remote_base() for both local and remote base pointer computation, maintaining the offset calculation logic.


205-215: LGTM! Remote base intrinsic updated in GET operation.

The GetOpNode::Lower method now correctly uses tl::get_remote_base() for both local and remote base pointer computation, consistent with the PUT operation changes.

examples/distributed/primitives/test_tilescale_copy.py (5)

1-7: LGTM! Test module setup is correct.

The imports and module structure properly set up the test environment with required dependencies for distributed multiprocessing tests.


9-13: LGTM! SIMT test properly configured.

The test correctly uses the requires_cuda decorator and spawns 2 processes for distributed testing of the SIMT push tile kernel.


15-20: LGTM! TMA load test has correct prerequisites.

The test properly stacks requires_cuda and requires_cuda_compute_version_ge(9, 0) decorators since TMA features require compute capability 9.0+.


22-27: LGTM! TMA store test has correct prerequisites.

The test properly stacks decorators for compute capability requirements, matching the TMA load test configuration.


29-30: LGTM! Test entry point configured correctly.

The main entry delegates to tilelang.testing.main() for standard test runner integration.

src/op/copy.h (1)

95-102: Verification confirmed — implementations are correct and properly integrated.

The three query methods are implemented soundly:

  • is_remote_push(): Checks if destination PE is remote (dst_pe ≠ -1)
  • is_remote_pull(): Checks if source PE is remote (src_pe ≠ -1)
  • is_remote_copy(): Correctly identifies any remote operation with OR logic

Type safety is ensured via IsInstance<IntImmNode>() checks before casting, and the sentinel value (-1) consistently distinguishes local from remote processing elements. The implementation properly supports distributed copy operations without impacting local copy behavior.

src/tl_templates/cuda/sync.h (1)

157-158: LGTM! Barrier pointer macro updated consistently.

The BARRIER_PTR macro now uses get_remote_base(tgt_rank) instead of get_remote_base_ptr(tgt_rank), aligning with the intrinsic renaming across the codebase. Verification confirms get_remote_base is properly defined in src/tl_templates/cuda/distributed.h and the change is consistent throughout the CUDA templates.

Comment on lines +159 to +195
// Parse remote copy params
if (args.size() >= 6) {
node->src_pe = args[5];
}
if (args.size() >= 7) {
node->dst_pe = args[6];
}

ICHECK(!(node->is_remote_push() && node->is_remote_pull()))
<< "At least one of src_pe or dst_pe must be local rank";

if (node->is_remote_push()) {
ICHECK(node->dst.scope() == "global")
<< "Can only copy to peer's global memory, but got "
<< node->dst.scope();
} else if (node->is_remote_pull()) {
ICHECK(node->src.scope() == "global")
<< "Can only pull from peer's global memory, but got "
<< node->src.scope();
}

data_ = std::move(node);
}

bool CopyNode::is_remote_push() const {
return !(dst_pe->IsInstance<IntImmNode>() &&
dst_pe.as<IntImmNode>()->value == -1);
}

bool CopyNode::is_remote_pull() const {
return !(src_pe->IsInstance<IntImmNode>() &&
src_pe.as<IntImmNode>()->value == -1);
}

bool CopyNode::is_remote_copy() const {
return is_remote_push() || is_remote_pull();
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard remote PE defaults before remote checks.

is_remote_push() / is_remote_pull() run immediately after the optional assignments. Legacy modules compiled before this PR still emit the 5-argument tl.copy; when we deserialize them here, args.size() is only 5, so dst_pe/src_pe stay PrimExpr() and the subsequent dst_pe->IsInstance dereferences a null handle. That’s a hard crash/regression for any pre-existing artifact. Please seed both fields with -1 before the if blocks and make the helpers resilient to an undefined PrimExpr.

Apply this diff to fix the issue:

Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
   ObjectPtr<CopyNode> node = make_object<CopyNode>();
+  node->src_pe = Integer(-1);
+  node->dst_pe = Integer(-1);
   Array<Range> rgs[2];
@@
-  if (args.size() >= 6) {
-    node->src_pe = args[5];
-  }
-  if (args.size() >= 7) {
-    node->dst_pe = args[6];
-  }
+  if (args.size() >= 6) {
+    node->src_pe = args[5];
+  }
+  if (args.size() >= 7) {
+    node->dst_pe = args[6];
+  }
@@
 bool CopyNode::is_remote_push() const {
-  return !(dst_pe->IsInstance<IntImmNode>() &&
-           dst_pe.as<IntImmNode>()->value == -1);
+  if (!dst_pe.defined()) {
+    return false;
+  }
+  if (const auto *imm = dst_pe.as<IntImmNode>()) {
+    return imm->value != -1;
+  }
+  return true;
 }
@@
 bool CopyNode::is_remote_pull() const {
-  return !(src_pe->IsInstance<IntImmNode>() &&
-           src_pe.as<IntImmNode>()->value == -1);
+  if (!src_pe.defined()) {
+    return false;
+  }
+  if (const auto *imm = src_pe.as<IntImmNode>()) {
+    return imm->value != -1;
+  }
+  return true;
 }

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines 96 to 100
PrimExpr local_base_ptr =
Call(DataType::Handle(), tl::get_remote_base_ptr(), {rank});
Call(DataType::Handle(), tl::get_remote_base(), {rank});
PrimExpr offset_to_base =
Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {bar_addr}),
local_base_ptr);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify get_remote_base intrinsic compatibility.

The barrier code now calls tl::get_remote_base() instead of tl::get_remote_base_ptr(). Ensure the new intrinsic returns the same DataType::Handle() type and accepts a rank parameter.

Run the following script to verify the intrinsic definition:


🏁 Script executed:

#!/bin/bash
# Find the definition of get_remote_base intrinsic
ast-grep --pattern $'const Op& get_remote_base() {
  $$$
}'

# Also search for its registration
rg -n 'get_remote_base' --type=cpp -C3

Length of output: 6899


I need to gather more information about the IR-level type specifications for these intrinsics to determine if the type mismatch is intentional or a bug.

[/web_search]

[/run_scripts]


🏁 Script executed:

#!/bin/bash
# Search for Op definitions and their return type specifications
rg -n 'get_local_base\|get_remote_base' src/op/distributed.cc -A10 -B2

# Check if DataType::Handle() is commonly used with uint64_t returns
rg -n 'DataType::Handle().*get_remote' src/ -C2

Length of output: 2191


🌐 Web query:

TVM TIR DataType Handle uint64_t implicit conversion

💡 Result:

Short answer: TVM does not implicitly convert a uint64_t integer into a TIR DataType::Handle. DataType::Handle is a pointer/opaque "handle" type (typically 64-bit) separate from integer types, and the C API's TVMValue union does not include a uint64 field (it has int64, double, and void*), so uint64 values are not represented natively in TVMValue and won't be auto-treated as a handle — you must explicitly construct a handle (pointer) or use an integer DataType (uint64) where appropriate. [1][2]

References:

  • TVM runtime DataType documentation (Handle / constructors). [1]
  • Discussion on TVMValue not supporting uint64 (explains TVMValue fields and lack of uint64). [2]

[/run_scripts]


🏁 Script executed:

#!/bin/bash
# Search for complete Op registrations with type info
rg -n 'TIR_DEFINE_TL_BUILTIN.*get_' src/op/distributed.cc -A15 | head -80

# Also check how get_uintptr_t and get_local_base are registered
rg -n 'get_uintptr_t\|get_local_base.*TIR_DEFINE' src/op/distributed.cc -B2 -A10

Length of output: 1388


[/run_scripts]


🏁 Script executed:

#!/bin/bash
# Check the actual usage context - see if there are type conversions nearby
rg -n 'get_remote_base\(\)' src/op/sync.cc src/op/remote_copy.cc -B5 -A10

# Also check if there's a Cast op being used
rg -n 'Cast\|reinterpret_cast' src/op/sync.cc src/op/remote_copy.cc | head -30

Length of output: 3715


Critical type mismatch: get_remote_base() returns uint64_t, not Handle.

The code changes from tl::get_remote_base_ptr() (returns void*) to tl::get_remote_base() (returns uint64_t), but incorrectly asserts DataType::Handle() at the call sites. Per TVM's type system, uint64_t does not implicitly convert to Handle.

Fix required at 5 locations:

  • src/op/sync.cc line 97
  • src/op/remote_copy.cc lines 103, 108, 209, 214

Either revert to get_remote_base_ptr() or use DataType::UInt(64) and add explicit Cast to Handle if a pointer is semantically required.

🤖 Prompt for AI Agents
In src/op/sync.cc around lines 96-100 the call to tl::get_remote_base() is
incorrectly typed as DataType::Handle; get_remote_base() returns uint64_t so
replace the Handle typing with DataType::UInt(64) (or cast the uint64_t to
Handle explicitly if you truly need a pointer) and wrap with an explicit Cast
where needed; apply the same fix at src/op/remote_copy.cc lines 103, 108, 209
and 214: change the Call type to DataType::UInt(64) and if subsequent code
expects a pointer, add a Cast(expr, DataType::Handle()) after the Call rather
than leaving the Call typed as Handle.

Comment on lines +134 to +152
if (call->op.same_as(create_tma_descriptor())) {
Var var;
auto iter = desc_map_.find(GetRef<Call>(call));
if (iter != desc_map_.end()) {
var = iter->second;
} else {
String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc",
PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var;
prefetch_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
// Mark the base pointer variable as TMA-related
if (auto base_var = call->args[2].as<Var>()) {
tma_related_vars_.insert(base_var.value());
}
}
return var;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix unsafe variable access at line 140.

Line 140 accesses call->args[2].as<Var>().value()->name_hint without checking if args[2] is actually a Var, which will crash if it's not. This is inconsistent with the safe check at lines 148-150.

Apply this diff to add a safety check:

     } else {
-      String name = call->args[2].as<Var>().value()->name_hint;
-      var = Var(name + "_desc",
-                PointerType(PrimType(cuTensorMapType()), "grid_constant"));
+      auto base_var_opt = call->args[2].as<Var>();
+      CHECK(base_var_opt) << "Expected args[2] to be a Var for create_tma_descriptor";
+      String name = base_var_opt.value()->name_hint;
+      var = Var(name + "_desc",
+                PointerType(PrimType(cuTensorMapType()), "grid_constant"));
       desc_map_[GetRef<Call>(call)] = var;
🤖 Prompt for AI Agents
In src/transform/lower_hopper_intrin.cc around lines 134 to 152, the code
unsafely dereferences call->args[2] as a Var when building the descriptor name;
first check whether call->args[2].as<Var>() is valid and only access name_hint
if it is, otherwise use a safe fallback name (e.g., "tma_desc") or derive a
unique name differently; reuse the same optional base_var for both name
construction and for inserting into tma_related_vars_ so the insertion is only
performed when base_var is valid; update the control flow so desc_map_ and
prefetch_calls_ still get created in the else branch but without unsafe
dereference.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a510e74 and a536bdd.

📒 Files selected for processing (1)
  • src/tl_templates/cuda/threadblock_swizzle.h (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/threadblock_swizzle.h (1)
src/tl_templates/hip/threadblock_swizzle.h (1)
  • rasterization2DRow (7-8)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (1)
src/tl_templates/cuda/threadblock_swizzle.h (1)

26-27: LGTM! Consistent default parameter addition.

The default argument for offset matches the rasterization2DRow function, providing consistent API across both rasterization functions.

namespace tl {

template <int panel_width, int offset> TL_DEVICE dim3 rasterization2DRow() {
template <int panel_width, int offset = 0> TL_DEVICE dim3 rasterization2DRow() {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

LGTM! Improved API usability with default parameter.

The default argument makes the offset parameter optional, improving usability for common cases where no offset is needed. The change is backward compatible.

Verify consistency with the HIP backend version to ensure both implementations have aligned signatures:


🏁 Script executed:

#!/bin/bash
# Description: Compare template signatures between CUDA and HIP versions

echo "=== CUDA version ==="
rg -n "template.*rasterization2D" src/tl_templates/cuda/threadblock_swizzle.h -A 1

echo -e "\n=== HIP version ==="
rg -n "template.*rasterization2D" src/tl_templates/hip/threadblock_swizzle.h -A 1

Length of output: 470


🏁 Script executed:

#!/bin/bash
# Get complete function signatures for both functions in CUDA and HIP

echo "=== CUDA: rasterization2DRow (full signature) ==="
sed -n '7,9p' src/tl_templates/cuda/threadblock_swizzle.h

echo -e "\n=== CUDA: rasterization2DColumn (full signature) ==="
sed -n '26,28p' src/tl_templates/cuda/threadblock_swizzle.h

echo -e "\n=== HIP: rasterization2DRow (full signature) ==="
sed -n '7,9p' src/tl_templates/hip/threadblock_swizzle.h

echo -e "\n=== HIP: rasterization2DColumn (full signature) ==="
sed -n '26,28p' src/tl_templates/hip/threadblock_swizzle.h

Length of output: 968


🏁 Script executed:

#!/bin/bash
# Check if offset parameter is used in HIP implementations

echo "=== HIP: Check for 'offset' usage in rasterization functions ==="
sed -n '7,45p' src/tl_templates/hip/threadblock_swizzle.h | grep -n "offset"

echo -e "\n=== Search for calls to rasterization2DRow/2DColumn with template arguments ==="
rg -n "rasterization2D(Row|Column)\s*<" --type=cc --type=h

Length of output: 745


🏁 Script executed:

#!/bin/bash
# View full CUDA implementations to see offset usage

echo "=== CUDA rasterization2DRow implementation ==="
sed -n '7,24p' src/tl_templates/cuda/threadblock_swizzle.h

echo -e "\n=== CUDA rasterization2DColumn implementation ==="
sed -n '26,45p' src/tl_templates/cuda/threadblock_swizzle.h

echo -e "\n=== Search for function calls to these templates ==="
rg -n "rasterization2D(Row|Column)" --type=cpp --type=hpp --type=h -C 2

Length of output: 2071


🏁 Script executed:

#!/bin/bash
# Broader search for function calls without file type restrictions

echo "=== Search for rasterization2DRow/Column calls ==="
rg -n "rasterization2D" -g "*.cu" -g "*.cuh" -g "*.h" -C 1

echo -e "\n=== Search with simpler pattern ==="
rg -n "rasterization2D(Row|Column)" -C 1

Length of output: 3189


🏁 Script executed:

#!/bin/bash
# Check git history for recent changes to HIP version

echo "=== Recent commits touching HIP version ==="
git log --oneline -n 20 -- src/tl_templates/hip/threadblock_swizzle.h

echo -e "\n=== Recent commits touching CUDA version ==="
git log --oneline -n 20 -- src/tl_templates/cuda/threadblock_swizzle.h

echo -e "\n=== Check if HIP version was modified in this PR ==="
git diff HEAD~1 -- src/tl_templates/hip/threadblock_swizzle.h

Length of output: 717


🏁 Script executed:

#!/bin/bash
# View the commit that introduced offset to understand original intent

echo "=== Commit 4b38278 details ==="
git show 4b38278 --stat

echo -e "\n=== Check if that commit touched HIP file ==="
git show 4b38278 -- src/tl_templates/hip/threadblock_swizzle.h

Length of output: 710


Backend signature mismatch: HIP implementation not updated.

CUDA now has template <int panel_width, int offset = 0> while HIP remains template <int panel_width>, creating incompatible signatures between backends. The offset parameter introduced in commit 4b38278 was added only to CUDA and not propagated to HIP.

Either update HIP implementations to include the offset parameter with the same default value, or reconsider adding it to CUDA if HIP-specific logic prevents it.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/threadblock_swizzle.h around line 7 the CUDA template
signature was changed to "template <int panel_width, int offset = 0>" while HIP
implementations still use "template <int panel_width>", creating a backend
signature mismatch; update the HIP header(s)/implementations to add the same
"int offset = 0" template parameter to the corresponding template
declarations/definitions so signatures match (ensure default value is 0 and
propagate the new parameter into any HIP-specific uses), or if HIP cannot accept
the parameter, revert the CUDA change so both backends share the same signature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants