From 56309753c11edfbd1f3f56346d31c8778f28d15a Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 12 Nov 2025 03:37:42 +0000 Subject: [PATCH 1/3] init internode folder, add doc and refactor launch script --- benchmark/distributed/ipc_impls/README.md | 2 +- .../ipc_impls/benchmark_nvshmem_p2p.py | 2 +- examples/distributed/internode/README.md | 20 +++++++++++++++++++ .../internode/example_allgather_internode.py | 8 ++++++++ .../example_overlapping_allgather.py | 4 ++-- tilelang/distributed/launch.sh | 14 ++++++------- 6 files changed, 39 insertions(+), 11 deletions(-) create mode 100644 examples/distributed/internode/README.md create mode 100644 examples/distributed/internode/example_allgather_internode.py rename examples/distributed/{ => internode}/example_overlapping_allgather.py (95%) diff --git a/benchmark/distributed/ipc_impls/README.md b/benchmark/distributed/ipc_impls/README.md index d89d00956..8a3c8c7cd 100644 --- a/benchmark/distributed/ipc_impls/README.md +++ b/benchmark/distributed/ipc_impls/README.md @@ -5,7 +5,7 @@ We launch only one block on each rank to avoid NVLink bandwidth as the bottlenec ## NVSHMEM-based push/pull ```bash -GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py +NPROC_PER_NODE=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py ``` ## Unrolled-copy implemented in TileScale (*ours*) diff --git a/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py b/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py index 5ab6265ae..1585558a4 100644 --- a/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py +++ b/benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py @@ -1,7 +1,7 @@ # This benchmark aims to measure the bandwidth of NVHSMEM-based communication. # We launch only one block on each rank to avoid NVLink bandwidth as the bottleneck. -# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py +# Usage: NPROC_PER_NODE=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py import os import tilelang diff --git a/examples/distributed/internode/README.md b/examples/distributed/internode/README.md new file mode 100644 index 000000000..5c825c6bf --- /dev/null +++ b/examples/distributed/internode/README.md @@ -0,0 +1,20 @@ +# Inter-node Examples + +Examples in this folder aim to demonstrate the inter-node communication capabilities of TileScale. + +- For previous intra-node examples, we can use either NVSHMEM APIs or native communication primitives (e.g. `T.put/get_block`, `T.copy`) provided by TileScale. +- However, for inter-node RDMA communication, currently we rely on NVSHMEM's implementation of IBRC/IBGDA. Hence, it is required to install NVSHMEM and pynvshmem. + - For detailed installation guide, please refer to [this](../../../docs/get_started/Installation.md#to-use-nvshmem-apis) + +## Example Usage + +In order to run inter-node distributed programs, we shall run the launch script simultaneously on multiple nodes. + +Example: +```bash +# master 0 +NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_overlapping_allgather.py +# workder 1 +NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_overlapping_allgather.py +``` + diff --git a/examples/distributed/internode/example_allgather_internode.py b/examples/distributed/internode/example_allgather_internode.py new file mode 100644 index 000000000..02bd7f7a1 --- /dev/null +++ b/examples/distributed/internode/example_allgather_internode.py @@ -0,0 +1,8 @@ +import torch +import torch.distributed as dist +import pynvshmem +import tilelang +import tilelang.language as T +from tilelang.distributed import init_distributed, dtype_map +import argparse + diff --git a/examples/distributed/example_overlapping_allgather.py b/examples/distributed/internode/example_overlapping_allgather.py similarity index 95% rename from examples/distributed/example_overlapping_allgather.py rename to examples/distributed/internode/example_overlapping_allgather.py index 13c3e6dac..e4756b78a 100644 --- a/examples/distributed/example_overlapping_allgather.py +++ b/examples/distributed/internode/example_overlapping_allgather.py @@ -14,8 +14,8 @@ from cuda.bindings import runtime as cudart else: from cuda import cudart -# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py -# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py +# NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py +# NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py def internode_gather(M, local_world_size, block_M, threads): diff --git a/tilelang/distributed/launch.sh b/tilelang/distributed/launch.sh index 024b777b2..2c75d42c7 100755 --- a/tilelang/distributed/launch.sh +++ b/tilelang/distributed/launch.sh @@ -16,17 +16,17 @@ export NCCL_DEBUG=${NCCL_DEBUG:="WARN"} # set env var. `NCCL_DEBUG` to expected # Choices: [VERSION, WARN(default), INFO, TRACE], # set launch configurations -nproc_per_node=${GPUS:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `GPUS` to # of GPUs per node -nnodes=${NODES:=1} # set env var. `NODES` to # of nodes +nproc_per_node=${NPROC_PER_NODE:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `NPROC_PER_NODE` to # of GPUs per node +nnodes=${NNODES:=1} # set env var. `NODES` to # of nodes node_rank=${NODE_RANK:=0} # set env var. `NODE_RANK` to the rank of current node -master_addr=${ARNOLD_WORKER_0_HOST:="127.0.0.1"} -if [ -z ${ARNOLD_WORKER_0_PORT} ]; then - master_port="8361" +master_ip=${MASTER_IP:="127.0.0.1"} +if [ -z ${MASTER_PORT} ]; then + master_port="$(( RANDOM % 1000 + 20000))" # random port between 20000 and 21000 else - master_port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1) + master_port=$(echo "$MASTER_PORT" | cut -d "," -f 1) fi -additional_args="--rdzv_endpoint=${master_addr}:${master_port}" +additional_args="--rdzv_endpoint=${master_ip}:${master_port}" IB_HCA=mlx5 From b59c9acae157a8241141fa0aa708da1f00ea8e98 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 12 Nov 2025 12:08:32 +0000 Subject: [PATCH 2/3] add draft internode overlapped allgather --- benchmark/distributed/benchmark_ag_gemm.py | 2 +- .../example_overlapping_allgather.py | 0 examples/distributed/internode/README.md | 4 +- .../internode/example_allgather_internode.py | 210 +++++++++++++++++- src/target/codegen_cuda.cc | 12 + tilelang/distributed/launch.sh | 2 +- tilelang/distributed/utils.py | 20 +- .../distributed/multi_device/nvshmem.py | 16 +- 8 files changed, 254 insertions(+), 12 deletions(-) rename examples/distributed/{internode => }/example_overlapping_allgather.py (100%) diff --git a/benchmark/distributed/benchmark_ag_gemm.py b/benchmark/distributed/benchmark_ag_gemm.py index a4b0bd785..505b8479f 100644 --- a/benchmark/distributed/benchmark_ag_gemm.py +++ b/benchmark/distributed/benchmark_ag_gemm.py @@ -18,7 +18,7 @@ import tilelang import tilelang.language as T from tilelang.carver.arch import driver -from tilelang.distributed import init_distributed, dtype_map, perf_fn +from tilelang.distributed import init_distributed, dtype_map, perf_fn, wait_eq from triton_dist.kernels.nvidia.allgather_gemm import ag_gemm, create_ag_gemm_context from functools import partial diff --git a/examples/distributed/internode/example_overlapping_allgather.py b/examples/distributed/example_overlapping_allgather.py similarity index 100% rename from examples/distributed/internode/example_overlapping_allgather.py rename to examples/distributed/example_overlapping_allgather.py diff --git a/examples/distributed/internode/README.md b/examples/distributed/internode/README.md index 5c825c6bf..86847e5fd 100644 --- a/examples/distributed/internode/README.md +++ b/examples/distributed/internode/README.md @@ -13,8 +13,8 @@ In order to run inter-node distributed programs, we shall run the launch script Example: ```bash # master 0 -NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_overlapping_allgather.py +NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py # workder 1 -NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_overlapping_allgather.py +NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py ``` diff --git a/examples/distributed/internode/example_allgather_internode.py b/examples/distributed/internode/example_allgather_internode.py index 02bd7f7a1..ec70dd499 100644 --- a/examples/distributed/internode/example_allgather_internode.py +++ b/examples/distributed/internode/example_allgather_internode.py @@ -1,8 +1,210 @@ -import torch -import torch.distributed as dist -import pynvshmem +# NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py +# NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py + +# todo: add benchmark and impl for wait_eq u64, also stricter test + +import os import tilelang import tilelang.language as T -from tilelang.distributed import init_distributed, dtype_map import argparse +import torch +import torch.distributed as dist +from tilelang.distributed import init_distributed, dtype_map, perf_fn, wait_eq +import pynvshmem +from dataclasses import dataclass, field + +from cuda import cudart, cuda + + +os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log + +@dataclass +class AllGatherInternodeContext: + # tensor info + M_per_rank: int + M: int = field(init=False) + N: int + dtype: str + torch_dtype: torch.dtype = field(init=False) + + # rank info + rank: int + num_local_ranks: int + num_ranks: int + local_rank: int = field(init=False) + num_nodes: int = field(init=False) + node_rank: int = field(init=False) + + # workspace + barriers: list[torch.Tensor] = field(init=False) + barrier: torch.Tensor = field(init=False) + internode_comm_bufs: list[torch.Tensor] = field(init=False) + # internode_comm_buf: torch.Tensor = field(init=False) + + # streams + internode_stream: torch.cuda.Stream = field(init=False) + intranode_stream: torch.cuda.Stream = field(init=False) + + def __post_init__(self): + self.M = self.M_per_rank * self.num_ranks + self.local_rank = self.rank % self.num_local_ranks + self.num_nodes = self.num_ranks // self.num_local_ranks + self.node_rank = self.rank // self.num_local_ranks + self.torch_dtype = dtype_map[self.dtype] + + self.create_workspace() + + self.internode_stream = torch.cuda.Stream() + self.intranode_stream = torch.cuda.Stream() + + pynvshmem.nvshmem_barrier_all() + torch.cuda.synchronize() + + def create_workspace(self): + self.barriers = pynvshmem.nvshmem_create_tensor_list_intra_node([self.num_nodes,], torch.uint64) + self.barrier = self.barriers[self.local_rank] + self.barrier.fill_(0) + + +@tilelang.jit +def put_internode_kernel( + num_nodes: int, + num_local_ranks: int, + M_per_rank: int, + M: int, + N: int, + dtype: str, + threads: int = 256 +): + + @T.prim_func + def main( + dst: T.Tensor([M, N], "int32"), # type: ignore + barrier: T.Tensor([num_nodes], "uint64"), # type: ignore + ): + with T.Kernel(num_nodes-1, threads=threads) as (bx): + rank = T.get_pe() + node_rank = rank // num_local_ranks + peer = (rank+ (bx+1) * num_local_ranks) % (num_nodes * num_local_ranks) + T.putmem_signal_nbi_block( + T.address_of(dst[rank * M_per_rank, 0]), + T.address_of(dst[rank * M_per_rank, 0]), + M_per_rank * N * dtype_map[dtype].itemsize, + T.address_of(barrier[node_rank]), + 1, + T.Amo.SIGNAL_SET, + peer + ) + return main + + +def tl_allgather_internode( + src: torch.Tensor, + dst: list[torch.Tensor], + ctx: AllGatherInternodeContext, + debug: bool = False, +): + # 0. local copy and barrier + cudart.cudaMemcpy( + dst[ctx.local_rank][ctx.rank * ctx.M_per_rank, 0].data_ptr(), + src.data_ptr(), + ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, + cudart.cudaMemcpyKind.cudaMemcpyDefault + ) + pynvshmem.nvshmem_barrier_all() + dist.barrier() + torch.cuda.synchronize() + + # 1. perform inter-node comm + # push to all peers with same local rank and signal on barrier + with torch.cuda.stream(ctx.internode_stream): + kernel = put_internode_kernel(ctx.num_nodes, ctx.num_local_ranks, ctx.M_per_rank, ctx.M, ctx.N, ctx.dtype) + if debug and ctx.rank == 0: + print(kernel.get_kernel_source()) + kernel(dst[ctx.local_rank], ctx.barrier) + + with torch.cuda.stream(ctx.intranode_stream): + # 2. perform intra-node cp-engine based gather to overlap with inter-node comm + for i in range(ctx.num_local_ranks-1): + tgt_local_rank = (ctx.local_rank + i + 1) % ctx.num_local_ranks + tgt_rank = tgt_local_rank + ctx.node_rank * ctx.num_local_ranks + cudart.cudaMemcpyAsync( + dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + ctx.intranode_stream.cuda_stream + ) + + # 3. wait for data from other nodes sent to intra-node peers and gather + for i in range(ctx.num_nodes-1): + tgt_node_rank = (ctx.node_rank + i + 1) % ctx.num_nodes + for tgt_local_rank in range (ctx.num_local_ranks): + tgt_rank = tgt_local_rank + tgt_node_rank * ctx.num_local_ranks + cuda.cuStreamWaitValue64( + ctx.intranode_stream.cuda_stream, + ctx.barriers[tgt_local_rank][tgt_node_rank].data_ptr(), + 1, + cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, + ) + cudart.cudaMemcpyAsync( + dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + ctx.intranode_stream.cuda_stream + ) + + ctx.intranode_stream.wait_stream(ctx.internode_stream) + + +def main(M_per_rank: int, N: int, dtype: str, debug: bool = False): + WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed( + return_tp_group=True, return_lc_group=True) + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) + assert WORLD_SIZE % local_world_size == 0 + nodes: int = WORLD_SIZE // local_world_size + assert nodes >= 2, "This example is for inter-node allgather" + node_rank = RANK // local_world_size + + # gather WORLD_SIZE*[M_per_rank, N]->[M, N] + if debug: + dtype = 'int32' + torch_dtype = torch.int32 + src = torch.full([M_per_rank, N], RANK, dtype=torch.int32, device='cuda') + dst = pynvshmem.nvshmem_create_tensor_list_intra_node([M_per_rank * WORLD_SIZE, N], torch.int32) + dst[LOCAL_RANK].fill_(-1) + else: + torch_dtype = dtype_map[dtype] + src = torch.randn([M_per_rank, N], dtype=torch_dtype, device='cuda') + dst = pynvshmem.nvshmem_create_tensor_node([M_per_rank * WORLD_SIZE, N], torch_dtype) + ctx = AllGatherInternodeContext(M_per_rank, N, "int32", RANK, local_world_size, WORLD_SIZE) + + pynvshmem.nvshmem_barrier_all() + dist.barrier(TP_GROUP) + tl_allgather_internode(src, dst, ctx, debug) + pynvshmem.nvshmem_barrier_all() + dist.barrier(TP_GROUP) + + if debug: + print(dst[LOCAL_RANK]) + + # torch ref + ref_dst = torch.empty_like(dst[LOCAL_RANK]) + dist.barrier(TP_GROUP) + dist.all_gather_into_tensor(ref_dst, src, TP_GROUP) + dist.barrier(TP_GROUP) + assert torch.allclose(dst[LOCAL_RANK], ref_dst) + print(f'[node={node_rank}, local_rank={LOCAL_RANK}] All check passed.✅') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--M_per_rank', type=int, default=1024, help='Number of rows of the local tensor') + parser.add_argument('--N', type=int, default=1024, help='Number of columns of the local tensor') + parser.add_argument('--dtype', type=str, default='float32', help='Data type') + parser.add_argument('-debug', action='store_true', default=False, help='Enable debug mode') + args = parser.parse_args() + main(args.M_per_rank, args.N, args.dtype, args.debug) + \ No newline at end of file diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index e93b6fc4e..7e6214336 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2021,6 +2021,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << ", "; this->PrintExpr(op->args[3], os); os << ")"; + } else if (op->op.same_as(tl::PutmemSignalBlock())) { + this->use_distributed_ = true; + this->use_nvshmem_ = true; + os << "nvshmemx_putmem_signal_block("; + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[3], os); + os << ")"; } else if (op->op.same_as(tl::PutmemNbiBlock())) { this->use_distributed_ = true; this->use_nvshmem_ = true; diff --git a/tilelang/distributed/launch.sh b/tilelang/distributed/launch.sh index 2c75d42c7..f5b44d287 100755 --- a/tilelang/distributed/launch.sh +++ b/tilelang/distributed/launch.sh @@ -17,7 +17,7 @@ export NCCL_DEBUG=${NCCL_DEBUG:="WARN"} # set env var. `NCCL_DEBUG` to expected # set launch configurations nproc_per_node=${NPROC_PER_NODE:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `NPROC_PER_NODE` to # of GPUs per node -nnodes=${NNODES:=1} # set env var. `NODES` to # of nodes +nnodes=${NODES:=1} # set env var. `NODES` to # of nodes node_rank=${NODE_RANK:=0} # set env var. `NODE_RANK` to the rank of current node master_ip=${MASTER_IP:="127.0.0.1"} diff --git a/tilelang/distributed/utils.py b/tilelang/distributed/utils.py index ae7e1bfd7..74f5ba4ae 100644 --- a/tilelang/distributed/utils.py +++ b/tilelang/distributed/utils.py @@ -31,8 +31,8 @@ "float16": torch.float16, "float8_e4m3fn": torch.float8_e4m3fn, "float8_e5m2": torch.float8_e5m2, - "s8": torch.int8, - "s32": torch.int32, + "int8": torch.int8, + "int32": torch.int32, "float32": torch.float32, } @@ -280,6 +280,14 @@ def set_signal(signal_tensor: torch.Tensor, signal: int, stream: torch.cuda.Stre cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, ) CUDA_CHECK(err) + elif signal_tensor.dtype in (torch.int64, torch.uint64): + (err,) = cuda.cuStreamWriteValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT, + ) + CUDA_CHECK(err) else: raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") @@ -297,6 +305,14 @@ def wait_eq(signal_tensor: torch.Tensor, cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, ) CUDA_CHECK(err) + elif signal_tensor.dtype == torch.uint64: + (err,) = cuda.cuStreamWaitValue64( + stream.cuda_stream, + signal_tensor.data_ptr(), + signal, + cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, + ) + CUDA_CHECK(err) else: raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}") diff --git a/tilelang/language/distributed/multi_device/nvshmem.py b/tilelang/language/distributed/multi_device/nvshmem.py index 186a5d991..e328453d7 100644 --- a/tilelang/language/distributed/multi_device/nvshmem.py +++ b/tilelang/language/distributed/multi_device/nvshmem.py @@ -140,8 +140,20 @@ def putmem_signal_nbi(*args): return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalNbi"), *args) -def putmem_signal_block(*args): - return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalBlock"), *args) +def putmem_signal_block(dest, src, nelems, sig_addr, signal, sig_op, pe): + """Put data from local memory to remote memory at block granularity, + and update a remote flag on delivery. + Args: + dest: Symmetric address of the destination data object. + src: Symmetric address of the object containing the data to be copied. + nelems: Number of elements to be transferred (in bytes). + sig_addr: Symmetric address of the remote flag to be updated. + signal: The value used for updating the remote signal data object. + sig_op: The type of update to be performed on the remote signal data object. + pe: The PE ID of the destination PE. + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalBlock"), dest, src, + nelems, sig_addr, signal, sig_op, pe) def putmem_signal_nbi_block(dest, src, nelems, sig_addr, signal, sig_op, pe): From f08bf9f11b1b1dc2b0b1d7ec6ef3cb0f41631849 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 12 Nov 2025 12:09:45 +0000 Subject: [PATCH 3/3] lint --- benchmark/distributed/benchmark_ag_gemm.py | 2 +- .../internode/example_allgather_internode.py | 99 +++++++++---------- .../distributed/multi_device/nvshmem.py | 4 +- 3 files changed, 48 insertions(+), 57 deletions(-) diff --git a/benchmark/distributed/benchmark_ag_gemm.py b/benchmark/distributed/benchmark_ag_gemm.py index 505b8479f..a4b0bd785 100644 --- a/benchmark/distributed/benchmark_ag_gemm.py +++ b/benchmark/distributed/benchmark_ag_gemm.py @@ -18,7 +18,7 @@ import tilelang import tilelang.language as T from tilelang.carver.arch import driver -from tilelang.distributed import init_distributed, dtype_map, perf_fn, wait_eq +from tilelang.distributed import init_distributed, dtype_map, perf_fn from triton_dist.kernels.nvidia.allgather_gemm import ag_gemm, create_ag_gemm_context from functools import partial diff --git a/examples/distributed/internode/example_allgather_internode.py b/examples/distributed/internode/example_allgather_internode.py index ec70dd499..d64ff0abc 100644 --- a/examples/distributed/internode/example_allgather_internode.py +++ b/examples/distributed/internode/example_allgather_internode.py @@ -9,15 +9,15 @@ import argparse import torch import torch.distributed as dist -from tilelang.distributed import init_distributed, dtype_map, perf_fn, wait_eq +from tilelang.distributed import init_distributed, dtype_map import pynvshmem from dataclasses import dataclass, field from cuda import cudart, cuda - os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log + @dataclass class AllGatherInternodeContext: # tensor info @@ -61,40 +61,36 @@ def __post_init__(self): torch.cuda.synchronize() def create_workspace(self): - self.barriers = pynvshmem.nvshmem_create_tensor_list_intra_node([self.num_nodes,], torch.uint64) + self.barriers = pynvshmem.nvshmem_create_tensor_list_intra_node([ + self.num_nodes, + ], torch.uint64) self.barrier = self.barriers[self.local_rank] self.barrier.fill_(0) @tilelang.jit -def put_internode_kernel( - num_nodes: int, - num_local_ranks: int, - M_per_rank: int, - M: int, - N: int, - dtype: str, - threads: int = 256 -): +def put_internode_kernel(num_nodes: int, + num_local_ranks: int, + M_per_rank: int, + M: int, + N: int, + dtype: str, + threads: int = 256): @T.prim_func def main( - dst: T.Tensor([M, N], "int32"), # type: ignore - barrier: T.Tensor([num_nodes], "uint64"), # type: ignore + dst: T.Tensor([M, N], "int32"), # type: ignore + barrier: T.Tensor([num_nodes], "uint64"), # type: ignore ): - with T.Kernel(num_nodes-1, threads=threads) as (bx): + with T.Kernel(num_nodes - 1, threads=threads) as (bx): rank = T.get_pe() node_rank = rank // num_local_ranks - peer = (rank+ (bx+1) * num_local_ranks) % (num_nodes * num_local_ranks) + peer = (rank + (bx + 1) * num_local_ranks) % (num_nodes * num_local_ranks) T.putmem_signal_nbi_block( T.address_of(dst[rank * M_per_rank, 0]), - T.address_of(dst[rank * M_per_rank, 0]), - M_per_rank * N * dtype_map[dtype].itemsize, - T.address_of(barrier[node_rank]), - 1, - T.Amo.SIGNAL_SET, - peer - ) + T.address_of(dst[rank * M_per_rank, 0]), M_per_rank * N * dtype_map[dtype].itemsize, + T.address_of(barrier[node_rank]), 1, T.Amo.SIGNAL_SET, peer) + return main @@ -105,41 +101,37 @@ def tl_allgather_internode( debug: bool = False, ): # 0. local copy and barrier - cudart.cudaMemcpy( - dst[ctx.local_rank][ctx.rank * ctx.M_per_rank, 0].data_ptr(), - src.data_ptr(), - ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, - cudart.cudaMemcpyKind.cudaMemcpyDefault - ) + cudart.cudaMemcpy(dst[ctx.local_rank][ctx.rank * ctx.M_per_rank, 0].data_ptr(), src.data_ptr(), + ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, + cudart.cudaMemcpyKind.cudaMemcpyDefault) pynvshmem.nvshmem_barrier_all() dist.barrier() torch.cuda.synchronize() - + # 1. perform inter-node comm # push to all peers with same local rank and signal on barrier with torch.cuda.stream(ctx.internode_stream): - kernel = put_internode_kernel(ctx.num_nodes, ctx.num_local_ranks, ctx.M_per_rank, ctx.M, ctx.N, ctx.dtype) + kernel = put_internode_kernel(ctx.num_nodes, ctx.num_local_ranks, ctx.M_per_rank, ctx.M, + ctx.N, ctx.dtype) if debug and ctx.rank == 0: print(kernel.get_kernel_source()) kernel(dst[ctx.local_rank], ctx.barrier) with torch.cuda.stream(ctx.intranode_stream): # 2. perform intra-node cp-engine based gather to overlap with inter-node comm - for i in range(ctx.num_local_ranks-1): + for i in range(ctx.num_local_ranks - 1): tgt_local_rank = (ctx.local_rank + i + 1) % ctx.num_local_ranks tgt_rank = tgt_local_rank + ctx.node_rank * ctx.num_local_ranks - cudart.cudaMemcpyAsync( - dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), - dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), - ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, - cudart.cudaMemcpyKind.cudaMemcpyDefault, - ctx.intranode_stream.cuda_stream - ) + cudart.cudaMemcpyAsync(dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + ctx.intranode_stream.cuda_stream) # 3. wait for data from other nodes sent to intra-node peers and gather - for i in range(ctx.num_nodes-1): + for i in range(ctx.num_nodes - 1): tgt_node_rank = (ctx.node_rank + i + 1) % ctx.num_nodes - for tgt_local_rank in range (ctx.num_local_ranks): + for tgt_local_rank in range(ctx.num_local_ranks): tgt_rank = tgt_local_rank + tgt_node_rank * ctx.num_local_ranks cuda.cuStreamWaitValue64( ctx.intranode_stream.cuda_stream, @@ -147,14 +139,12 @@ def tl_allgather_internode( 1, cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ, ) - cudart.cudaMemcpyAsync( - dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), - dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), - ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, - cudart.cudaMemcpyKind.cudaMemcpyDefault, - ctx.intranode_stream.cuda_stream - ) - + cudart.cudaMemcpyAsync(dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(), + ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize, + cudart.cudaMemcpyKind.cudaMemcpyDefault, + ctx.intranode_stream.cuda_stream) + ctx.intranode_stream.wait_stream(ctx.internode_stream) @@ -166,13 +156,14 @@ def main(M_per_rank: int, N: int, dtype: str, debug: bool = False): nodes: int = WORLD_SIZE // local_world_size assert nodes >= 2, "This example is for inter-node allgather" node_rank = RANK // local_world_size - + # gather WORLD_SIZE*[M_per_rank, N]->[M, N] if debug: dtype = 'int32' torch_dtype = torch.int32 src = torch.full([M_per_rank, N], RANK, dtype=torch.int32, device='cuda') - dst = pynvshmem.nvshmem_create_tensor_list_intra_node([M_per_rank * WORLD_SIZE, N], torch.int32) + dst = pynvshmem.nvshmem_create_tensor_list_intra_node([M_per_rank * WORLD_SIZE, N], + torch.int32) dst[LOCAL_RANK].fill_(-1) else: torch_dtype = dtype_map[dtype] @@ -185,7 +176,7 @@ def main(M_per_rank: int, N: int, dtype: str, debug: bool = False): tl_allgather_internode(src, dst, ctx, debug) pynvshmem.nvshmem_barrier_all() dist.barrier(TP_GROUP) - + if debug: print(dst[LOCAL_RANK]) @@ -200,11 +191,11 @@ def main(M_per_rank: int, N: int, dtype: str, debug: bool = False): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--M_per_rank', type=int, default=1024, help='Number of rows of the local tensor') + parser.add_argument( + '--M_per_rank', type=int, default=1024, help='Number of rows of the local tensor') parser.add_argument('--N', type=int, default=1024, help='Number of columns of the local tensor') parser.add_argument('--dtype', type=str, default='float32', help='Data type') parser.add_argument('-debug', action='store_true', default=False, help='Enable debug mode') args = parser.parse_args() main(args.M_per_rank, args.N, args.dtype, args.debug) - \ No newline at end of file diff --git a/tilelang/language/distributed/multi_device/nvshmem.py b/tilelang/language/distributed/multi_device/nvshmem.py index e328453d7..80fac7a96 100644 --- a/tilelang/language/distributed/multi_device/nvshmem.py +++ b/tilelang/language/distributed/multi_device/nvshmem.py @@ -152,8 +152,8 @@ def putmem_signal_block(dest, src, nelems, sig_addr, signal, sig_op, pe): sig_op: The type of update to be performed on the remote signal data object. pe: The PE ID of the destination PE. """ - return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalBlock"), dest, src, - nelems, sig_addr, signal, sig_op, pe) + return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalBlock"), dest, src, nelems, + sig_addr, signal, sig_op, pe) def putmem_signal_nbi_block(dest, src, nelems, sig_addr, signal, sig_op, pe):