Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions python/sglang/srt/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,13 @@

def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
if input_.is_cpu:
from sglang.srt.distributed import get_tp_group

shm_comm_op = get_tp_group().shm_comm_op
shm_comm_op.shm_allreduce(
input_, get_tp_group().device_group, torch.distributed.ReduceOp.SUM
)
return input_

return get_tp_group().all_reduce(input_)


def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
if input_.is_cpu:
from sglang.srt.distributed import get_tp_group

shm_comm_op = get_tp_group().shm_comm_op
return shm_comm_op.shm_allgather(input_, get_tp_group().device_group, dim)
return get_tp_group().all_gather(input_, dim)


Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return input_

if input_.is_cpu:
import intel_extension_for_pytorch as ipex

ipex.distributed.all_reduce(input_, group=self.device_group)
torch.ops.sgl_kernel_cpu.shm_allreduce(
input_,
self.device_group.group_name,
"sum",
)
return input_

if not supports_custom_op():
Expand Down Expand Up @@ -464,6 +466,13 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"

if input_.is_cpu:
return torch.ops.sgl_kernel_cpu.shm_allgather(
input_,
self.device_group.group_name,
dim,
)

# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
Expand Down
12 changes: 7 additions & 5 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
silu_and_mul(x, out)
return out

def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if cpu_has_amx_support():
return sgl_kernel.cpu.silu_and_mul(x)
else:
return self.forward_native(x)
forward_cpu = staticmethod(sgl_kernel.cpu.silu_and_mul) if cpu_has_amx_support() else forward_native

# def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
# if cpu_has_amx_support():
# return sgl_kernel.cpu.silu_and_mul(x)
# else:
# return self.forward_native(x)


class GeluAndMul(CustomOp):
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
if cpu_has_amx_support():
import sgl_kernel.cpu

_has_amx = True

else:
_has_amx = False


def fused_topk_native(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -182,7 +187,7 @@ def select_experts(
assert num_expert_group is not None
if correction_bias is None:
device = hidden_states.device
if device == torch.device("cpu") and cpu_has_amx_support():
if device == torch.device("cpu") and _has_amx:
topk_weights, topk_ids = sgl_kernel.cpu.grouped_topk(
hidden_states,
router_logits,
Expand All @@ -202,7 +207,7 @@ def select_experts(
)
else:
device = hidden_states.device
if device == torch.device("cpu") and cpu_has_amx_support():
if device == torch.device("cpu") and _has_amx:
topk_weights, topk_ids = sgl_kernel.cpu.biased_grouped_topk(
hidden_states,
router_logits,
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
if cpu_has_amx_support():
import sgl_kernel.cpu

_has_amx = True

else:
_has_amx = False


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -727,7 +732,7 @@ def forward(
positions = torch.add(positions, offsets) if offsets is not None else positions

# TODO: Add scenario of self.rotary_dim < self.head_size
if positions.device == torch.device("cpu") and cpu_has_amx_support():
if positions.device == torch.device("cpu") and _has_amx:
return sgl_kernel.cpu.rotary_position_embedding(
positions, query, key, self.cos_sin_cache
)
Expand Down
Loading
Loading