-
Notifications
You must be signed in to change notification settings - Fork 4
[Feature] Support unified T.copy lowering to both SIMT and TMA for intra-node copy
#36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds remote-PE support and symmetric-buffer wiring for distributed TileLang copies, new tilescale intrinsics (get_remote_base/get_local_base), host-visible metadata, compiler passes to declare symm buffers, JIT/descriptor updates, CUDA codegen changes, and example/tests demonstrating remote tile-scale copy. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User/Main
participant Lang as Language API
participant Transform as Transform Passes
participant Codegen as CUDA Codegen
participant Runtime as Runtime (host/device)
User->>Lang: call copy(src,dst,src_pe,dst_pe)
Lang->>Lang: emit tl.copy intrinsic with src_pe/dst_pe
Lang->>Transform: LowerAndLegalize
Transform->>Transform: DeclareSymmBuffer detects remote push/pull
alt remote_push
Transform->>Transform: compute symm_ptr via get_remote_base(dst_pe)
Transform->>Transform: wrap copy with LetStmt(symm_ptr)
else remote_pull
Transform->>Transform: compute symm_ptr via get_remote_base(src_pe)
Transform->>Transform: wrap copy with LetStmt(symm_ptr)
end
Transform->>Codegen: transformed IR
Codegen->>Codegen: emit tl::get_remote_base / get_local_base calls
Codegen->>Runtime: produced CUDA + host init (host_meta_data)
Runtime->>Runtime: host/device branch returns meta_data or host_meta_data
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
0874b4f to
7aaa99e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/transform/lower_hopper_intrin.cc (1)
154-166: Also mark the im2col base pointer as TMA-related.We only record the base pointer var for
create_tma_descriptor. When the same pattern occurs withcreate_tma_im2col_descriptor, the hoistedtvm_call_packedstill references the Let-bound var, so the var is out of scope and the descriptor init sees a dangling symbol. Please add the sametma_related_vars_bookkeeping in the im2col branch so its base pointer gets inlined as well.} else if (call->op.same_as(create_tma_im2col_descriptor())) { @@ - prefetch_calls_.push_back( - Evaluate(Call(DataType::Handle(), builtin::call_extern(), - {StringImm("tl::prefetch_tma_descriptor"), var}))); + prefetch_calls_.push_back( + Evaluate(Call(DataType::Handle(), builtin::call_extern(), + {StringImm("tl::prefetch_tma_descriptor"), var}))); + if (auto base_var = call->args[2].as<Var>()) { + tma_related_vars_.insert(base_var.value()); + } }src/op/copy.cc (1)
133-176: Initialize remote PE defaults before calling is_remote_*.If any existing
tl.copycall site still uses the legacy 5-argument form,args.size()stays ≤5 sosrc_pe/dst_peremain undefined. The subsequentis_remote_push()/is_remote_pull()immediately dereference those unsetPrimExprs, triggering anICHECKfailure. Defaulting them toInteger(-1)(and then overriding when extra args are present) keeps the legacy path safe.ObjectPtr<CopyNode> node = make_object<CopyNode>(); + node->src_pe = Integer(-1); + node->dst_pe = Integer(-1); Array<Range> rgs[2]; Buffer bf[2];
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
examples/distributed/primitives/example_tilescale_copy.py(1 hunks)src/op/copy.cc(2 hunks)src/op/copy.h(1 hunks)src/op/distributed.cc(1 hunks)src/op/distributed.h(1 hunks)src/target/codegen_cuda.cc(2 hunks)src/tl_templates/cuda/common.h(1 hunks)src/tl_templates/cuda/distributed.h(1 hunks)src/transform/declare_symm_buffer.cc(1 hunks)src/transform/lower_hopper_intrin.cc(4 hunks)tilelang/engine/phase.py(1 hunks)tilelang/jit/adapter/utils.py(1 hunks)tilelang/jit/adapter/wrapper.py(5 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/copy.py(2 hunks)tilelang/language/proxy.py(2 hunks)tilelang/transform/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (12)
tilelang/engine/phase.py (2)
src/transform/declare_symm_buffer.cc (2)
DeclareSymmBuffer(220-227)DeclareSymmBuffer(220-220)tilelang/transform/__init__.py (1)
DeclareSymmBuffer(492-498)
src/transform/declare_symm_buffer.cc (5)
tilelang/language/proxy.py (1)
ptr(273-295)tilelang/language/distributed/common.py (1)
get_rank(8-11)src/tl_templates/cuda/distributed.h (1)
get_uintptr_t(47-49)tilelang/language/ast/ir.py (1)
LetStmt(880-908)tilelang/transform/__init__.py (1)
DeclareSymmBuffer(492-498)
examples/distributed/primitives/example_tilescale_copy.py (8)
src/tl_templates/cuda/reduce.h (1)
T(208-280)tilelang/distributed/utils.py (1)
init_dist(40-62)tilelang/language/distributed/common.py (1)
get_rank(8-11)tilelang/language/copy.py (1)
copy(11-98)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/utils/allocator.py (1)
get_allocator(236-248)tilelang/jit/kernel.py (1)
initialize(407-416)tilelang/utils/tensor.py (1)
tensor(43-56)
src/op/copy.cc (2)
src/op/copy.h (1)
is_remote_copy(100-124)tilelang/language/copy.py (1)
copy(11-98)
src/tl_templates/cuda/distributed.h (2)
src/op/distributed.h (1)
tl(11-238)tilelang/language/distributed/common.py (2)
get_rank(8-11)get_num_ranks(14-17)
tilelang/language/__init__.py (1)
tilelang/language/proxy.py (1)
make_tensor_like(305-313)
src/op/copy.h (1)
src/op/copy.cc (6)
is_remote_copy(190-192)is_remote_copy(190-190)is_remote_push(180-183)is_remote_push(180-180)is_remote_pull(185-188)is_remote_pull(185-185)
tilelang/transform/__init__.py (1)
src/transform/declare_symm_buffer.cc (2)
DeclareSymmBuffer(220-227)DeclareSymmBuffer(220-220)
tilelang/language/copy.py (2)
tilelang/utils/language.py (1)
get_buffer_region_from_load(124-146)tilelang/language/tir/op.py (1)
call_intrin(120-145)
tilelang/jit/adapter/utils.py (1)
src/op/distributed.h (1)
tvm(10-239)
tilelang/jit/adapter/wrapper.py (2)
tilelang/jit/adapter/utils.py (2)
pythonic_expr(110-218)tilescale_pythonic_expr(221-346)src/op/distributed.h (1)
tvm(10-239)
src/transform/lower_hopper_intrin.cc (4)
tilelang/language/ast/ir.py (1)
LetStmt(880-908)tilelang/language/builtin.py (1)
create_tma_descriptor(74-83)src/op/builtin.cc (2)
cuTensorMapType(39-39)cuTensorMapType(39-39)tilelang/language/tir/op.py (1)
call_extern(173-195)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc
[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found
(clang-diagnostic-error)
🪛 GitHub Actions: CI
tilelang/language/proxy.py
[error] 307-309: UP045 Use X | None for type annotations in function signature: 'shape: Optional[tuple[PrimExpr, ...]]', 'dtype: Optional[str]', 'strides: Optional[tuple[PrimExpr, ...]]'.
[error] 308-309: UP045 Use X | None for type annotations in function signature: 'dtype' and 'strides'.
[error] 307-309: Ruff check: Multiple UP045 suggestions; convert annotations to use '| None' where appropriate.
tilelang/engine/phase.py
[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.
tilelang/language/__init__.py
[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.
tilelang/transform/__init__.py
[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.
tilelang/language/copy.py
[error] 14-15: UP045 Use X | None for type annotations in 'src_pe' and 'dst_pe' parameters.
[error] 30-30: W293 Blank line contains whitespace in docstring/header area.
[error] 14-16: Ruff check: Use X | None for type annotations (UP045) and related fixes suggested by the linter.
tilelang/jit/adapter/utils.py
[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.
tilelang/jit/adapter/wrapper.py
[error] 1-1: Pre-commit yapf: code formatting changes were made. Please re-run the pipeline or commit formatted files.
🪛 Ruff (0.14.2)
examples/distributed/primitives/example_tilescale_copy.py
15-15: Unused function argument: rank
(ARG001)
182-182: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/__init__.py
19-19: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
| * on peer's symm heap | ||
| */ | ||
|
|
||
| #include <tvm/ffi/reflection/registry.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the registry header include
<tvm/ffi/reflection/registry.h> does not exist, so this include breaks the build. Please pull in the registry helpers from <tvm/runtime/registry.h> instead.
-#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/registry.h>📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include <tvm/ffi/reflection/registry.h> | |
| #include <tvm/runtime/registry.h> |
🧰 Tools
🪛 Clang (14.0.6)
[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found
(clang-diagnostic-error)
🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around line 9, the include currently
references a non-existent header <tvm/ffi/reflection/registry.h>; replace it
with the correct registry header <tvm/runtime/registry.h> so the file pulls in
the registry helpers used by the code and the build can succeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
tilelang/language/copy.py (2)
14-15: Fix type annotations to match the default values.The type hints specify
tir.PrimExpr | tir.IntImm | Nonebut the default value is-1(a Pythonint), creating a type mismatch.Apply this diff to include
intin the type union:- src_pe: tir.PrimExpr | tir.IntImm | None = -1, - dst_pe: tir.PrimExpr | tir.IntImm | None = -1, + src_pe: tir.PrimExpr | tir.IntImm | int | None = -1, + dst_pe: tir.PrimExpr | tir.IntImm | int | None = -1,Alternatively, change the defaults to
Noneand normalize inside the function:- src_pe: tir.PrimExpr | tir.IntImm | None = -1, - dst_pe: tir.PrimExpr | tir.IntImm | None = -1, + src_pe: tir.PrimExpr | tir.IntImm | None = None, + dst_pe: tir.PrimExpr | tir.IntImm | None = None, ... + if src_pe is None: + src_pe = -1 + if dst_pe is None: + dst_pe = -1
95-95: Assertion breaks with symbolic PrimExpr arguments.The assertion
src_pe == -1 or dst_pe == -1attempts to evaluate symbolicPrimExprvalues in a boolean context (e.g.,src_pe=1 - T.get_rank()as used in the TMA example), which raisesTypeError: Cannot convert PrimExpr to boolat runtime.Apply this diff to check only constant values:
- assert src_pe == -1 or dst_pe == -1, "At least one of src_pe or dst_pe must be local rank" + def _is_const_remote(pe): + if isinstance(pe, int): + return pe != -1 + if isinstance(pe, tir.IntImm): + return pe.value != -1 + return False + + if _is_const_remote(src_pe) and _is_const_remote(dst_pe): + raise ValueError("At least one of src_pe or dst_pe must be local rank")examples/distributed/primitives/example_tilescale_copy.py (1)
15-15: Remove the unusedrankparameter.The
rankparameter is never referenced in the function body. While the comment at line 158 suggests it's for TMA kernels, those implementations useT.get_rank()directly (lines 98, 128) instead of a compile-time constant.Apply this diff:
-def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile', rank=None): +def get_kernel(M, N, block_M, block_N, threads, kernel='simt_push_tile'):And update the call site at line 158:
kernel = get_kernel( - M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel, - rank=local_rank) # only TMA kernels need compile-time aware peer rank + M, N, BLOCK_M, BLOCK_N, threads, kernel=args.kernel)
🧹 Nitpick comments (1)
examples/distributed/primitives/example_tilescale_copy.py (1)
163-183: Validation logic correctly verifies peer-to-peer copy.The test correctly:
- Allocates tensors on the distributed allocator
- Synchronizes before and after kernel execution
- Uses
all_gatherto collect each rank's source tensor- Validates that local
dstmatches the peer'ssrcusinglocal_rank ^ 1indexingMinor: Line 181 triggers a style warning (TRY003) for the inline exception message. Consider defining a custom exception or using a simpler message:
- raise ValueError("Test failed") + raise ValueError
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/distributed/primitives/example_tilescale_copy.py(1 hunks)tilelang/language/copy.py(2 hunks)tilelang/language/proxy.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/copy.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
examples/distributed/primitives/example_tilescale_copy.py (6)
tilelang/distributed/utils.py (1)
init_dist(40-62)tilelang/language/distributed/common.py (1)
get_rank(8-11)tilelang/language/copy.py (1)
copy(11-98)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/utils/allocator.py (1)
get_allocator(236-248)tilelang/utils/tensor.py (1)
tensor(43-56)
🪛 Ruff (0.14.2)
examples/distributed/primitives/example_tilescale_copy.py
15-15: Unused function argument: rank
(ARG001)
181-181: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
🔇 Additional comments (5)
tilelang/language/proxy.py (1)
305-313: LGTM! Clean helper function.The
make_tensor_likefunction provides a clean interface for creating tensor-like buffers with optional property overrides. The None-check fallback pattern is clear and the type hints appropriately reflect the optional nature of the parameters.tilelang/language/copy.py (1)
97-98: LGTM! Intrinsic call correctly extended.The
tir.call_intrininvocation correctly passes the newsrc_peanddst_peparameters, maintaining consistency with the updated function signature and the underlying intrinsic's expanded interface.examples/distributed/primitives/example_tilescale_copy.py (3)
17-136: LGTM! Comprehensive demonstration of remote copy patterns.The five kernel implementations effectively demonstrate different distributed copy strategies:
- SIMT push/pull patterns with explicit per-endpoint signaling
- TMA load/store patterns using runtime rank expressions
- Proper staging through shared memory with swizzled layouts
All kernels correctly use
1 - rankto target the peer in a 2-rank configuration.
139-162: LGTM! Proper distributed setup and kernel initialization.The initialization sequence correctly:
- Validates the 2-rank constraint required by the peer-signaling logic
- Configures the distributed allocator with appropriate parameters
- Compiles the selected kernel and prints source on rank 0 only
186-194: LGTM! Clean script entrypoint.The argument parser and multiprocessing spawn setup is clean and appropriate for the 2-rank distributed test scenario.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/distributed/primitives/example_tilescale_copy.py(1 hunks)src/op/copy.cc(2 hunks)src/tl_templates/cuda/distributed.h(1 hunks)tilelang/language/copy.py(2 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/tl_templates/cuda/distributed.h
🧰 Additional context used
🧬 Code graph analysis (3)
src/op/copy.cc (2)
src/op/copy.h (1)
is_remote_copy(100-124)tilelang/language/copy.py (1)
copy(11-96)
examples/distributed/primitives/example_tilescale_copy.py (10)
tilelang/distributed/utils.py (1)
init_dist(40-62)tilelang/env.py (1)
disable_cache(286-287)tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/allocate.py (1)
alloc_shared(24-39)tilelang/language/distributed/common.py (1)
get_rank(8-11)tilelang/language/copy.py (1)
copy(11-96)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/utils/allocator.py (1)
get_allocator(236-248)tilelang/jit/kernel.py (1)
initialize(407-416)tilelang/utils/tensor.py (1)
tensor(43-56)
tilelang/language/copy.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
🪛 Ruff (0.14.2)
examples/distributed/primitives/example_tilescale_copy.py
179-179: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
🔇 Additional comments (9)
src/op/copy.cc (2)
158-178: LGTM! Remote copy parameter parsing and validation is correct.The parameter parsing and validation logic is well-structured:
- Correctly parses
src_peanddst_pefrom args[5] and args[6]- Enforces mutual exclusion (at least one endpoint must be local)
- Validates memory scopes appropriately (remote push requires global dst, remote pull requires global src)
183-195: LGTM! Query helper implementations are correct.The query helpers correctly identify remote copy operations by checking whether src_pe or dst_pe differ from the local sentinel value (-1).
tilelang/language/copy.py (3)
11-19: LGTM! Function signature is reasonable.The type hints
tir.PrimExpr | tir.IntImm | None = -1correctly include the implicitinttype that Python's-1literal represents. While the hint could explicitly listint, the current form is acceptable since Python int literals are compatible with the PrimExpr type system in TVM.
20-36: LGTM! Documentation is clear and accurate.The docstring properly documents the new
src_peanddst_peparameters, explaining that -1 indicates a local endpoint.
94-96: LGTM! Parameter propagation is correct.The function correctly passes
src_peanddst_pethrough to thetl.copyintrinsic. Validation is appropriately handled at the C++ layer (insrc/op/copy.cc), avoiding issues with symbolic PrimExpr values that would arise from Python-level assertions.examples/distributed/primitives/example_tilescale_copy.py (4)
1-12: LGTM! Imports and setup are appropriate.The imports and environment configuration are well-suited for a distributed copy example. Disabling the cache and silencing NCCL logs improve the testing experience.
14-136: LGTM! Kernel variants demonstrate remote copy patterns effectively.The
get_kernelfunction provides a comprehensive set of examples:
- simt_push_buffer: Full-buffer remote push
- simt_push_tile: Tiled push via shared memory
- simt_pull_tile: Tiled pull via shared memory
- tma_load_tile: TMA-based remote load
- tma_store_tile: TMA-based remote store
The distinction between using
rank[0](SIMT kernels) andT.get_rank()(TMA kernels) is noted in the comment on line 99, which helps users understand the current limitation.
139-181: LGTM! Main function demonstrates proper distributed workflow.The main function correctly implements:
- Distributed initialization with barrier synchronization
- Allocator configuration for multi-rank execution
- Kernel compilation and initialization
- Tensor allocation on distributed allocator
- Proper synchronization before and after kernel execution
- Result validation using all_gather and peer data comparison (via XOR to get opposite rank)
- Cleanup of distributed process group
The validation logic on line 172 (
dst_torchs[local_rank ^ 1]) cleverly uses XOR to select the peer's source data for comparison.
184-192: LGTM! Script entrypoint is well-structured.The argument parser provides sensible defaults and the use of
torch.multiprocessing.spawncorrectly launches the two required processes for this distributed example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (1)
src/transform/declare_symm_buffer.cc (1)
9-9: Fix the registry header include.This issue was already flagged in a previous review. The header
<tvm/ffi/reflection/registry.h>does not exist and breaks the build. Replace it with<tvm/runtime/registry.h>.-#include <tvm/ffi/reflection/registry.h> +#include <tvm/runtime/registry.h>
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/declare_symm_buffer.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/declare_symm_buffer.cc (4)
tilelang/language/distributed/common.py (1)
get_rank(8-11)src/tl_templates/cuda/distributed.h (1)
get_uintptr_t(47-49)tilelang/language/ast/ir.py (1)
LetStmt(880-908)tilelang/transform/__init__.py (1)
DeclareSymmBuffer(492-498)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc
[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found
(clang-diagnostic-error)
🔇 Additional comments (3)
src/transform/declare_symm_buffer.cc (3)
50-67: LGTM!The
Applymethod correctly implements the standard TVM pass pattern: early return for undefined bodies, buffer map initialization, and stmt visitation.
70-79: LGTM!The block visitor correctly populates the buffer map from both
alloc_buffersandmatch_buffers, then delegates to the parent visitor.
242-254: Pass creation and FFI registration look correct.The pass creation follows the standard TVM pattern, and the FFI registration is properly structured. However, this code depends on fixing the header include at line 9 to function correctly.
Ensure that after fixing the header include at line 9, the FFI registration compiles and the pass is accessible from Python as
tl.transform.DeclareSymmBuffer().
| static int name_suffix_id = | ||
| 0; // Avoid name collision for symm buffers, start from 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: global mutable state in transformation pass.
The global name_suffix_id is incremented without synchronization (lines 110, 178), causing data races when multiple passes run concurrently. Move this counter into the SymmBufferDeclarer class as a member variable.
-static int name_suffix_id =
- 0; // Avoid name collision for symm buffers, start from 0Add to the SymmBufferDeclarer class private members:
Map<Var, Buffer> buffer_data_to_buffer_;
+ int name_suffix_id_ = 0;Then update lines 110 and 178 to use name_suffix_id_++ instead of name_suffix_id++.
🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around lines 29-30 (and usages at ~110
and ~178), a global mutable variable `name_suffix_id` is used and incremented,
causing race conditions when passes run concurrently; move this counter into the
SymmBufferDeclarer class as a private member (e.g., add int name_suffix_id_ = 0
alongside buffer_data_to_buffer_), then replace usages of the global
`name_suffix_id++` at the noted locations with `name_suffix_id_++` and remove
the global variable to eliminate the shared mutable state.
| if (parsed_op.as<CopyNode>()->is_remote_push()) { | ||
| // LOG(INFO) << "Found remote push"; | ||
|
|
||
| Buffer dst = parsed_op.as<CopyNode>()->dst; | ||
| Array<Range> dst_range = parsed_op.as<CopyNode>()->dst_range; | ||
|
|
||
| // 1. Calculate symm dst ptr | ||
| PrimExpr symm_dst_ptr_expr = | ||
| CalculateSymmPtr(dst->data, parsed_op.as<CopyNode>()->dst_pe); | ||
| // LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr; | ||
|
|
||
| // 2. Create a let binding | ||
| String storage_scope = | ||
| dst->data->type_annotation.as<PointerTypeNode>()->storage_scope; | ||
| Var symm_dst_var = | ||
| Var(dst->name + "_symm_" + std::to_string(name_suffix_id++), | ||
| PointerType(PrimType(dst->dtype), storage_scope)); | ||
|
|
||
| // 3. Create modified dst buffer with symm var | ||
| dst.CopyOnWrite()->data = symm_dst_var; | ||
|
|
||
| // 4. Rebuild the destination region call with the modified buffer | ||
| // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, | ||
| // extent_1, ...] | ||
| Array<PrimExpr> dst_region_mins; | ||
| Array<PrimExpr> dst_region_extents; | ||
| for (const Range &r : dst_range) { | ||
| dst_region_mins.push_back(r->min); | ||
| dst_region_extents.push_back(r->extent); | ||
| } | ||
| BufferLoad dst_load(dst, dst_region_mins); | ||
|
|
||
| Array<PrimExpr> dst_region_args; | ||
| dst_region_args.push_back(dst_load); | ||
| dst_region_args.push_back( | ||
| IntImm(DataType::Int(32), call_op->args[1] | ||
| .as<CallNode>() | ||
| ->args[1] | ||
| .as<IntImmNode>() | ||
| ->value)); // access_mask | ||
| for (const PrimExpr &extent : dst_region_extents) { | ||
| dst_region_args.push_back(extent); | ||
| } | ||
|
|
||
| // Create new Call for the destination region | ||
| Call dst_region_call = | ||
| Call(call_op->args[1].as<CallNode>()->dtype, | ||
| call_op->args[1].as<CallNode>()->op, dst_region_args, | ||
| call_op->args[1].as<CallNode>()->span); | ||
|
|
||
| // 5. Rebuild the Copy call with modified args | ||
| Array<PrimExpr> new_copy_args; | ||
| new_copy_args.push_back(call_op->args[0]); // src region (unchanged) | ||
| new_copy_args.push_back(dst_region_call); // modified dst region | ||
| // Copy remaining args | ||
| for (size_t i = 2; i < call_op->args.size(); i++) { | ||
| new_copy_args.push_back(call_op->args[i]); | ||
| } | ||
|
|
||
| // Create the modified copy call | ||
| Call new_copy_call = | ||
| Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); | ||
|
|
||
| // Wrap it in an Evaluate statement | ||
| Stmt modified_stmt = Evaluate(new_copy_call); | ||
|
|
||
| // Wrap with LetStmt that defines the symm pointer | ||
| return LetStmt(symm_dst_var, symm_dst_ptr_expr, modified_stmt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add defensive null checks for nested property access.
The code has several unsafe property accesses that can crash:
- Line 108:
type_annotation.as<PointerTypeNode>()->storage_scopeassumes type_annotation is defined and is a PointerType. - Lines 130-134: Deeply nested
as<CallNode>()andas<IntImmNode>()calls without checking intermediate results.
If any cast fails, dereferencing the null pointer crashes the compiler.
Apply these defensive checks:
if (parsed_op.as<CopyNode>()->is_remote_push()) {
// LOG(INFO) << "Found remote push";
Buffer dst = parsed_op.as<CopyNode>()->dst;
Array<Range> dst_range = parsed_op.as<CopyNode>()->dst_range;
// 1. Calculate symm dst ptr
PrimExpr symm_dst_ptr_expr =
CalculateSymmPtr(dst->data, parsed_op.as<CopyNode>()->dst_pe);
// LOG(INFO) << "Symm dst ptr expr: " << symm_dst_ptr_expr;
// 2. Create a let binding
+ ICHECK(dst->data->type_annotation.defined())
+ << "Buffer data variable must have type annotation";
+ auto ptr_type = dst->data->type_annotation.as<PointerTypeNode>();
+ ICHECK(ptr_type) << "Buffer data type annotation must be PointerType";
String storage_scope =
- dst->data->type_annotation.as<PointerTypeNode>()->storage_scope;
+ ptr_type->storage_scope;
Var symm_dst_var =
Var(dst->name + "_symm_" + std::to_string(name_suffix_id++),
PointerType(PrimType(dst->dtype), storage_scope));
// 3. Create modified dst buffer with symm var
dst.CopyOnWrite()->data = symm_dst_var;
// 4. Rebuild the destination region call with the modified buffer
// RegionOp args: [BufferLoad(min_indices), access_mask, extent_0,
// extent_1, ...]
Array<PrimExpr> dst_region_mins;
Array<PrimExpr> dst_region_extents;
for (const Range &r : dst_range) {
dst_region_mins.push_back(r->min);
dst_region_extents.push_back(r->extent);
}
BufferLoad dst_load(dst, dst_region_mins);
Array<PrimExpr> dst_region_args;
dst_region_args.push_back(dst_load);
+ auto dst_region_call_node = call_op->args[1].as<CallNode>();
+ ICHECK(dst_region_call_node) << "Expected Call node for dst region";
+ auto access_mask_node = dst_region_call_node->args[1].as<IntImmNode>();
+ ICHECK(access_mask_node) << "Expected IntImm for access_mask";
dst_region_args.push_back(
- IntImm(DataType::Int(32), call_op->args[1]
- .as<CallNode>()
- ->args[1]
- .as<IntImmNode>()
- ->value)); // access_mask
+ IntImm(DataType::Int(32), access_mask_node->value)); // access_mask🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around lines 95-162, several chained
.as<...> dereferences can return nullptr and crash; add defensive null/type
checks before dereferencing: verify dst->data->type_annotation is non-null and
its .as<PointerTypeNode>() succeeds before accessing storage_scope; verify
call_op and call_op->args size >=2, that call_op->args[1].as<CallNode>() is
non-null, that that CallNode has args size >=2 and that args[1].as<IntImmNode>()
is non-null before reading ->value; on any check failure handle gracefully
(e.g., log/ICHECK(false) with a clear message or return the original
stmt/unmodified Call) instead of dereferencing a null pointer so the compiler
won't crash.
| } else if (parsed_op.as<CopyNode>()->is_remote_pull()) { | ||
| LOG(INFO) << "Found remote pull"; | ||
|
|
||
| Buffer src = parsed_op.as<CopyNode>()->src; | ||
| Array<Range> src_range = parsed_op.as<CopyNode>()->src_range; | ||
|
|
||
| // 1. Calculate symm src ptr | ||
| PrimExpr symm_src_ptr_expr = | ||
| CalculateSymmPtr(src->data, parsed_op.as<CopyNode>()->src_pe); | ||
| // LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr; | ||
|
|
||
| // 2. Create a let binding | ||
| String storage_scope = | ||
| src->data->type_annotation.as<PointerTypeNode>()->storage_scope; | ||
| Var symm_src_var = | ||
| Var(src->name + "_symm_" + std::to_string(name_suffix_id++), | ||
| PointerType(PrimType(src->dtype), storage_scope)); | ||
|
|
||
| // 3. Create modified src buffer with symm var | ||
| src.CopyOnWrite()->data = symm_src_var; | ||
|
|
||
| // 4. Rebuild the source region call with the modified buffer | ||
| // RegionOp args: [BufferLoad(min_indices), access_mask, extent_0, | ||
| // extent_1, ...] | ||
| Array<PrimExpr> src_region_mins; | ||
| Array<PrimExpr> src_region_extents; | ||
| for (const Range &r : src_range) { | ||
| src_region_mins.push_back(r->min); | ||
| src_region_extents.push_back(r->extent); | ||
| } | ||
| BufferLoad src_load(src, src_region_mins); | ||
|
|
||
| Array<PrimExpr> src_region_args; | ||
| src_region_args.push_back(src_load); | ||
| src_region_args.push_back( | ||
| IntImm(DataType::Int(32), call_op->args[1] | ||
| .as<CallNode>() | ||
| ->args[1] | ||
| .as<IntImmNode>() | ||
| ->value)); // access_mask | ||
| for (const PrimExpr &extent : src_region_extents) { | ||
| src_region_args.push_back(extent); | ||
| } | ||
|
|
||
| // Create new Call for the source region | ||
| Call src_region_call = | ||
| Call(call_op->args[0].as<CallNode>()->dtype, | ||
| call_op->args[0].as<CallNode>()->op, src_region_args, | ||
| call_op->args[0].as<CallNode>()->span); | ||
|
|
||
| // 5. Rebuild the Copy call with modified args | ||
| Array<PrimExpr> new_copy_args; | ||
| new_copy_args.push_back(src_region_call); // modified src region | ||
| new_copy_args.push_back(call_op->args[1]); // dst region (unchanged) | ||
| // Copy remaining args | ||
| for (size_t i = 2; i < call_op->args.size(); i++) { | ||
| new_copy_args.push_back(call_op->args[i]); | ||
| } | ||
|
|
||
| // Create the modified copy call | ||
| Call new_copy_call = | ||
| Call(call_op->dtype, call_op->op, new_copy_args, call_op->span); | ||
|
|
||
| // Wrap it in an Evaluate statement | ||
| Stmt modified_stmt = Evaluate(new_copy_call); | ||
|
|
||
| // Wrap with LetStmt that defines the symm pointer | ||
| return LetStmt(symm_src_var, symm_src_ptr_expr, modified_stmt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add defensive null checks and remove debug LOG statement.
This code has the same issues as the remote_push block:
- Line 164:
LOG(INFO)should be removed or commented out (other debug logs are commented). - Line 176:
type_annotation.as<PointerTypeNode>()->storage_scopeneeds null checks. - Lines 198-202: Deeply nested
as<>()calls without validation.
Apply similar defensive checks as suggested for remote_push:
} else if (parsed_op.as<CopyNode>()->is_remote_pull()) {
- LOG(INFO) << "Found remote pull";
+ // LOG(INFO) << "Found remote pull";
Buffer src = parsed_op.as<CopyNode>()->src;
Array<Range> src_range = parsed_op.as<CopyNode>()->src_range;
// 1. Calculate symm src ptr
PrimExpr symm_src_ptr_expr =
CalculateSymmPtr(src->data, parsed_op.as<CopyNode>()->src_pe);
// LOG(INFO) << "Symm src ptr expr: " << symm_src_ptr_expr;
// 2. Create a let binding
+ ICHECK(src->data->type_annotation.defined())
+ << "Buffer data variable must have type annotation";
+ auto ptr_type = src->data->type_annotation.as<PointerTypeNode>();
+ ICHECK(ptr_type) << "Buffer data type annotation must be PointerType";
String storage_scope =
- src->data->type_annotation.as<PointerTypeNode>()->storage_scope;
+ ptr_type->storage_scope;
Var symm_src_var =
Var(src->name + "_symm_" + std::to_string(name_suffix_id++),
PointerType(PrimType(src->dtype), storage_scope));
// 3. Create modified src buffer with symm var
src.CopyOnWrite()->data = symm_src_var;
// 4. Rebuild the source region call with the modified buffer
// RegionOp args: [BufferLoad(min_indices), access_mask, extent_0,
// extent_1, ...]
Array<PrimExpr> src_region_mins;
Array<PrimExpr> src_region_extents;
for (const Range &r : src_range) {
src_region_mins.push_back(r->min);
src_region_extents.push_back(r->extent);
}
BufferLoad src_load(src, src_region_mins);
Array<PrimExpr> src_region_args;
src_region_args.push_back(src_load);
+ auto src_region_call_node = call_op->args[0].as<CallNode>();
+ ICHECK(src_region_call_node) << "Expected Call node for src region";
+ auto access_mask_node = src_region_call_node->args[1].as<IntImmNode>();
+ ICHECK(access_mask_node) << "Expected IntImm for access_mask";
src_region_args.push_back(
- IntImm(DataType::Int(32), call_op->args[1]
- .as<CallNode>()
- ->args[1]
- .as<IntImmNode>()
- ->value)); // access_mask
+ IntImm(DataType::Int(32), access_mask_node->value)); // access_mask🤖 Prompt for AI Agents
In src/transform/declare_symm_buffer.cc around lines 163-230, remove the debug
LOG(INFO) and add defensive null/size checks before every use of as<...> and
indexing: validate parsed_op.as<CopyNode>() is non-null, ensure
src->data->type_annotation.as<PointerTypeNode>() is non-null before accessing
storage_scope (provide a safe fallback or return original stmt on failure),
check call_op and call_op->args[0].as<CallNode>() /
call_op->args[1].as<CallNode>() are non-null before accessing their fields, and
verify call_op->args has the expected size before indexing args[1] and iterating
from index 2; if any check fails, bail out cleanly (e.g., return the original
stmt or propagate an error) instead of dereferencing nulls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG(INFO) commented out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Rachmanino, thanks for commenting out the LOG(INFO)! 👍
Just a gentle reminder that there are still two items from the original feedback that could improve robustness:
- Line 176: Adding null checks before accessing
storage_scopefrom the type annotation - Lines 198-202: Validating the nested
as<>()calls before dereferencing
These defensive checks would help prevent potential crashes if the IR structure is unexpected. However, if you're confident the IR will always have the expected shape at this point in the pipeline, feel free to defer these checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
src/op/copy.cc (1)
1979-1983: Fix the input count in the documentation comment.The comment states "Takes 8 inputs" but only 7 are listed and the registration correctly uses
.set_num_inputs(7). Update the comment to reflect the actual count.Apply this diff:
-// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, +// - Takes 7 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, // eviction_policy, src_pe, dst_pesrc/transform/declare_symm_buffer.cc (5)
9-9: Fix the incorrect header include.The header
<tvm/ffi/reflection/registry.h>does not exist. Use<tvm/runtime/registry.h>instead.-#include <tvm/ffi/reflection/registry.h> +#include <tvm/runtime/registry.h>
29-30: Race condition: Move global mutable counter into class.The global
name_suffix_idis incremented without synchronization (lines 111, 179), causing data races when passes run concurrently. Move this counter intoSymmBufferDeclareras a member variable.
34-42: Fix type inconsistencies in pointer arithmetic.Multiple issues:
- Line 35:
local_rankis declared but never used- Lines 36-40: Mixing Handle and integer types incorrectly in pointer arithmetic
get_uintptr_t()should work withDataType::UInt(64), not HandleAll pointer arithmetic should be performed on UInt(64) values, then cast back to Handle for the return.
96-163: Add defensive null checks for nested property access.Lines 108 and 130-135 contain unsafe property accesses that can crash if any cast fails:
type_annotation.as<PointerTypeNode>()->storage_scopeassumes the type is defined and is a PointerType- Deeply nested
.as<CallNode>()and.as<IntImmNode>()without checking intermediate results
164-231: Remove debug LOG statement and add defensive null checks.Line 165:
LOG(INFO)should be commented out (other debug logs are commented).
Lines 176-177 and 198-202 contain unsafe property accesses without null checks.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/copy.cc(2 hunks)src/transform/declare_symm_buffer.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/declare_symm_buffer.cc (5)
tilelang/language/proxy.py (1)
ptr(273-295)tilelang/language/distributed/common.py (1)
get_rank(8-11)src/tl_templates/cuda/distributed.h (1)
get_uintptr_t(47-49)tilelang/language/ast/ir.py (1)
LetStmt(880-908)tilelang/transform/__init__.py (1)
DeclareSymmBuffer(492-498)
src/op/copy.cc (2)
src/op/copy.h (1)
is_remote_copy(100-124)tilelang/language/copy.py (1)
copy(11-96)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc
[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found
(clang-diagnostic-error)
🔇 Additional comments (3)
src/op/copy.cc (2)
158-179: LGTM! Remote copy parameter parsing and validation is correct.The code properly:
- Parses src_pe/dst_pe with bounds checking
- Ensures at least one side is local (prevents both being remote)
- Validates memory scope constraints for peer access (remote push requires dst in global, remote pull requires src in global)
183-195: LGTM! Helper methods correctly identify remote copy types.The implementation properly detects:
- Remote push when dst_pe ≠ -1
- Remote pull when src_pe ≠ -1
- Remote copy as the union of both
src/transform/declare_symm_buffer.cc (1)
243-255: LGTM! Pass registration follows standard TVM patterns.The pass creation and FFI registration are implemented correctly using the standard TVM idioms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
.github/workflows/ci.yml (1)
108-132: Consider aligning distributed test execution between examples and tests sections.The examples section (line 119) runs distributed tests with
TILELANG_USE_DISTRIBUTED=1 python -m pytest -n 4 "${DIST_TESTS[@]}"directly, while the tests section now uses a per-test script approach with random ports. This inconsistency may indicate the per-test approach is necessary to prevent test collisions, but the examples section hasn't been updated.Verify whether the examples section would benefit from the same per-test isolation strategy, or document why it's not needed there.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
.github/workflows/ci.yml(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
.github/workflows/ci.yml (2)
117-125: Critical: Unaddressed issues from previous review—paralleldependency and port collision risk.This change introduces the same concerns previously flagged on lines 149–157:
Missing
paralleldependency check: The workflow assumes GNUparallelis installed on the self-hosted runner without verification. If missing, the script will silently fail or produce cryptic errors.Port collision risk: Random port selection from [20000, 21000) provides only 1000 ports. With
--jobs 4, multiple concurrent tests picking the same randomMASTER_PORTcan cause conflicts and test failures, especially with multiple distributed test files.Apply this fix to verify
parallelavailability and expand the port range to reduce collisions:if [ "${#DIST_TESTS[@]}" -gt 0 ]; then + if ! command -v parallel &> /dev/null; then + echo "Error: GNU parallel not installed. Aborting distributed examples." + exit 1 + fi tmp_script=$(mktemp) echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" printf '%s\n' "${DIST_TESTS[@]}" for test_file in "${DIST_TESTS[@]}"; do - RANDOM_PORT=$(( RANDOM % 1000 + 20000 )) # random port [20000, 21000) + RANDOM_PORT=$(( RANDOM % 30000 + 20000 )) # expanded range [20000, 50000) echo "MASTER_PORT=$RANDOM_PORT TILELANG_USE_DISTRIBUTED=1 python -m pytest \"$test_file\" -v -r fE" >> "$tmp_script" done parallel --jobs 4 < "$tmp_script" rm "$tmp_script"
149-157: Critical: Unaddressed issues from previous review—paralleldependency and port collision risk.The same critical issues flagged on lines 117–125 apply here:
- Missing
paralleldependency check: No verification that GNUparallelis installed.- Port collision risk: Random [20000, 21000) range with
--jobs 4creates collision risk.These concerns were raised in the previous review and remain unresolved.
Apply the same fix:
if [ "${#DIST_TESTS[@]}" -gt 0 ]; then + if ! command -v parallel &> /dev/null; then + echo "Error: GNU parallel not installed. Aborting distributed tests." + exit 1 + fi tmp_script=$(mktemp) echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" printf '%s\n' "${DIST_TESTS[@]}" for test_file in "${DIST_TESTS[@]}"; do - RANDOM_PORT=$(( RANDOM % 1000 + 20000 )) # random port [20000, 21000) + RANDOM_PORT=$(( RANDOM % 30000 + 20000 )) # expanded range [20000, 50000) echo "MASTER_PORT=$RANDOM_PORT TILELANG_USE_DISTRIBUTED=1 python -m pytest \"$test_file\" -v -r fE" >> "$tmp_script" done parallel --jobs 4 < "$tmp_script" rm "$tmp_script"
🧹 Nitpick comments (1)
.github/workflows/ci.yml (1)
108-170: Consider extracting distributed test execution into a reusable helper function.Lines 117–125 and 149–157 implement the same pattern twice (create script, generate per-test commands, run via parallel, cleanup). This duplicates logic across the "Run examples" and "Run tests" sections.
Extract this into a shell function at the top of the step to reduce duplication and improve maintainability:
- name: Run examples run: | source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" cd examples unset PYTHONPATH + + run_distributed_tests() { + local test_type="$1" # "examples" or "tests" + mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) + if [ "${#DIST_TESTS[@]}" -gt 0 ]; then + if ! command -v parallel &> /dev/null; then + echo "Error: GNU parallel not installed." + exit 1 + fi + tmp_script=$(mktemp) + echo "Running distributed $test_type with TILELANG_USE_DISTRIBUTED=1:" + printf '%s\n' "${DIST_TESTS[@]}" + for test_file in "${DIST_TESTS[@]}"; do + RANDOM_PORT=$(( RANDOM % 30000 + 20000 )) + echo "MASTER_PORT=$RANDOM_PORT TILELANG_USE_DISTRIBUTED=1 python -m pytest \"$test_file\" -v -r fE" >> "$tmp_script" + done + parallel --jobs 4 < "$tmp_script" + rm "$tmp_script" + else + echo "No distributed $test_type found." + fi + } # find and run distributed tests with TILELANG_USE_DISTRIBUTED=1 - mapfile -t DIST_TESTS < <(find . -type f -path '*/distributed/*' -name 'test*.py' 2>/dev/null || true) - if [ "${#DIST_TESTS[@]}" -gt 0 ]; then - ... - fi + run_distributed_tests "examples"This approach centralizes the logic, simplifies both sections, and makes it easier to fix issues (like the port collision concern) in one place.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
.github/workflows/ci.yml(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
4efa911 to
99501ef
Compare
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (5)
src/transform/declare_symm_buffer.cc (5)
9-9: Fix the broken registry include.
Line 9 still pulls in<tvm/ffi/reflection/registry.h>, which doesn’t exist in this tree, so the file fails to compile. Please include the runtime registry header instead.-#include <tvm/ffi/reflection/registry.h> +#include <tvm/runtime/registry.h>
29-30: Eliminate the global counter race.
The globalname_suffix_idis still present and incremented from multiple pass invocations, which races under concurrent compilation. Move the counter intoSymmBufferDeclarerand update the call sites to use the member.-static int name_suffix_id = - 0; // Avoid name collision for symm buffers, start from 0 +// Avoid name collision for symm buffers, start from 0-class SymmBufferDeclarer : public StmtExprMutator { +class SymmBufferDeclarer : public StmtExprMutator { public: @@ private: Map<Var, Buffer> buffer_data_to_buffer_; + int name_suffix_id_ = 0;- Var symm_dst_var = - Var(dst->name + "_symm_" + std::to_string(name_suffix_id++), + Var symm_dst_var = + Var(dst->name + "_symm_" + std::to_string(name_suffix_id_++), PointerType(PrimType(dst->dtype), storage_scope));- Var symm_src_var = - Var(src->name + "_symm_" + std::to_string(name_suffix_id++), + Var symm_src_var = + Var(src->name + "_symm_" + std::to_string(name_suffix_id_++), PointerType(PrimType(src->dtype), storage_scope));Also applies to: 108-110, 176-178
34-40: Repair pointer arithmetic inCalculateSymmPtr.
This routine still mixesHandle-typed PrimExpr with integer arithmetic, which is undefined after type checking. Convert all operands toUInt(64)viatl::get_uintptr_tbefore subtracting/adding, then reinterpret the final sum back to a handle.PrimExpr CalculateSymmPtr(PrimExpr ptr, PrimExpr pe) { - PrimExpr local_base_ptr = Call(DataType::Handle(), tl::get_local_base(), {}); - PrimExpr offset_to_base = - Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {ptr}), local_base_ptr); - PrimExpr result = Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}) + - offset_to_base; - return result; + PrimExpr local_base_handle = Call(DataType::Handle(), tl::get_local_base(), {}); + PrimExpr local_base_uint = + Call(DataType::UInt(64), tl::get_uintptr_t(), {local_base_handle}); + PrimExpr ptr_uint = Call(DataType::UInt(64), tl::get_uintptr_t(), {ptr}); + PrimExpr offset_to_base = Sub(ptr_uint, local_base_uint); + PrimExpr remote_base_handle = + Call(DataType::Handle(), tl::get_remote_base_ptr(), {pe}); + PrimExpr remote_base_uint = + Call(DataType::UInt(64), tl::get_uintptr_t(), {remote_base_handle}); + PrimExpr result_uint = Add(remote_base_uint, offset_to_base); + return Call(DataType::Handle(), builtin::reinterpret(), {result_uint}); }
105-143: Add defensive checks before rebuilding the destination region.
This block still assumes the buffer has a pointer type annotation and thatcall_op->args[1]is aCallNodewith anIntImmaccess mask. Any divergence in the IR crashes the compiler. Guard those assumptions and reuse the validated node when rebuilding the region.- String storage_scope = - dst->data->type_annotation.as<PointerTypeNode>()->storage_scope; + ICHECK(dst->data->type_annotation.defined()) + << "Buffer data variable must have type annotation"; + const auto* dst_ptr_type = + dst->data->type_annotation.as<PointerTypeNode>(); + ICHECK(dst_ptr_type) + << "Buffer data type annotation must be PointerType"; + String storage_scope = dst_ptr_type->storage_scope; @@ - Array<PrimExpr> dst_region_args; + ICHECK_GE(call_op->args.size(), 2U) + << "Copy call must provide dst region"; + const auto* dst_region_call_node = call_op->args[1].as<CallNode>(); + ICHECK(dst_region_call_node) << "Expected Call node for dst region"; + Array<PrimExpr> dst_region_args; dst_region_args.push_back(dst_load); - dst_region_args.push_back( - IntImm(DataType::Int(32), call_op->args[1] - .as<CallNode>() - ->args[1] - .as<IntImmNode>() - ->value)); // access_mask + const auto* dst_access_mask = + dst_region_call_node->args[1].as<IntImmNode>(); + ICHECK(dst_access_mask) << "Expected IntImm access mask for dst region"; + dst_region_args.push_back( + IntImm(DataType::Int(32), dst_access_mask->value)); // access_mask @@ - Call dst_region_call = - Call(call_op->args[1].as<CallNode>()->dtype, - call_op->args[1].as<CallNode>()->op, dst_region_args, - call_op->args[1].as<CallNode>()->span); + Call dst_region_call = + Call(dst_region_call_node->dtype, dst_region_call_node->op, + dst_region_args, dst_region_call_node->span);
174-211: Do the same safety checks for the source region.
The remote pull branch still dereferences unchecked pointer annotations and assumescall_op->args[0]yields aCallNode(IntImm). Add the correspondingICHECKs before accessing those members.- String storage_scope = - src->data->type_annotation.as<PointerTypeNode>()->storage_scope; + ICHECK(src->data->type_annotation.defined()) + << "Buffer data variable must have type annotation"; + const auto* src_ptr_type = + src->data->type_annotation.as<PointerTypeNode>(); + ICHECK(src_ptr_type) + << "Buffer data type annotation must be PointerType"; + String storage_scope = src_ptr_type->storage_scope; @@ - Array<PrimExpr> src_region_args; + ICHECK_GE(call_op->args.size(), 2U) + << "Copy call must provide src region"; + const auto* src_region_call_node = call_op->args[0].as<CallNode>(); + ICHECK(src_region_call_node) << "Expected Call node for src region"; + Array<PrimExpr> src_region_args; src_region_args.push_back(src_load); - src_region_args.push_back( - IntImm(DataType::Int(32), call_op->args[0] - .as<CallNode>() - ->args[1] - .as<IntImmNode>() - ->value)); // access_mask + const auto* src_access_mask = + src_region_call_node->args[1].as<IntImmNode>(); + ICHECK(src_access_mask) << "Expected IntImm access mask for src region"; + src_region_args.push_back( + IntImm(DataType::Int(32), src_access_mask->value)); // access_mask @@ - Call src_region_call = - Call(call_op->args[0].as<CallNode>()->dtype, - call_op->args[0].as<CallNode>()->op, src_region_args, - call_op->args[0].as<CallNode>()->span); + Call src_region_call = + Call(src_region_call_node->dtype, src_region_call_node->op, + src_region_args, src_region_call_node->span);
🧹 Nitpick comments (4)
src/transform/lower_hopper_intrin.cc (1)
66-66: Consider removing unnecessary blank line.This blank line appears to be unnecessary and could be removed for consistency.
tilelang/language/__init__.py (1)
19-19: Remove unnecessarynoqadirective.The
# noqa: F401comment is flagged as unused by Ruff. Sincemake_tensor_likeis being explicitly exported in__all__(implicitly via the module's public API), the import is not unused and the directive can be removed.Apply this diff:
- make_tensor_like, # noqa: F401 + make_tensor_like,tilelang/language/copy.py (1)
25-26: Clarify PE parameter semantics in docstring.The docstring states "Defaults to -1, which means copy from/to local" but doesn't explain what non-
-1values represent. Consider adding a brief note that non--1values specify a remote PE index.- src_pe (Optional[tir.PrimExpr], optional): Source PE index. Defaults to -1, which means copy from local - dst_pe (Optional[tir.PrimExpr], optional): Destination PE index. Defaults to -1, which means copy to local. + src_pe (Optional[tir.PrimExpr], optional): Source PE index. Defaults to -1 (local), otherwise specifies remote PE. + dst_pe (Optional[tir.PrimExpr], optional): Destination PE index. Defaults to -1 (local), otherwise specifies remote PE.src/op/copy.cc (1)
1979-1984: Keep the tl.copy registration comment accurate.The op now exposes seven inputs, so the comment claiming eight (with
is_remote_copy) is stale. Updating it avoids the next reader scratching their head.Apply this diff to synchronize the documentation:
-// - Takes 8 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, -// eviction_policy, src_pe, dst_pe, is_remote_copy +// - Takes 7 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, +// eviction_policy, src_pe, dst_pe
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
examples/distributed/primitives/example_tilescale_copy.py(1 hunks)examples/distributed/primitives/test_tilescale_copy.py(1 hunks)src/op/copy.cc(2 hunks)src/op/copy.h(1 hunks)src/op/distributed.cc(1 hunks)src/op/distributed.h(1 hunks)src/op/remote_copy.cc(2 hunks)src/op/sync.cc(1 hunks)src/target/codegen_cuda.cc(2 hunks)src/tl_templates/cuda/common.h(1 hunks)src/tl_templates/cuda/distributed.h(1 hunks)src/tl_templates/cuda/sync.h(1 hunks)src/transform/declare_symm_buffer.cc(1 hunks)src/transform/lower_hopper_intrin.cc(4 hunks)tilelang/engine/phase.py(1 hunks)tilelang/jit/adapter/utils.py(1 hunks)tilelang/jit/adapter/wrapper.py(5 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/copy.py(2 hunks)tilelang/language/proxy.py(1 hunks)tilelang/transform/__init__.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- tilelang/transform/init.py
- tilelang/jit/adapter/utils.py
- tilelang/engine/phase.py
- tilelang/language/proxy.py
- src/tl_templates/cuda/distributed.h
🧰 Additional context used
🧬 Code graph analysis (11)
src/op/copy.h (1)
src/op/copy.cc (6)
is_remote_copy(193-195)is_remote_copy(193-193)is_remote_push(183-186)is_remote_push(183-183)is_remote_pull(188-191)is_remote_pull(188-188)
src/op/copy.cc (2)
src/op/copy.h (1)
is_remote_copy(100-124)tilelang/language/copy.py (1)
copy(11-96)
src/op/remote_copy.cc (1)
src/tl_templates/cuda/distributed.h (1)
get_uintptr_t(53-55)
src/op/distributed.h (1)
tilelang/language/distributed/common.py (2)
get_rank(8-11)get_num_ranks(14-17)
src/transform/declare_symm_buffer.cc (2)
src/tl_templates/cuda/distributed.h (1)
get_uintptr_t(53-55)tilelang/transform/__init__.py (1)
DeclareSymmBuffer(492-498)
examples/distributed/primitives/test_tilescale_copy.py (1)
tilelang/testing/__init__.py (1)
requires_cuda_compute_version_ge(105-106)
tilelang/language/copy.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
examples/distributed/primitives/example_tilescale_copy.py (8)
tilelang/distributed/utils.py (1)
init_dist(40-62)tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/allocate.py (1)
alloc_shared(24-39)tilelang/language/distributed/common.py (1)
get_rank(8-11)tilelang/language/copy.py (1)
copy(11-96)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/utils/allocator.py (1)
get_allocator(236-248)tilelang/utils/tensor.py (1)
tensor(43-56)
src/transform/lower_hopper_intrin.cc (3)
tilelang/language/builtin.py (1)
create_tma_descriptor(74-83)src/op/builtin.cc (2)
cuTensorMapType(39-39)cuTensorMapType(39-39)tilelang/language/tir/op.py (1)
call_extern(173-195)
tilelang/jit/adapter/wrapper.py (1)
tilelang/jit/adapter/utils.py (2)
pythonic_expr(110-218)tilescale_pythonic_expr(221-352)
tilelang/language/__init__.py (1)
tilelang/language/proxy.py (1)
make_tensor_like(305-313)
🪛 Clang (14.0.6)
src/transform/declare_symm_buffer.cc
[error] 9-9: 'tvm/ffi/reflection/registry.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
examples/distributed/primitives/example_tilescale_copy.py
179-179: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/__init__.py
19-19: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-nvidia
🔇 Additional comments (20)
src/transform/lower_hopper_intrin.cc (3)
117-131: LGTM! Visitor pattern correctly implemented.The LetStmt visitor correctly visits the value and body first, then checks if the variable is TMA-related and stores the binding. The structure preservation logic (lines 126-130) follows best practices.
187-189: LGTM! Appropriate data structures for TMA tracking.The new data members use appropriate hash and equality functors (ObjectPtrHash, ObjectPtrEqual) for Var objects, which is correct for TVM IR nodes.
42-56: No changes required — the inlining is safe by design.The inlining only replaces base pointer variables with their let-bound expressions. Base pointers—being memory addresses—are semantically constrained to be function parameters or early allocations, never depending on later-defined local variables. The code's narrow focus on only base pointers (stored in
tma_related_vars_) ensures the hoisting respects scope.Likely an incorrect or invalid review comment.
tilelang/jit/adapter/wrapper.py (5)
9-10: LGTM! Import aligns with distributed copy support.The addition of
tilescale_pythonic_exprcorrectly enables parsing of distributed intrinsics (e.g.,tl.get_remote_base,tl.get_local_base) for TileScale-aware expression translation in TMA descriptor generation.
288-289: LGTM! Method follows established pattern.The new helper method correctly wraps
tilescale_pythonic_exprwith type mapping, following the same pattern as_pythonic_exprabove it.
466-466: LGTM! Correct use of tilescale-aware expression translator.Switching to
_tilescale_pythonic_exprforglobalAddressproperly handles distributed intrinsics (e.g.,tl.get_remote_base,tl.get_local_base) that may appear in TMA descriptor global addresses for distributed copy scenarios, directly supporting the PR's unified copy lowering objectives.
65-65: No changes needed.
host_meta_datais properly declared at file scope in the generated C++ code. The declarationuint64_t* host_meta_data = nullptr;is generated insrc/target/codegen_cuda.cc(line 301), and it's also extern-declared in the CUDA template headers (src/tl_templates/cuda/distributed.h, line 9). The assignment in the wrapper function will not cause compilation errors.
333-338: Pattern verification confirms no false positives—change is sound.The symmetric buffer naming convention in
src/transform/declare_symm_buffer.cccreates descriptors with the pattern{buffer_name}_symm_{numeric_id}, where the numeric ID is globally incremented to prevent collisions. The added condition(match.startswith(name + "_symm_") and match.endswith("_desc"))correctly identifies these TMA copy descriptors without risk of false positives, as the naming scheme ensures uniqueness even when buffer names have similar prefixes.tilelang/language/copy.py (2)
11-19: LGTM! Remote copy parameters integrated correctly.The function signature correctly extends the copy interface with
src_peanddst_peparameters for distributed copy support. The type hints appropriately allowPrimExpr,IntImm,int, orNonewith a default of-1to indicate local rank.
95-96: Remote copy parameters are correctly registered and passed.Verification confirms the backend intrinsic registration expects 7 arguments. The
tilelang/language/copy.pylines 95-96 pass all 7 arguments in the correct order expected by theCopyconstructor insrc/op/copy.cc:
- src (buffer/region)
- dst (buffer/region)
- coalesced_width
- disable_tma
- eviction_policy
- src_pe (new remote copy parameter)
- dst_pe (new remote copy parameter)
The backend registration at
src/op/copy.ccconfirms.set_num_inputs(7), matching the implementation. The code is correct.src/tl_templates/cuda/common.h (1)
35-38: LGTM! Host qualifier macros added correctly.The new
TL_HOST*andTL_HOST_DEVICE*macros properly extend the device-only qualifiers to support host-side and dual host-device contexts, enabling distributed intrinsics to work on both host and device.src/op/remote_copy.cc (2)
99-109: LGTM! Remote base intrinsic updated in PUT operation.The
PutOpNode::Lowermethod now correctly usestl::get_remote_base()for both local and remote base pointer computation, maintaining the offset calculation logic.
205-215: LGTM! Remote base intrinsic updated in GET operation.The
GetOpNode::Lowermethod now correctly usestl::get_remote_base()for both local and remote base pointer computation, consistent with the PUT operation changes.examples/distributed/primitives/test_tilescale_copy.py (5)
1-7: LGTM! Test module setup is correct.The imports and module structure properly set up the test environment with required dependencies for distributed multiprocessing tests.
9-13: LGTM! SIMT test properly configured.The test correctly uses the
requires_cudadecorator and spawns 2 processes for distributed testing of the SIMT push tile kernel.
15-20: LGTM! TMA load test has correct prerequisites.The test properly stacks
requires_cudaandrequires_cuda_compute_version_ge(9, 0)decorators since TMA features require compute capability 9.0+.
22-27: LGTM! TMA store test has correct prerequisites.The test properly stacks decorators for compute capability requirements, matching the TMA load test configuration.
29-30: LGTM! Test entry point configured correctly.The main entry delegates to
tilelang.testing.main()for standard test runner integration.src/op/copy.h (1)
95-102: Verification confirmed — implementations are correct and properly integrated.The three query methods are implemented soundly:
is_remote_push(): Checks if destination PE is remote (dst_pe ≠ -1)is_remote_pull(): Checks if source PE is remote (src_pe ≠ -1)is_remote_copy(): Correctly identifies any remote operation with OR logicType safety is ensured via
IsInstance<IntImmNode>()checks before casting, and the sentinel value (-1) consistently distinguishes local from remote processing elements. The implementation properly supports distributed copy operations without impacting local copy behavior.src/tl_templates/cuda/sync.h (1)
157-158: LGTM! Barrier pointer macro updated consistently.The
BARRIER_PTRmacro now usesget_remote_base(tgt_rank)instead ofget_remote_base_ptr(tgt_rank), aligning with the intrinsic renaming across the codebase. Verification confirmsget_remote_baseis properly defined insrc/tl_templates/cuda/distributed.hand the change is consistent throughout the CUDA templates.
| // Parse remote copy params | ||
| if (args.size() >= 6) { | ||
| node->src_pe = args[5]; | ||
| } | ||
| if (args.size() >= 7) { | ||
| node->dst_pe = args[6]; | ||
| } | ||
|
|
||
| ICHECK(!(node->is_remote_push() && node->is_remote_pull())) | ||
| << "At least one of src_pe or dst_pe must be local rank"; | ||
|
|
||
| if (node->is_remote_push()) { | ||
| ICHECK(node->dst.scope() == "global") | ||
| << "Can only copy to peer's global memory, but got " | ||
| << node->dst.scope(); | ||
| } else if (node->is_remote_pull()) { | ||
| ICHECK(node->src.scope() == "global") | ||
| << "Can only pull from peer's global memory, but got " | ||
| << node->src.scope(); | ||
| } | ||
|
|
||
| data_ = std::move(node); | ||
| } | ||
|
|
||
| bool CopyNode::is_remote_push() const { | ||
| return !(dst_pe->IsInstance<IntImmNode>() && | ||
| dst_pe.as<IntImmNode>()->value == -1); | ||
| } | ||
|
|
||
| bool CopyNode::is_remote_pull() const { | ||
| return !(src_pe->IsInstance<IntImmNode>() && | ||
| src_pe.as<IntImmNode>()->value == -1); | ||
| } | ||
|
|
||
| bool CopyNode::is_remote_copy() const { | ||
| return is_remote_push() || is_remote_pull(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard remote PE defaults before remote checks.
is_remote_push() / is_remote_pull() run immediately after the optional assignments. Legacy modules compiled before this PR still emit the 5-argument tl.copy; when we deserialize them here, args.size() is only 5, so dst_pe/src_pe stay PrimExpr() and the subsequent dst_pe->IsInstance dereferences a null handle. That’s a hard crash/regression for any pre-existing artifact. Please seed both fields with -1 before the if blocks and make the helpers resilient to an undefined PrimExpr.
Apply this diff to fix the issue:
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CopyNode> node = make_object<CopyNode>();
+ node->src_pe = Integer(-1);
+ node->dst_pe = Integer(-1);
Array<Range> rgs[2];
@@
- if (args.size() >= 6) {
- node->src_pe = args[5];
- }
- if (args.size() >= 7) {
- node->dst_pe = args[6];
- }
+ if (args.size() >= 6) {
+ node->src_pe = args[5];
+ }
+ if (args.size() >= 7) {
+ node->dst_pe = args[6];
+ }
@@
bool CopyNode::is_remote_push() const {
- return !(dst_pe->IsInstance<IntImmNode>() &&
- dst_pe.as<IntImmNode>()->value == -1);
+ if (!dst_pe.defined()) {
+ return false;
+ }
+ if (const auto *imm = dst_pe.as<IntImmNode>()) {
+ return imm->value != -1;
+ }
+ return true;
}
@@
bool CopyNode::is_remote_pull() const {
- return !(src_pe->IsInstance<IntImmNode>() &&
- src_pe.as<IntImmNode>()->value == -1);
+ if (!src_pe.defined()) {
+ return false;
+ }
+ if (const auto *imm = src_pe.as<IntImmNode>()) {
+ return imm->value != -1;
+ }
+ return true;
}Committable suggestion skipped: line range outside the PR's diff.
| PrimExpr local_base_ptr = | ||
| Call(DataType::Handle(), tl::get_remote_base_ptr(), {rank}); | ||
| Call(DataType::Handle(), tl::get_remote_base(), {rank}); | ||
| PrimExpr offset_to_base = | ||
| Sub(Call(DataType::Handle(), tl::get_uintptr_t(), {bar_addr}), | ||
| local_base_ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify get_remote_base intrinsic compatibility.
The barrier code now calls tl::get_remote_base() instead of tl::get_remote_base_ptr(). Ensure the new intrinsic returns the same DataType::Handle() type and accepts a rank parameter.
Run the following script to verify the intrinsic definition:
🏁 Script executed:
#!/bin/bash
# Find the definition of get_remote_base intrinsic
ast-grep --pattern $'const Op& get_remote_base() {
$$$
}'
# Also search for its registration
rg -n 'get_remote_base' --type=cpp -C3Length of output: 6899
I need to gather more information about the IR-level type specifications for these intrinsics to determine if the type mismatch is intentional or a bug.
[/web_search]
[/run_scripts]
🏁 Script executed:
#!/bin/bash
# Search for Op definitions and their return type specifications
rg -n 'get_local_base\|get_remote_base' src/op/distributed.cc -A10 -B2
# Check if DataType::Handle() is commonly used with uint64_t returns
rg -n 'DataType::Handle().*get_remote' src/ -C2Length of output: 2191
🌐 Web query:
TVM TIR DataType Handle uint64_t implicit conversion
💡 Result:
Short answer: TVM does not implicitly convert a uint64_t integer into a TIR DataType::Handle. DataType::Handle is a pointer/opaque "handle" type (typically 64-bit) separate from integer types, and the C API's TVMValue union does not include a uint64 field (it has int64, double, and void*), so uint64 values are not represented natively in TVMValue and won't be auto-treated as a handle — you must explicitly construct a handle (pointer) or use an integer DataType (uint64) where appropriate. [1][2]
References:
- TVM runtime DataType documentation (Handle / constructors). [1]
- Discussion on TVMValue not supporting uint64 (explains TVMValue fields and lack of uint64). [2]
[/run_scripts]
🏁 Script executed:
#!/bin/bash
# Search for complete Op registrations with type info
rg -n 'TIR_DEFINE_TL_BUILTIN.*get_' src/op/distributed.cc -A15 | head -80
# Also check how get_uintptr_t and get_local_base are registered
rg -n 'get_uintptr_t\|get_local_base.*TIR_DEFINE' src/op/distributed.cc -B2 -A10Length of output: 1388
[/run_scripts]
🏁 Script executed:
#!/bin/bash
# Check the actual usage context - see if there are type conversions nearby
rg -n 'get_remote_base\(\)' src/op/sync.cc src/op/remote_copy.cc -B5 -A10
# Also check if there's a Cast op being used
rg -n 'Cast\|reinterpret_cast' src/op/sync.cc src/op/remote_copy.cc | head -30Length of output: 3715
Critical type mismatch: get_remote_base() returns uint64_t, not Handle.
The code changes from tl::get_remote_base_ptr() (returns void*) to tl::get_remote_base() (returns uint64_t), but incorrectly asserts DataType::Handle() at the call sites. Per TVM's type system, uint64_t does not implicitly convert to Handle.
Fix required at 5 locations:
src/op/sync.ccline 97src/op/remote_copy.cclines 103, 108, 209, 214
Either revert to get_remote_base_ptr() or use DataType::UInt(64) and add explicit Cast to Handle if a pointer is semantically required.
🤖 Prompt for AI Agents
In src/op/sync.cc around lines 96-100 the call to tl::get_remote_base() is
incorrectly typed as DataType::Handle; get_remote_base() returns uint64_t so
replace the Handle typing with DataType::UInt(64) (or cast the uint64_t to
Handle explicitly if you truly need a pointer) and wrap with an explicit Cast
where needed; apply the same fix at src/op/remote_copy.cc lines 103, 108, 209
and 214: change the Call type to DataType::UInt(64) and if subsequent code
expects a pointer, add a Cast(expr, DataType::Handle()) after the Call rather
than leaving the Call typed as Handle.
| if (call->op.same_as(create_tma_descriptor())) { | ||
| Var var; | ||
| auto iter = desc_map_.find(GetRef<Call>(call)); | ||
| if (iter != desc_map_.end()) { | ||
| var = iter->second; | ||
| } else { | ||
| String name = call->args[2].as<Var>().value()->name_hint; | ||
| var = Var(name + "_desc", | ||
| PointerType(PrimType(cuTensorMapType()), "grid_constant")); | ||
| desc_map_[GetRef<Call>(call)] = var; | ||
| prefetch_calls_.push_back( | ||
| Evaluate(Call(DataType::Handle(), builtin::call_extern(), | ||
| {StringImm("tl::prefetch_tma_descriptor"), var}))); | ||
| // Mark the base pointer variable as TMA-related | ||
| if (auto base_var = call->args[2].as<Var>()) { | ||
| tma_related_vars_.insert(base_var.value()); | ||
| } | ||
| } | ||
| return var; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unsafe variable access at line 140.
Line 140 accesses call->args[2].as<Var>().value()->name_hint without checking if args[2] is actually a Var, which will crash if it's not. This is inconsistent with the safe check at lines 148-150.
Apply this diff to add a safety check:
} else {
- String name = call->args[2].as<Var>().value()->name_hint;
- var = Var(name + "_desc",
- PointerType(PrimType(cuTensorMapType()), "grid_constant"));
+ auto base_var_opt = call->args[2].as<Var>();
+ CHECK(base_var_opt) << "Expected args[2] to be a Var for create_tma_descriptor";
+ String name = base_var_opt.value()->name_hint;
+ var = Var(name + "_desc",
+ PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var;🤖 Prompt for AI Agents
In src/transform/lower_hopper_intrin.cc around lines 134 to 152, the code
unsafely dereferences call->args[2] as a Var when building the descriptor name;
first check whether call->args[2].as<Var>() is valid and only access name_hint
if it is, otherwise use a safe fallback name (e.g., "tma_desc") or derive a
unique name differently; reuse the same optional base_var for both name
construction and for inserting into tma_related_vars_ so the insertion is only
performed when base_var is valid; update the control flow so desc_map_ and
prefetch_calls_ still get created in the else branch but without unsafe
dereference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tl_templates/cuda/threadblock_swizzle.h(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/threadblock_swizzle.h (1)
src/tl_templates/hip/threadblock_swizzle.h (1)
rasterization2DRow(7-8)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (1)
src/tl_templates/cuda/threadblock_swizzle.h (1)
26-27: LGTM! Consistent default parameter addition.The default argument for
offsetmatches therasterization2DRowfunction, providing consistent API across both rasterization functions.
| namespace tl { | ||
|
|
||
| template <int panel_width, int offset> TL_DEVICE dim3 rasterization2DRow() { | ||
| template <int panel_width, int offset = 0> TL_DEVICE dim3 rasterization2DRow() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
LGTM! Improved API usability with default parameter.
The default argument makes the offset parameter optional, improving usability for common cases where no offset is needed. The change is backward compatible.
Verify consistency with the HIP backend version to ensure both implementations have aligned signatures:
🏁 Script executed:
#!/bin/bash
# Description: Compare template signatures between CUDA and HIP versions
echo "=== CUDA version ==="
rg -n "template.*rasterization2D" src/tl_templates/cuda/threadblock_swizzle.h -A 1
echo -e "\n=== HIP version ==="
rg -n "template.*rasterization2D" src/tl_templates/hip/threadblock_swizzle.h -A 1Length of output: 470
🏁 Script executed:
#!/bin/bash
# Get complete function signatures for both functions in CUDA and HIP
echo "=== CUDA: rasterization2DRow (full signature) ==="
sed -n '7,9p' src/tl_templates/cuda/threadblock_swizzle.h
echo -e "\n=== CUDA: rasterization2DColumn (full signature) ==="
sed -n '26,28p' src/tl_templates/cuda/threadblock_swizzle.h
echo -e "\n=== HIP: rasterization2DRow (full signature) ==="
sed -n '7,9p' src/tl_templates/hip/threadblock_swizzle.h
echo -e "\n=== HIP: rasterization2DColumn (full signature) ==="
sed -n '26,28p' src/tl_templates/hip/threadblock_swizzle.hLength of output: 968
🏁 Script executed:
#!/bin/bash
# Check if offset parameter is used in HIP implementations
echo "=== HIP: Check for 'offset' usage in rasterization functions ==="
sed -n '7,45p' src/tl_templates/hip/threadblock_swizzle.h | grep -n "offset"
echo -e "\n=== Search for calls to rasterization2DRow/2DColumn with template arguments ==="
rg -n "rasterization2D(Row|Column)\s*<" --type=cc --type=hLength of output: 745
🏁 Script executed:
#!/bin/bash
# View full CUDA implementations to see offset usage
echo "=== CUDA rasterization2DRow implementation ==="
sed -n '7,24p' src/tl_templates/cuda/threadblock_swizzle.h
echo -e "\n=== CUDA rasterization2DColumn implementation ==="
sed -n '26,45p' src/tl_templates/cuda/threadblock_swizzle.h
echo -e "\n=== Search for function calls to these templates ==="
rg -n "rasterization2D(Row|Column)" --type=cpp --type=hpp --type=h -C 2Length of output: 2071
🏁 Script executed:
#!/bin/bash
# Broader search for function calls without file type restrictions
echo "=== Search for rasterization2DRow/Column calls ==="
rg -n "rasterization2D" -g "*.cu" -g "*.cuh" -g "*.h" -C 1
echo -e "\n=== Search with simpler pattern ==="
rg -n "rasterization2D(Row|Column)" -C 1Length of output: 3189
🏁 Script executed:
#!/bin/bash
# Check git history for recent changes to HIP version
echo "=== Recent commits touching HIP version ==="
git log --oneline -n 20 -- src/tl_templates/hip/threadblock_swizzle.h
echo -e "\n=== Recent commits touching CUDA version ==="
git log --oneline -n 20 -- src/tl_templates/cuda/threadblock_swizzle.h
echo -e "\n=== Check if HIP version was modified in this PR ==="
git diff HEAD~1 -- src/tl_templates/hip/threadblock_swizzle.hLength of output: 717
🏁 Script executed:
#!/bin/bash
# View the commit that introduced offset to understand original intent
echo "=== Commit 4b38278 details ==="
git show 4b38278 --stat
echo -e "\n=== Check if that commit touched HIP file ==="
git show 4b38278 -- src/tl_templates/hip/threadblock_swizzle.hLength of output: 710
Backend signature mismatch: HIP implementation not updated.
CUDA now has template <int panel_width, int offset = 0> while HIP remains template <int panel_width>, creating incompatible signatures between backends. The offset parameter introduced in commit 4b38278 was added only to CUDA and not propagated to HIP.
Either update HIP implementations to include the offset parameter with the same default value, or reconsider adding it to CUDA if HIP-specific logic prevents it.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/threadblock_swizzle.h around line 7 the CUDA template
signature was changed to "template <int panel_width, int offset = 0>" while HIP
implementations still use "template <int panel_width>", creating a backend
signature mismatch; update the HIP header(s)/implementations to add the same
"int offset = 0" template parameter to the corresponding template
declarations/definitions so signatures match (ensure default value is 0 and
propagate the new parameter into any HIP-specific uses), or if HIP cannot accept
the parameter, revert the CUDA change so both backends share the same signature.
TODO
Summary by CodeRabbit
New Features
Tests