Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/distributed/ipc_impls/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ We launch only one block on each rank to avoid NVLink bandwidth as the bottlenec

## NVSHMEM-based push/pull
```bash
GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
NPROC_PER_NODE=2 bash tilelang/distributed/launch.sh benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
```

## Unrolled-copy implemented in TileScale (*ours*)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/distributed/ipc_impls/benchmark_nvshmem_p2p.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This benchmark aims to measure the bandwidth of NVHSMEM-based communication.
# We launch only one block on each rank to avoid NVLink bandwidth as the bottleneck.

# Usage: GPUS=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py
# Usage: NPROC_PER_NODE=2 bash tilelang/distributed/launch.sh benchmark/distributed/benchmark_nvshmem_p2p.py

import os
import tilelang
Expand Down
4 changes: 2 additions & 2 deletions examples/distributed/example_overlapping_allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from cuda.bindings import runtime as cudart
else:
from cuda import cudart
# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
# NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
# NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py


def internode_gather(M, local_world_size, block_M, threads):
Expand Down
20 changes: 20 additions & 0 deletions examples/distributed/internode/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Inter-node Examples

Examples in this folder aim to demonstrate the inter-node communication capabilities of TileScale.

- For previous intra-node examples, we can use either NVSHMEM APIs or native communication primitives (e.g. `T.put/get_block`, `T.copy`) provided by TileScale.
- However, for inter-node RDMA communication, currently we rely on NVSHMEM's implementation of IBRC/IBGDA. Hence, it is required to install NVSHMEM and pynvshmem.
- For detailed installation guide, please refer to [this](../../../docs/get_started/Installation.md#to-use-nvshmem-apis)

## Example Usage

In order to run inter-node distributed programs, we shall run the launch script simultaneously on multiple nodes.

Example:
```bash
# master 0
NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py
# workder 1
NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py
```

201 changes: 201 additions & 0 deletions examples/distributed/internode/example_allgather_internode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# NODES=2 NODE_RANK=0 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py
# NODES=2 NODE_RANK=1 MASTER_IP=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/internode/example_allgather_internode.py

# todo: add benchmark and impl for wait_eq u64, also stricter test

import os
import tilelang
import tilelang.language as T
import argparse
import torch
import torch.distributed as dist
from tilelang.distributed import init_distributed, dtype_map
import pynvshmem
from dataclasses import dataclass, field

from cuda import cudart, cuda

os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log


@dataclass
class AllGatherInternodeContext:
# tensor info
M_per_rank: int
M: int = field(init=False)
N: int
dtype: str
torch_dtype: torch.dtype = field(init=False)

# rank info
rank: int
num_local_ranks: int
num_ranks: int
local_rank: int = field(init=False)
num_nodes: int = field(init=False)
node_rank: int = field(init=False)

# workspace
barriers: list[torch.Tensor] = field(init=False)
barrier: torch.Tensor = field(init=False)
internode_comm_bufs: list[torch.Tensor] = field(init=False)
# internode_comm_buf: torch.Tensor = field(init=False)

# streams
internode_stream: torch.cuda.Stream = field(init=False)
intranode_stream: torch.cuda.Stream = field(init=False)

def __post_init__(self):
self.M = self.M_per_rank * self.num_ranks
self.local_rank = self.rank % self.num_local_ranks
self.num_nodes = self.num_ranks // self.num_local_ranks
self.node_rank = self.rank // self.num_local_ranks
self.torch_dtype = dtype_map[self.dtype]

self.create_workspace()

self.internode_stream = torch.cuda.Stream()
self.intranode_stream = torch.cuda.Stream()

pynvshmem.nvshmem_barrier_all()
torch.cuda.synchronize()

def create_workspace(self):
self.barriers = pynvshmem.nvshmem_create_tensor_list_intra_node([
self.num_nodes,
], torch.uint64)
self.barrier = self.barriers[self.local_rank]
self.barrier.fill_(0)


@tilelang.jit
def put_internode_kernel(num_nodes: int,
num_local_ranks: int,
M_per_rank: int,
M: int,
N: int,
dtype: str,
threads: int = 256):

@T.prim_func
def main(
dst: T.Tensor([M, N], "int32"), # type: ignore
barrier: T.Tensor([num_nodes], "uint64"), # type: ignore
):
with T.Kernel(num_nodes - 1, threads=threads) as (bx):
rank = T.get_pe()
node_rank = rank // num_local_ranks
peer = (rank + (bx + 1) * num_local_ranks) % (num_nodes * num_local_ranks)
T.putmem_signal_nbi_block(
T.address_of(dst[rank * M_per_rank, 0]),
T.address_of(dst[rank * M_per_rank, 0]), M_per_rank * N * dtype_map[dtype].itemsize,
T.address_of(barrier[node_rank]), 1, T.Amo.SIGNAL_SET, peer)

return main


def tl_allgather_internode(
src: torch.Tensor,
dst: list[torch.Tensor],
ctx: AllGatherInternodeContext,
debug: bool = False,
):
# 0. local copy and barrier
cudart.cudaMemcpy(dst[ctx.local_rank][ctx.rank * ctx.M_per_rank, 0].data_ptr(), src.data_ptr(),
ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize,
cudart.cudaMemcpyKind.cudaMemcpyDefault)
pynvshmem.nvshmem_barrier_all()
dist.barrier()
torch.cuda.synchronize()

# 1. perform inter-node comm
# push to all peers with same local rank and signal on barrier
with torch.cuda.stream(ctx.internode_stream):
kernel = put_internode_kernel(ctx.num_nodes, ctx.num_local_ranks, ctx.M_per_rank, ctx.M,
ctx.N, ctx.dtype)
if debug and ctx.rank == 0:
print(kernel.get_kernel_source())
kernel(dst[ctx.local_rank], ctx.barrier)

with torch.cuda.stream(ctx.intranode_stream):
# 2. perform intra-node cp-engine based gather to overlap with inter-node comm
for i in range(ctx.num_local_ranks - 1):
tgt_local_rank = (ctx.local_rank + i + 1) % ctx.num_local_ranks
tgt_rank = tgt_local_rank + ctx.node_rank * ctx.num_local_ranks
cudart.cudaMemcpyAsync(dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(),
dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(),
ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
ctx.intranode_stream.cuda_stream)

# 3. wait for data from other nodes sent to intra-node peers and gather
for i in range(ctx.num_nodes - 1):
tgt_node_rank = (ctx.node_rank + i + 1) % ctx.num_nodes
for tgt_local_rank in range(ctx.num_local_ranks):
tgt_rank = tgt_local_rank + tgt_node_rank * ctx.num_local_ranks
cuda.cuStreamWaitValue64(
ctx.intranode_stream.cuda_stream,
ctx.barriers[tgt_local_rank][tgt_node_rank].data_ptr(),
1,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
cudart.cudaMemcpyAsync(dst[ctx.local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(),
dst[tgt_local_rank][tgt_rank * ctx.M_per_rank, 0].data_ptr(),
ctx.M_per_rank * ctx.N * ctx.torch_dtype.itemsize,
cudart.cudaMemcpyKind.cudaMemcpyDefault,
ctx.intranode_stream.cuda_stream)

ctx.intranode_stream.wait_stream(ctx.internode_stream)


def main(M_per_rank: int, N: int, dtype: str, debug: bool = False):
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(
return_tp_group=True, return_lc_group=True)
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE'))
assert WORLD_SIZE % local_world_size == 0
nodes: int = WORLD_SIZE // local_world_size
assert nodes >= 2, "This example is for inter-node allgather"
node_rank = RANK // local_world_size

# gather WORLD_SIZE*[M_per_rank, N]->[M, N]
if debug:
dtype = 'int32'
torch_dtype = torch.int32
src = torch.full([M_per_rank, N], RANK, dtype=torch.int32, device='cuda')
dst = pynvshmem.nvshmem_create_tensor_list_intra_node([M_per_rank * WORLD_SIZE, N],
torch.int32)
dst[LOCAL_RANK].fill_(-1)
else:
torch_dtype = dtype_map[dtype]
src = torch.randn([M_per_rank, N], dtype=torch_dtype, device='cuda')
dst = pynvshmem.nvshmem_create_tensor_node([M_per_rank * WORLD_SIZE, N], torch_dtype)
ctx = AllGatherInternodeContext(M_per_rank, N, "int32", RANK, local_world_size, WORLD_SIZE)

pynvshmem.nvshmem_barrier_all()
dist.barrier(TP_GROUP)
tl_allgather_internode(src, dst, ctx, debug)
pynvshmem.nvshmem_barrier_all()
dist.barrier(TP_GROUP)

if debug:
print(dst[LOCAL_RANK])

# torch ref
ref_dst = torch.empty_like(dst[LOCAL_RANK])
dist.barrier(TP_GROUP)
dist.all_gather_into_tensor(ref_dst, src, TP_GROUP)
dist.barrier(TP_GROUP)
assert torch.allclose(dst[LOCAL_RANK], ref_dst)
print(f'[node={node_rank}, local_rank={LOCAL_RANK}] All check passed.✅')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--M_per_rank', type=int, default=1024, help='Number of rows of the local tensor')
parser.add_argument('--N', type=int, default=1024, help='Number of columns of the local tensor')
parser.add_argument('--dtype', type=str, default='float32', help='Data type')
parser.add_argument('-debug', action='store_true', default=False, help='Enable debug mode')
args = parser.parse_args()

main(args.M_per_rank, args.N, args.dtype, args.debug)
12 changes: 12 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << ", ";
this->PrintExpr(op->args[3], os);
os << ")";
} else if (op->op.same_as(tl::PutmemSignalBlock())) {
this->use_distributed_ = true;
this->use_nvshmem_ = true;
os << "nvshmemx_putmem_signal_block(";
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ", ";
this->PrintExpr(op->args[2], os);
os << ", ";
this->PrintExpr(op->args[3], os);
os << ")";
} else if (op->op.same_as(tl::PutmemNbiBlock())) {
this->use_distributed_ = true;
this->use_nvshmem_ = true;
Expand Down
12 changes: 6 additions & 6 deletions tilelang/distributed/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ export NCCL_DEBUG=${NCCL_DEBUG:="WARN"} # set env var. `NCCL_DEBUG` to expected
# Choices: [VERSION, WARN(default), INFO, TRACE],

# set launch configurations
nproc_per_node=${GPUS:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `GPUS` to # of GPUs per node
nproc_per_node=${NPROC_PER_NODE:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `NPROC_PER_NODE` to # of GPUs per node
nnodes=${NODES:=1} # set env var. `NODES` to # of nodes
node_rank=${NODE_RANK:=0} # set env var. `NODE_RANK` to the rank of current node

master_addr=${ARNOLD_WORKER_0_HOST:="127.0.0.1"}
if [ -z ${ARNOLD_WORKER_0_PORT} ]; then
master_port="8361"
master_ip=${MASTER_IP:="127.0.0.1"}
if [ -z ${MASTER_PORT} ]; then
master_port="$(( RANDOM % 1000 + 20000))" # random port between 20000 and 21000
else
master_port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1)
master_port=$(echo "$MASTER_PORT" | cut -d "," -f 1)
fi
additional_args="--rdzv_endpoint=${master_addr}:${master_port}"
additional_args="--rdzv_endpoint=${master_ip}:${master_port}"
IB_HCA=mlx5


Expand Down
20 changes: 18 additions & 2 deletions tilelang/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
"float16": torch.float16,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
"s8": torch.int8,
"s32": torch.int32,
"int8": torch.int8,
"int32": torch.int32,
"float32": torch.float32,
}

Expand Down Expand Up @@ -280,6 +280,14 @@ def set_signal(signal_tensor: torch.Tensor, signal: int, stream: torch.cuda.Stre
cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
)
CUDA_CHECK(err)
elif signal_tensor.dtype in (torch.int64, torch.uint64):
(err,) = cuda.cuStreamWriteValue64(
stream.cuda_stream,
signal_tensor.data_ptr(),
signal,
cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
)
CUDA_CHECK(err)
else:
raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")

Expand All @@ -297,6 +305,14 @@ def wait_eq(signal_tensor: torch.Tensor,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
CUDA_CHECK(err)
elif signal_tensor.dtype == torch.uint64:
(err,) = cuda.cuStreamWaitValue64(
stream.cuda_stream,
signal_tensor.data_ptr(),
signal,
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
)
CUDA_CHECK(err)
else:
raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")

Expand Down
16 changes: 14 additions & 2 deletions tilelang/language/distributed/multi_device/nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,20 @@ def putmem_signal_nbi(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalNbi"), *args)


def putmem_signal_block(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalBlock"), *args)
def putmem_signal_block(dest, src, nelems, sig_addr, signal, sig_op, pe):
"""Put data from local memory to remote memory at block granularity,
and update a remote flag on delivery.
Args:
dest: Symmetric address of the destination data object.
src: Symmetric address of the object containing the data to be copied.
nelems: Number of elements to be transferred (in bytes).
sig_addr: Symmetric address of the remote flag to be updated.
signal: The value used for updating the remote signal data object.
sig_op: The type of update to be performed on the remote signal data object.
pe: The PE ID of the destination PE.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.PutmemSignalBlock"), dest, src, nelems,
sig_addr, signal, sig_op, pe)


def putmem_signal_nbi_block(dest, src, nelems, sig_addr, signal, sig_op, pe):
Expand Down
Loading