From 3cacd14a82b6a37263a70a8e65b8450d07336616 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 6 May 2025 11:28:34 +0000 Subject: [PATCH 1/6] add initial CPU compiler runner --- .../srt/model_executor/cpu_compile_runner.py | 480 ++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 33 ++ 2 files changed, 513 insertions(+) create mode 100644 python/sglang/srt/model_executor/cpu_compile_runner.py diff --git a/python/sglang/srt/model_executor/cpu_compile_runner.py b/python/sglang/srt/model_executor/cpu_compile_runner.py new file mode 100644 index 000000000000..bc31f85ded86 --- /dev/null +++ b/python/sglang/srt/model_executor/cpu_compile_runner.py @@ -0,0 +1,480 @@ +"""Modified from cuda_graph_runner.py""" + +from __future__ import annotations + +import bisect +from contextlib import contextmanager +from typing import TYPE_CHECKING, Callable + +import torch +import tqdm + +from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + + +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): + for sub in model._modules.values(): + if isinstance(sub, CustomOp): + if reverse: + sub._forward_method = sub.forward_cuda + setattr(sub, "is_torch_compile", False) + else: + # NOTE: Temporarily workaround MoE + if "FusedMoE" in sub.__class__.__name__: + if num_tokens == 1: + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to only use torch.compile when bs =1 + sub._forward_method = fused_moe_forward_native + else: + sub._forward_method = sub.forward_native + setattr(sub, "is_torch_compile", True) + if isinstance(sub, torch.nn.Module): + _to_torch(sub, reverse, num_tokens) + + +@contextmanager +def patch_model( + model: torch.nn.Module, + enable_compile: bool, + num_tokens: int, + tp_group: GroupCoordinator, +): + """Patch the model to make it compatible with with torch.compile""" + backup_ca_comm = None + + try: + if enable_compile: + # _to_torch(model, reverse=False, num_tokens=num_tokens) # not sure why this is needed + backup_ca_comm = tp_group.ca_comm + # Use custom-allreduce here. + # We found the custom allreduce is much faster than the built-in allreduce in torch, + # even with ENABLE_INTRA_NODE_COMM=1. + # tp_group.ca_comm = None + yield torch.compile( + torch.no_grad()(model.forward), + mode="max-autotune-no-cudagraphs", + dynamic=False, + ) + else: + yield model.forward + finally: + if enable_compile: + # _to_torch(model, reverse=True, num_tokens=num_tokens) + tp_group.ca_comm = backup_ca_comm + + +def set_torch_compile_config(): + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + # FIXME: tmp workaround + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + + +def get_batch_sizes_to_compile(model_runner: ModelRunner): + # NOTE: may want to simplify this + server_args = model_runner.server_args + capture_bs = server_args.cuda_graph_bs + + if capture_bs is None: + if server_args.speculative_algorithm is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + [64, 96, 128, 160] + else: + capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + else: + capture_bs = list(range(1, 33)) + + if max(capture_bs) > model_runner.req_to_token_pool.size: + # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very small. We add more values here to make sure we capture the maximum bs. + capture_bs = list( + sorted( + set( + capture_bs + + [model_runner.req_to_token_pool.size - 1] + + [model_runner.req_to_token_pool.size] + ) + ) + ) + + capture_bs = [ + bs + for bs in capture_bs + if bs <= model_runner.req_to_token_pool.size + and bs <= server_args.cuda_graph_max_bs + ] + compile_bs = ( + [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + if server_args.enable_torch_compile + else [] + ) + return compile_bs + + +class CpuCompileRunner: + """A CpuCompileRunner runs the forward pass of a model with torch.compile.""" + + def __init__(self, model_runner: ModelRunner): + # Parse args + self.model_runner = model_runner + # self.graphs = {} + # self.output_buffers = {} + self.compiled_forwards = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size + + # Batch sizes to capture + self.compile_bs = get_batch_sizes_to_compile(model_runner) + self.capture_forward_mode = ForwardMode.DECODE + self.capture_hidden_mode = CaptureHiddenMode.NULL + self.num_tokens_per_bs = 1 + if model_runner.spec_algorithm.is_eagle(): + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen") + else: + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) + + # Attention backend + self.max_bs = max(self.compile_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + # self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) + # self.seq_len_fill_value = ( + # self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + # ) + self.seq_len_fill_value = 0 + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary + self.encoder_len_fill_value = 0 + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + + if self.enable_torch_compile: + set_torch_compile_config() + + # Graph inputs + # NOTE: we don't actually need this + with torch.device("cpu"): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + # cuda_graph_runner uses + # - int32 for req_pool_indices and seq_lens + # - int64 for out_cache_loc + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int64 + ) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + + # Speculative_inference + if model_runner.spec_algorithm.is_eagle(): + self.hidden_states = torch.zeros( + (self.max_num_token, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) + + if self.is_encoder_decoder: + # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch + self.encoder_lens = torch.full( + (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 + ) + else: + self.encoder_lens = None + + if self.enable_dp_attention: + self.gathered_buffer = torch.zeros( + ( + self.max_bs * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + self.global_num_tokens_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + + # Capture + try: + with self.model_capture_mode(): + self.capture() + except RuntimeError as e: + import traceback + raise Exception( + f"CPU compile failed: {e}\n" + f"{traceback.format_exc()}\n" + "Possible solutions:\n" + # "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "4. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + + @contextmanager + def model_capture_mode(self): + if hasattr(self.model_runner.model, "capture_mode"): + self.model_runner.model.capture_mode = True + + yield + + if hasattr(self.model_runner.model, "capture_mode"): + self.model_runner.model.capture_mode = False + + def can_run(self, forward_batch: ForwardBatch): + if self.enable_dp_attention: + min_num_tokens, max_num_tokens = min( + forward_batch.global_num_tokens_cpu + ), max(forward_batch.global_num_tokens_cpu) + is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( + (min_num_tokens == max_num_tokens and max_num_tokens in self.compiled_forwards) + if self.disable_padding + else max_num_tokens <= self.max_bs + ) + else: + is_bs_supported = ( + forward_batch.batch_size in self.compiled_forwards + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + + # TODO: check this + # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) + # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph + # because the full_text_row_masked_out_mask tensor will always be ones + is_encoder_lens_supported = ( + torch.all(forward_batch.encoder_lens > 0) + if self.is_encoder_decoder + else True + ) + return is_bs_supported and is_encoder_lens_supported + + def capture(self): + # TODO: see if this can be removed + # Reverse the order to enable better memory sharing across cuda graphs. + capture_range = ( + tqdm.tqdm(list(reversed(self.compile_bs))) + if get_tensor_model_parallel_rank() == 0 + else reversed(self.compile_bs) + ) + for bs in capture_range: + with patch_model( + self.model_runner.model, + True, + num_tokens=bs * self.num_tokens_per_bs, + tp_group=self.model_runner.tp_group, + ) as forward: + self.capture_one_batch_size(bs, forward) + self.compiled_forwards[bs] = forward + # self.graphs[bs] = graph + # self.output_buffers[bs] = output_buffers + + def capture_one_batch_size(self, bs: int, forward: Callable): + # graph = torch.cuda.CUDAGraph() + # stream = self.stream + num_tokens = bs * self.num_tokens_per_bs + + # Graph inputs + input_ids = self.input_ids[:num_tokens] + req_pool_indices = self.req_pool_indices[:bs] + seq_lens = self.seq_lens[:bs] + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] + if self.is_encoder_decoder: + encoder_lens = self.encoder_lens[:bs] + else: + encoder_lens = None + mrope_positions = self.mrope_positions[:, :bs] + + if self.enable_dp_attention: + global_num_tokens = [bs] * self.tp_size + gathered_buffer = self.gathered_buffer[: bs * self.tp_size] + else: + global_num_tokens = None + gathered_buffer = None + + spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) + + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + encoder_lens=encoder_lens, + return_logprob=False, + positions=positions, + global_num_tokens_cpu=global_num_tokens, + gathered_buffer=gathered_buffer, + mrope_positions=mrope_positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=self.capture_hidden_mode, + ) + + # Attention backend + self.model_runner.attn_backend.init_forward_metadata(forward_batch) + # self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( + # bs, + # num_tokens, + # req_pool_indices, + # seq_lens, + # encoder_lens, + # forward_batch.forward_mode, + # forward_batch.spec_info, + # ) + + # trigger torch.compile() + for _ in range(2): + self.model_runner.tp_group.barrier() + forward(input_ids, forward_batch.positions, forward_batch) + + def recapture_if_needed(self, forward_batch: ForwardBatch): + # If the capture_hidden_mode changes, we need to recapture the graph + hidden_mode_from_spec_info = getattr( + forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + if ( + forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL + and self.capture_hidden_mode != CaptureHiddenMode.FULL + ): + self.capture_hidden_mode = CaptureHiddenMode.FULL + self.capture() + elif ( + forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL + and self.capture_hidden_mode != hidden_mode_from_spec_info + ): + self.capture_hidden_mode = hidden_mode_from_spec_info + self.capture() + + def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): + self.recapture_if_needed(forward_batch) + + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + if self.enable_dp_attention: + index = bisect.bisect_left( + self.compile_bs, max(forward_batch.global_num_tokens_cpu) + ) + else: + index = bisect.bisect_left(self.compile_bs, raw_bs) + bs = self.compile_bs[index] + if bs != raw_bs: + self.seq_lens.fill_(1) + self.out_cache_loc.zero_() + + # NOTE: this is only for CUDA graph (copy to replay buffers) + # we can remove this. + # Common inputs + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + self.positions[:raw_num_token].copy_(forward_batch.positions) + if forward_batch.decode_seq_lens_cpu is not None: + if bs != raw_bs: + self.seq_lens_cpu.fill_(1) + self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) + + if self.is_encoder_decoder: + self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) + if forward_batch.mrope_positions is not None: + self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + + if hasattr(forward_batch.spec_info, "hidden_states"): + self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states + + # Attention backend + self.model_runner.attn_backend.init_forward_metadata(forward_batch) + # self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( + # bs, + # self.req_pool_indices, + # self.seq_lens, + # forward_batch.seq_lens_sum + (bs - raw_bs), + # self.encoder_lens, + # forward_batch.forward_mode, + # forward_batch.spec_info, + # seq_lens_cpu=self.seq_lens_cpu, + # ) + + # Replay + # self.graphs[bs].replay() + # next_token_logits, hidden_states = self.output_buffers[bs] + logits_output = self.compiled_forwards[bs]( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + next_token_logits = logits_output.next_token_logits + hidden_states = logits_output.hidden_states + + logits_output = LogitsProcessorOutput( + next_token_logits=next_token_logits[:raw_num_token], + hidden_states=( + hidden_states[:raw_num_token] if hidden_states is not None else None + ), + ) + return logits_output + + def get_spec_info(self, num_tokens: int): + spec_info = None + if self.model_runner.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_utils import EagleVerifyInput + + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen.") + else: + spec_info = EagleVerifyInput( + draft_token=None, + custom_mask=torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ), + positions=None, + retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, + retrive_cum_len=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + spec_steps=self.model_runner.server_args.speculative_num_steps, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) + + return spec_info diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0ba5f8b419b9..97ad4c0c1139 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -56,6 +56,7 @@ TokenToKVPoolAllocator, ) from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner +from sglang.srt.model_executor.cpu_compile_runner import CpuCompileRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import ( @@ -213,6 +214,10 @@ def initialize(self, min_per_gpu_memory: float): self.init_cublas() self.init_attention_backend() self.init_cuda_graphs() + elif self.device == "cpu": + self.init_attention_backend() + self.init_cpu_compile() + self.cuda_graph_runner = None else: self.cuda_graph_runner = None self.init_attention_backend() @@ -914,6 +919,27 @@ def init_cuda_graphs(self): f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." ) + def init_cpu_compile(self): + self.cpu_compile_runner = None + + if not self.is_generation: + return + + if not self.server_args.enable_torch_compile: + return + + tic = time.time() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"CPU compile begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" + ) + self.cpu_compile_runner = CpuCompileRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"CPU compile end. Time elapsed: {time.time() - tic:.2f} s. " + f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." + ) + def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") from sglang.srt.model_parallel import tensor_parallel @@ -971,6 +997,13 @@ def forward( forward_batch, skip_attn_backend_init=skip_attn_backend_init ) + if ( + forward_batch.forward_mode.is_decode() + and self.cpu_compile_runner is not None + and self.cpu_compile_runner.can_run(forward_batch) + ): + return self.cpu_compile_runner.replay(forward_batch) + if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): From c6200cb9972ff9f5d4aad3c2e46d2ff6bff3c990 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 6 May 2025 15:15:20 +0000 Subject: [PATCH 2/6] working POC, though output is wrong --- python/sglang/srt/layers/activation.py | 12 +++--- .../srt/model_executor/cpu_compile_runner.py | 1 + sgl-kernel/csrc/cpu/bmm.cpp | 2 +- sgl-kernel/csrc/cpu/gemm.cpp | 2 +- sgl-kernel/csrc/cpu/gemm_fp8.cpp | 2 +- sgl-kernel/csrc/cpu/models/deepseek.cpp | 6 +-- sgl-kernel/csrc/cpu/qkv_proj.cpp | 6 +-- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 37 ++++++++++++++-- sgl-kernel/python/sgl_kernel/cpu.py | 42 +++++++++++++++---- 9 files changed, 85 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 5de053c97bd6..30fc6f05b12a 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -55,11 +55,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: silu_and_mul(x, out) return out - def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: - if cpu_has_amx_support(): - return sgl_kernel.cpu.silu_and_mul(x) - else: - return self.forward_native(x) + forward_cpu = staticmethod(sgl_kernel.cpu.silu_and_mul) if cpu_has_amx_support() else forward_native + + # def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: + # if cpu_has_amx_support(): + # return sgl_kernel.cpu.silu_and_mul(x) + # else: + # return self.forward_native(x) class GeluAndMul(CustomOp): diff --git a/python/sglang/srt/model_executor/cpu_compile_runner.py b/python/sglang/srt/model_executor/cpu_compile_runner.py index bc31f85ded86..e11116041770 100644 --- a/python/sglang/srt/model_executor/cpu_compile_runner.py +++ b/python/sglang/srt/model_executor/cpu_compile_runner.py @@ -64,6 +64,7 @@ def patch_model( # tp_group.ca_comm = None yield torch.compile( torch.no_grad()(model.forward), + fullgraph=True, mode="max-autotune-no-cudagraphs", dynamic=False, ) diff --git a/sgl-kernel/csrc/cpu/bmm.cpp b/sgl-kernel/csrc/cpu/bmm.cpp index 337d6d4c67a2..d98d762a9fb2 100644 --- a/sgl-kernel/csrc/cpu/bmm.cpp +++ b/sgl-kernel/csrc/cpu/bmm.cpp @@ -76,7 +76,7 @@ void bmm_kernel_impl( // scale: [] 0-dim tensor for per tensor quant // void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale) { + const std::optional& scale) { RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector({out, mat1, mat2})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 3c8be1612db5..95f90f3b7690 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -416,7 +416,7 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { // out : [M, N] // at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni) { + const std::optional& bias, bool is_vnni) { RECORD_FUNCTION( "sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 0088b969afb4..0be574a9f4b3 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -442,7 +442,7 @@ INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::vector block_size, std::optional& bias, + std::vector block_size, const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); diff --git a/sgl-kernel/csrc/cpu/models/deepseek.cpp b/sgl-kernel/csrc/cpu/models/deepseek.cpp index 213a1d753e6a..2048df57ced4 100644 --- a/sgl-kernel/csrc/cpu/models/deepseek.cpp +++ b/sgl-kernel/csrc/cpu/models/deepseek.cpp @@ -12,7 +12,7 @@ extern void decode_attention_cpu(at::Tensor& query, at::Tensor& k_cache, at::Ten double sm_scale, double logit_cap); extern void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale); + const std::optional& scale); extern std::tuple qkv_proj_with_rope( at::Tensor& hidden_states, @@ -34,13 +34,13 @@ extern std::tuple qkv_proj_with_rope( std::optional> block_size); extern at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni); + const std::optional& bias, bool is_vnni); extern at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, std::optional& bias, at::ScalarType out_dtype, bool is_vnni); extern at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::vector block_size, std::optional& bias, + std::vector block_size, const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); extern void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, py::object op); diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp index cc5beed19d02..b6d45e7eb1ae 100644 --- a/sgl-kernel/csrc/cpu/qkv_proj.cpp +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -401,16 +401,16 @@ void rotary_emb_kernel_impl( } // anonymous namespace extern at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni); + const std::optional& bias, bool is_vnni); extern at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, std::optional& bias, at::ScalarType out_dtype, bool is_vnni); extern void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale); + const std::optional& scale); extern at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::vector block_size, std::optional& bias, at::ScalarType out_dtype, bool is_vnni); + std::vector block_size, const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); // NB: shapes in DeepDeek R1 diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index d6debe9b7962..300ca51d88be 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -105,7 +105,7 @@ std::tuple per_token_quant_int8_cpu(at::Tensor& A); // gemm at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni); + const std::optional& bias, bool is_vnni); // igemm at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, @@ -115,7 +115,7 @@ at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, // fp8 gemm at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, std::vector block_size, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni); + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); // quant + igemm at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, @@ -123,7 +123,7 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Ten // bmm void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale); + const std::optional& scale); // fused moe at::Tensor fused_experts_cpu( @@ -280,3 +280,34 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // rope m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU"); } + +#define IMPL_CPU(op) m.impl(#op, at::kCPU, &op); + +TORCH_LIBRARY(sgl_kernel_cpu, m) { + m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); + IMPL_CPU(silu_and_mul_cpu); + + m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); + IMPL_CPU(rmsnorm_cpu); + + m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor(a!) residual, Tensor weight, float eps) -> ()"); + IMPL_CPU(fused_add_rmsnorm_cpu); + + m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor"); + IMPL_CPU(weight_packed_linear); + + m.def("fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); + IMPL_CPU(fp8_scaled_mm_cpu); + + m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"); + IMPL_CPU(bmm_cpu); + + m.def("decode_attention_cpu(Tensor query, Tensor(a!) k_cache, Tensor(b!) v_cache, Tensor(c!) output," + "Tensor key, Tensor value, Tensor loc, Tensor attn_logits," + "Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens," + "float sm_scale, float logit_cap) -> ()"); + IMPL_CPU(decode_attention_cpu); + + m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)"); + IMPL_CPU(rotary_position_embedding_cpu); +} diff --git a/sgl-kernel/python/sgl_kernel/cpu.py b/sgl-kernel/python/sgl_kernel/cpu.py index 7e21882b9244..7dce142f7246 100644 --- a/sgl-kernel/python/sgl_kernel/cpu.py +++ b/sgl-kernel/python/sgl_kernel/cpu.py @@ -203,7 +203,8 @@ def decode_attention( sm_scale, logit_cap=0.0, ): - sgl_kernel.common_ops.decode_attention_cpu( + # sgl_kernel.common_ops.decode_attention_cpu( + torch.ops.sgl_kernel_cpu.decode_attention_cpu.default( q, k_buffer, v_buffer, @@ -358,7 +359,8 @@ def weight_packed_linear( bias, is_vnni=True, ): - return sgl_kernel.common_ops.weight_packed_linear( + # return sgl_kernel.common_ops.weight_packed_linear( + return torch.ops.sgl_kernel_cpu.weight_packed_linear.default( x, weight, bias, @@ -366,6 +368,11 @@ def weight_packed_linear( ) +@torch.library.register_fake("sgl_kernel_cpu::weight_packed_linear") +def _(x, weight, bias, is_vnni): + return x.new_empty(x.shape[0], weight.shape[0]) + + def grouped_topk( hidden_states, router_logits, @@ -410,7 +417,8 @@ def fused_add_rmsnorm( weight, eps, ): - sgl_kernel.common_ops.fused_add_rmsnorm_cpu( + # sgl_kernel.common_ops.fused_add_rmsnorm_cpu( + torch.ops.sgl_kernel_cpu.fused_add_rmsnorm_cpu.default( input, residual, weight, @@ -423,12 +431,17 @@ def rmsnorm( weight, eps, ): - return sgl_kernel.common_ops.rmsnorm_cpu( + # return sgl_kernel.common_ops.rmsnorm_cpu( + return torch.ops.sgl_kernel_cpu.rmsnorm_cpu.default( input, weight, eps, ) +@torch.library.register_fake("sgl_kernel_cpu::rmsnorm_cpu") +def _(input, weight, eps): + return torch.empty_like(input) + def int8_scaled_mm( mat1, @@ -470,7 +483,8 @@ def fp8_scaled_mm( out_dtype, is_vnni=True, ): - return sgl_kernel.common_ops.fp8_scaled_mm_cpu( + # return sgl_kernel.common_ops.fp8_scaled_mm_cpu( + return torch.ops.sgl_kernel_cpu.fp8_scaled_mm_cpu.default( mat1, mat2, scales2, block_size, bias, out_dtype, is_vnni ) @@ -481,19 +495,31 @@ def rotary_position_embedding( k_pe, t_emb_pos, ): - return sgl_kernel.common_ops.rotary_position_embedding_cpu( + # return sgl_kernel.common_ops.rotary_position_embedding_cpu( + return torch.ops.sgl_kernel_cpu.rotary_position_embedding_cpu.default( t_pos, q_pe, k_pe, t_emb_pos, ) +@torch.library.register_fake("sgl_kernel_cpu::rotary_position_embedding_cpu") +def _(t_pos, q_pe, k_pe, t_emb_pos): + return torch.empty_like(q_pe), torch.empty_like(k_pe) + def silu_and_mul( input, ): - return sgl_kernel.common_ops.silu_and_mul_cpu(input) + # return sgl_kernel.common_ops.silu_and_mul_cpu(input) + return torch.ops.sgl_kernel_cpu.silu_and_mul_cpu.default(input) + + +@torch.library.register_fake("sgl_kernel_cpu::silu_and_mul_cpu") +def _(input): + return input.new_empty(input.shape[0], input.shape[1] // 2) def bmm(out, mat1, mat2, is_vnni=True, scale=None): - return sgl_kernel.common_ops.bmm_cpu(out, mat1, mat2, is_vnni, scale) + # return sgl_kernel.common_ops.bmm_cpu(out, mat1, mat2, is_vnni, scale) + return torch.ops.sgl_kernel_cpu.bmm_cpu.default(out, mat1, mat2, is_vnni, scale) From 2497683636c17e836e288c2697d1fda312a6bcbe Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 7 May 2025 08:38:55 +0000 Subject: [PATCH 3/6] fix wrong custom op mutate args --- .../srt/model_executor/cpu_compile_runner.py | 56 +------------------ sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 4 +- 2 files changed, 4 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/model_executor/cpu_compile_runner.py b/python/sglang/srt/model_executor/cpu_compile_runner.py index e11116041770..5d4183a69c0e 100644 --- a/python/sglang/srt/model_executor/cpu_compile_runner.py +++ b/python/sglang/srt/model_executor/cpu_compile_runner.py @@ -296,8 +296,6 @@ def capture(self): # self.output_buffers[bs] = output_buffers def capture_one_batch_size(self, bs: int, forward: Callable): - # graph = torch.cuda.CUDAGraph() - # stream = self.stream num_tokens = bs * self.num_tokens_per_bs # Graph inputs @@ -349,20 +347,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): # Attention backend self.model_runner.attn_backend.init_forward_metadata(forward_batch) - # self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( - # bs, - # num_tokens, - # req_pool_indices, - # seq_lens, - # encoder_lens, - # forward_batch.forward_mode, - # forward_batch.spec_info, - # ) # trigger torch.compile() for _ in range(2): self.model_runner.tp_group.barrier() - forward(input_ids, forward_batch.positions, forward_batch) + forward(input_ids, positions, forward_batch) def recapture_if_needed(self, forward_batch: ForwardBatch): # If the capture_hidden_mode changes, we need to recapture the graph @@ -397,60 +386,19 @@ def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = Fal index = bisect.bisect_left(self.compile_bs, raw_bs) bs = self.compile_bs[index] if bs != raw_bs: + raise self.seq_lens.fill_(1) self.out_cache_loc.zero_() - # NOTE: this is only for CUDA graph (copy to replay buffers) - # we can remove this. - # Common inputs - self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) - self.positions[:raw_num_token].copy_(forward_batch.positions) - if forward_batch.decode_seq_lens_cpu is not None: - if bs != raw_bs: - self.seq_lens_cpu.fill_(1) - self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) - - if self.is_encoder_decoder: - self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) - if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) - - if hasattr(forward_batch.spec_info, "hidden_states"): - self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states - # Attention backend self.model_runner.attn_backend.init_forward_metadata(forward_batch) - # self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( - # bs, - # self.req_pool_indices, - # self.seq_lens, - # forward_batch.seq_lens_sum + (bs - raw_bs), - # self.encoder_lens, - # forward_batch.forward_mode, - # forward_batch.spec_info, - # seq_lens_cpu=self.seq_lens_cpu, - # ) # Replay - # self.graphs[bs].replay() - # next_token_logits, hidden_states = self.output_buffers[bs] logits_output = self.compiled_forwards[bs]( forward_batch.input_ids, forward_batch.positions, forward_batch, ) - next_token_logits = logits_output.next_token_logits - hidden_states = logits_output.hidden_states - - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits[:raw_num_token], - hidden_states=( - hidden_states[:raw_num_token] if hidden_states is not None else None - ), - ) return logits_output def get_spec_info(self, num_tokens: int): diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 300ca51d88be..7ea4a8f0e142 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -290,7 +290,7 @@ TORCH_LIBRARY(sgl_kernel_cpu, m) { m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); IMPL_CPU(rmsnorm_cpu); - m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor(a!) residual, Tensor weight, float eps) -> ()"); + m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor(b!) residual, Tensor weight, float eps) -> ()"); IMPL_CPU(fused_add_rmsnorm_cpu); m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor"); @@ -303,7 +303,7 @@ TORCH_LIBRARY(sgl_kernel_cpu, m) { IMPL_CPU(bmm_cpu); m.def("decode_attention_cpu(Tensor query, Tensor(a!) k_cache, Tensor(b!) v_cache, Tensor(c!) output," - "Tensor key, Tensor value, Tensor loc, Tensor attn_logits," + "Tensor key, Tensor value, Tensor loc, Tensor(d!) attn_logits," "Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens," "float sm_scale, float logit_cap) -> ()"); IMPL_CPU(decode_attention_cpu); From 064c40abddd24b91dc0db48eb6027a8b6f6fc9bd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 7 May 2025 10:05:37 +0000 Subject: [PATCH 4/6] use c10d::ReduceOp --- sgl-kernel/csrc/cpu/interface.cpp | 9 ++------- sgl-kernel/csrc/cpu/models/deepseek.cpp | 8 ++++---- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 6 +++--- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sgl-kernel/csrc/cpu/interface.cpp b/sgl-kernel/csrc/cpu/interface.cpp index 9d2e7c8d2798..3783dd3654c9 100644 --- a/sgl-kernel/csrc/cpu/interface.cpp +++ b/sgl-kernel/csrc/cpu/interface.cpp @@ -50,16 +50,11 @@ void initialize(int size, int rank) { void shm_allreduce( torch::Tensor& data, c10::intrusive_ptr process_group, - py::object op) { + c10d::ReduceOp op) { RECORD_FUNCTION( "sgl-kernel::shm_allreduce", std::vector({data})); - static py::object ReduceOp = - py::module_::import("torch.distributed").attr("ReduceOp"); - static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); - TORCH_CHECK( - py::int_(op.attr("value")) == ReduceOpSum, - "Only torch.distributed.ReduceOp.SUM is supported"); + TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported"); auto numel = data.numel(); diff --git a/sgl-kernel/csrc/cpu/models/deepseek.cpp b/sgl-kernel/csrc/cpu/models/deepseek.cpp index 2048df57ced4..302b908d97c1 100644 --- a/sgl-kernel/csrc/cpu/models/deepseek.cpp +++ b/sgl-kernel/csrc/cpu/models/deepseek.cpp @@ -43,7 +43,7 @@ extern at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tens std::vector block_size, const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); -extern void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, py::object op); +extern void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, c10d::ReduceOp op); extern std::tuple grouped_topk_cpu( at::Tensor& hidden_states, @@ -101,7 +101,7 @@ at::Tensor row_parallel_linear_forward( int tp_size, int tp_rank, std::optional> process_group, - std::optional op, + std::optional op, bool use_int8_w8a8, bool use_fp8_w8a16, at::ScalarType out_dtype, @@ -181,7 +181,7 @@ at::Tensor forward_absorb_decode_fused_cpu( std::optional> block_size, // qkv_proj_with_rope std::optional& bmm_scale, // bmm std::optional> process_group, // o_proj - std::optional op, // o_proj + std::optional op, // o_proj std::optional& o_proj_scale, // o_proj std::optional> o_proj_block_size, // o_proj bool is_vnni // qkv_proj_with_rope, bmm, o_proj @@ -326,7 +326,7 @@ at::Tensor forward_moe_fused_cpu( std::optional& shared_expert_a1_scale, // shared_expert std::optional& shared_expert_a2_scale, // shared_expert std::optional> process_group, // all_reduce - std::optional op, // all_reduce + std::optional op, // all_reduce bool is_vnni // MoEGate, experts, shared_expert ) { RECORD_FUNCTION("sgl-kernel::forward_moe_fused_cpu", std::vector({ diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 7ea4a8f0e142..6136cc41cf3e 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -91,7 +91,7 @@ at::Tensor forward_absorb_decode_fused_cpu( std::optional> block_size, // qkv_proj_with_rope std::optional& bmm_scale, // bmm std::optional> process_group, // o_proj - std::optional op, // o_proj + std::optional op, // o_proj std::optional& o_proj_scale, // o_proj std::optional> o_proj_block_size, // o_proj bool is_vnni // qkv_proj_with_rope, bmm, o_proj @@ -191,7 +191,7 @@ at::Tensor forward_moe_fused_cpu( std::optional& shared_expert_a1_scale, // shared_expert std::optional& shared_expert_a2_scale, // shared_expert std::optional> process_group, // all_reduce - std::optional op, // all_reduce + std::optional op, // all_reduce bool is_vnni // MoEGate, experts, shared_expert ); @@ -207,7 +207,7 @@ std::tuple qkv_proj_with_rope( at::Tensor& h void initialize(int size, int rank); // shared mmeory all_reduce -void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, py::object op); +void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, c10d::ReduceOp op); // shared memory all_gather at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int dim); From de288779989d3170cb7a623ea35154fadcdd0e91 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 8 May 2025 10:09:28 +0000 Subject: [PATCH 5/6] torch.compile() for DeepSeekv2 --- python/sglang/srt/layers/moe/topk.py | 9 ++- python/sglang/srt/layers/rotary_embedding.py | 7 ++- python/sglang/srt/models/deepseek_v2.py | 7 ++- sgl-kernel/csrc/cpu/models/deepseek.cpp | 16 ++--- sgl-kernel/csrc/cpu/moe.cpp | 24 ++++---- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 25 +++++--- sgl-kernel/python/sgl_kernel/cpu.py | 65 +++++++++++++++++++- 7 files changed, 117 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 2c54b146ce03..3c1c639787b1 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -23,6 +23,11 @@ if cpu_has_amx_support(): import sgl_kernel.cpu + _has_amx = True + +else: + _has_amx = False + def fused_topk_native( hidden_states: torch.Tensor, @@ -182,7 +187,7 @@ def select_experts( assert num_expert_group is not None if correction_bias is None: device = hidden_states.device - if device == torch.device("cpu") and cpu_has_amx_support(): + if device == torch.device("cpu") and _has_amx: topk_weights, topk_ids = sgl_kernel.cpu.grouped_topk( hidden_states, router_logits, @@ -202,7 +207,7 @@ def select_experts( ) else: device = hidden_states.device - if device == torch.device("cpu") and cpu_has_amx_support(): + if device == torch.device("cpu") and _has_amx: topk_weights, topk_ids = sgl_kernel.cpu.biased_grouped_topk( hidden_states, router_logits, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index bb17541042d3..5f089a5da4b8 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -20,6 +20,11 @@ if cpu_has_amx_support(): import sgl_kernel.cpu + _has_amx = True + +else: + _has_amx = False + def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] @@ -727,7 +732,7 @@ def forward( positions = torch.add(positions, offsets) if offsets is not None else positions # TODO: Add scenario of self.rotary_dim < self.head_size - if positions.device == torch.device("cpu") and cpu_has_amx_support(): + if positions.device == torch.device("cpu") and _has_amx: return sgl_kernel.cpu.rotary_position_embedding( positions, query, key, self.cos_sin_cache ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9ae5aed79680..c09be6f87f6a 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -272,6 +272,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.shared_experts_down_proj, ) ) + return self.forward_normal(hidden_states) if has_shared_experts and use_intel_amx_backend: return self.forward_moe_fused_cpu(hidden_states) else: @@ -346,10 +347,12 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.gate_impl is None: self.gate_impl = self.gate.forward # router_logits: (num_tokens, n_experts) - router_logits = self.gate_impl(hidden_states) + # router_logits = self.gate_impl(hidden_states) + router_logits = self.gate(hidden_states) if self.experts_impl is None: self.experts_impl = self.experts.forward - fused_experts_out = self.experts_impl( + # fused_experts_out = self.experts_impl( + fused_experts_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) diff --git a/sgl-kernel/csrc/cpu/models/deepseek.cpp b/sgl-kernel/csrc/cpu/models/deepseek.cpp index 302b908d97c1..51d8a102510e 100644 --- a/sgl-kernel/csrc/cpu/models/deepseek.cpp +++ b/sgl-kernel/csrc/cpu/models/deepseek.cpp @@ -71,11 +71,11 @@ extern at::Tensor fused_experts_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); extern at::Tensor shared_expert_cpu( @@ -87,11 +87,11 @@ extern at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); // This function implements the forward function of sglang/python/sglang/srt/layers/linear.py:RowParallelLinear diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 5f15d0f58c72..94c607c698b9 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -921,11 +921,11 @@ void shared_expert_kernel_impl( static inline void check_moe_scales( bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale) { + const std::optional& a1_scale, + const std::optional& a2_scale) { if (use_int8_w8a8) { TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); @@ -966,11 +966,11 @@ at::Tensor fused_experts_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); @@ -1194,11 +1194,11 @@ at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 6136cc41cf3e..e9e21d16365c 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -135,11 +135,11 @@ at::Tensor fused_experts_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); at::Tensor shared_expert_cpu( @@ -151,11 +151,11 @@ at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); at::Tensor forward_moe_fused_cpu( @@ -310,4 +310,13 @@ TORCH_LIBRARY(sgl_kernel_cpu, m) { m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)"); IMPL_CPU(rotary_position_embedding_cpu); + + m.def("grouped_topk_cpu(Tensor hidden_states, Tensor gating_input, int topk, bool renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)"); + IMPL_CPU(grouped_topk_cpu); + + m.def("fused_experts_cpu(Tensor(a!) hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor"); + IMPL_CPU(fused_experts_cpu); + + m.def("shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor"); + IMPL_CPU(shared_expert_cpu); } diff --git a/sgl-kernel/python/sgl_kernel/cpu.py b/sgl-kernel/python/sgl_kernel/cpu.py index 7dce142f7246..041357635479 100644 --- a/sgl-kernel/python/sgl_kernel/cpu.py +++ b/sgl-kernel/python/sgl_kernel/cpu.py @@ -18,7 +18,8 @@ def fused_experts( a2_scale=None, is_vnni=True, ): - return sgl_kernel.common_ops.fused_experts_cpu( + # return sgl_kernel.common_ops.fused_experts_cpu( + return torch.ops.sgl_kernel_cpu.fused_experts_cpu.default( x, w13_weight, w2_weight, @@ -36,6 +37,26 @@ def fused_experts( ) +@torch.library.register_fake("sgl_kernel_cpu::fused_experts_cpu") +def _( + x, + w13_weight, + w2_weight, + topk_weights, + topk_ids, + inplace, + use_int8_w8a8, + use_fp8_w8a16, + w1_scale, + w2_scale, + block_size, + a1_scale, + a2_scale, + is_vnni, +): + return torch.empty_like(x) + + def shared_expert( hidden_states, w1, @@ -52,7 +73,8 @@ def shared_expert( a2_scale=None, is_vnni=True, ): - return sgl_kernel.common_ops.shared_expert_cpu( + # return sgl_kernel.common_ops.shared_expert_cpu( + return torch.ops.sgl_kernel_cpu.shared_expert_cpu.default( hidden_states, w1, w2, @@ -70,6 +92,26 @@ def shared_expert( ) +@torch.library.register_fake("sgl_kernel_cpu::shared_expert_cpu") +def _( + hidden_states, + w1, + w2, + fused_experts_out, + routed_scaling_factor, + inplace, + use_int8_w8a8, + use_fp8_w8a16, + w1_scale, + w2_scale, + block_size, + a1_scale, + a2_scale, + is_vnni, +): + return torch.empty_like(hidden_states) + + def forward_moe_fused( hidden_states, MoEGate_weight, @@ -381,7 +423,8 @@ def grouped_topk( num_expert_group, topk_group, ): - return sgl_kernel.common_ops.grouped_topk_cpu( + # return sgl_kernel.common_ops.grouped_topk_cpu( + return torch.ops.sgl_kernel_cpu.grouped_topk_cpu.default( hidden_states, router_logits, top_k, @@ -391,6 +434,22 @@ def grouped_topk( ) +@torch.library.register_fake("sgl_kernel_cpu::grouped_topk_cpu") +def _( + hidden_states, + router_logits, + top_k, + renormalize, + num_expert_group, + topk_group, +): + shape = (hidden_states.shape[0], top_k) + device = hidden_states.device + topk_weights = torch.empty(shape, device=device, dtype=torch.float32) + topk_ids = torch.empty(shape, device=device, dtype=torch.int) + return topk_weights, topk_ids + + def biased_grouped_topk( hidden_states, router_logits, From 630db4d6198e71fcbb1fcb8d955a2bf6df434ba3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 8 May 2025 11:58:08 +0000 Subject: [PATCH 6/6] support tensor-parallel --- .../sglang/srt/distributed/communication_op.py | 14 -------------- .../sglang/srt/distributed/parallel_state.py | 15 ++++++++++++--- .../srt/model_executor/cpu_compile_runner.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 5 +++++ sgl-kernel/csrc/cpu/interface.cpp | 11 +++++++---- sgl-kernel/csrc/cpu/models/deepseek.cpp | 14 +++++++------- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 18 ++++++++++++------ 7 files changed, 44 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index 98e8e6f87e09..95600edfb410 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -10,15 +10,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" - if input_.is_cpu: - from sglang.srt.distributed import get_tp_group - - shm_comm_op = get_tp_group().shm_comm_op - shm_comm_op.shm_allreduce( - input_, get_tp_group().device_group, torch.distributed.ReduceOp.SUM - ) - return input_ - return get_tp_group().all_reduce(input_) @@ -26,11 +17,6 @@ def tensor_model_parallel_all_gather( input_: torch.Tensor, dim: int = -1 ) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" - if input_.is_cpu: - from sglang.srt.distributed import get_tp_group - - shm_comm_op = get_tp_group().shm_comm_op - return shm_comm_op.shm_allgather(input_, get_tp_group().device_group, dim) return get_tp_group().all_gather(input_, dim) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 5c0442ec38a7..8c4eb4e646b3 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -396,9 +396,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if input_.is_cpu: - import intel_extension_for_pytorch as ipex - - ipex.distributed.all_reduce(input_, group=self.device_group) + torch.ops.sgl_kernel_cpu.shm_allreduce( + input_, + self.device_group.group_name, + "sum", + ) return input_ if not supports_custom_op(): @@ -464,6 +466,13 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if input_.is_cpu: + return torch.ops.sgl_kernel_cpu.shm_allgather( + input_, + self.device_group.group_name, + dim, + ) + # For HPUs, use HPU communicator. hpu_comm = self.hpu_communicator if hpu_comm is not None and not hpu_comm.disabled: diff --git a/python/sglang/srt/model_executor/cpu_compile_runner.py b/python/sglang/srt/model_executor/cpu_compile_runner.py index 5d4183a69c0e..132b0c319111 100644 --- a/python/sglang/srt/model_executor/cpu_compile_runner.py +++ b/python/sglang/srt/model_executor/cpu_compile_runner.py @@ -386,7 +386,7 @@ def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = Fal index = bisect.bisect_left(self.compile_bs, raw_bs) bs = self.compile_bs[index] if bs != raw_bs: - raise + # raise self.seq_lens.fill_(1) self.out_cache_loc.zero_() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 97ad4c0c1139..7bab700607d7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -324,6 +324,11 @@ def init_torch_distributed(self): os.environ["LOCAL_SIZE"] = str(self.tp_size) shm_comm_op.initialize(self.tp_size, self.tp_rank) + # we have to register fake here since output shape depends on tp_size + @torch.library.register_fake("sgl_kernel_cpu::shm_allgather") + def _(data, group_name, dim): + return torch.cat([data] * self.tp_size, dim=dim) + # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, diff --git a/sgl-kernel/csrc/cpu/interface.cpp b/sgl-kernel/csrc/cpu/interface.cpp index 3783dd3654c9..2f52243d4e0f 100644 --- a/sgl-kernel/csrc/cpu/interface.cpp +++ b/sgl-kernel/csrc/cpu/interface.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "shm.h" @@ -49,12 +50,12 @@ void initialize(int size, int rank) { void shm_allreduce( torch::Tensor& data, - c10::intrusive_ptr process_group, - c10d::ReduceOp op) { + std::string group_name, + std::string op) { RECORD_FUNCTION( "sgl-kernel::shm_allreduce", std::vector({data})); - TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported"); + TORCH_CHECK(op == "sum", "Only torch.distributed.ReduceOp.SUM is supported"); auto numel = data.numel(); @@ -75,6 +76,7 @@ void shm_allreduce( if (data_type_fallback || !all_ranks_local_p) { // Fallback to torch distributed allreduce std::vector tensors = {data}; + auto process_group = c10d::resolve_process_group(group_name); process_group->allreduce(tensors)->wait(); } else { all_reduce_outer_loop(data, numel, data_size); @@ -83,7 +85,7 @@ void shm_allreduce( return; } -torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr process_group, int dim) { +torch::Tensor shm_allgather(torch::Tensor& data, std::string group_name, int64_t dim) { RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector({data})); auto numel = data.numel(); @@ -107,6 +109,7 @@ torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr> output_tensors(1); + auto process_group = c10d::resolve_process_group(group_name); auto world_size = process_group->getSize(); for (int i = 0; i < world_size; i++) { output_tensors[0].push_back(torch::empty_like(data)); diff --git a/sgl-kernel/csrc/cpu/models/deepseek.cpp b/sgl-kernel/csrc/cpu/models/deepseek.cpp index 51d8a102510e..8182b2d5ec80 100644 --- a/sgl-kernel/csrc/cpu/models/deepseek.cpp +++ b/sgl-kernel/csrc/cpu/models/deepseek.cpp @@ -43,7 +43,7 @@ extern at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tens std::vector block_size, const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); -extern void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, c10d::ReduceOp op); +extern void shm_allreduce(at::Tensor& data, std::string group_name, std::string op); extern std::tuple grouped_topk_cpu( at::Tensor& hidden_states, @@ -100,8 +100,8 @@ at::Tensor row_parallel_linear_forward( std::optional& bias, int tp_size, int tp_rank, - std::optional> process_group, - std::optional op, + std::optional process_group, + std::optional op, bool use_int8_w8a8, bool use_fp8_w8a16, at::ScalarType out_dtype, @@ -180,8 +180,8 @@ at::Tensor forward_absorb_decode_fused_cpu( std::optional& kv_a_proj_scale, // qkv_proj_with_rope std::optional> block_size, // qkv_proj_with_rope std::optional& bmm_scale, // bmm - std::optional> process_group, // o_proj - std::optional op, // o_proj + std::optional process_group, // o_proj + std::optional op, // o_proj std::optional& o_proj_scale, // o_proj std::optional> o_proj_block_size, // o_proj bool is_vnni // qkv_proj_with_rope, bmm, o_proj @@ -325,8 +325,8 @@ at::Tensor forward_moe_fused_cpu( std::optional> shared_expert_block_size, // shared_expert std::optional& shared_expert_a1_scale, // shared_expert std::optional& shared_expert_a2_scale, // shared_expert - std::optional> process_group, // all_reduce - std::optional op, // all_reduce + std::optional process_group, // all_reduce + std::optional op, // all_reduce bool is_vnni // MoEGate, experts, shared_expert ) { RECORD_FUNCTION("sgl-kernel::forward_moe_fused_cpu", std::vector({ diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index e9e21d16365c..246857355652 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -90,8 +90,8 @@ at::Tensor forward_absorb_decode_fused_cpu( std::optional& kv_a_proj_scale, // qkv_proj_with_rope std::optional> block_size, // qkv_proj_with_rope std::optional& bmm_scale, // bmm - std::optional> process_group, // o_proj - std::optional op, // o_proj + std::optional process_group, // o_proj + std::optional op, // o_proj std::optional& o_proj_scale, // o_proj std::optional> o_proj_block_size, // o_proj bool is_vnni // qkv_proj_with_rope, bmm, o_proj @@ -190,8 +190,8 @@ at::Tensor forward_moe_fused_cpu( std::optional> shared_expert_block_size, // shared_expert std::optional& shared_expert_a1_scale, // shared_expert std::optional& shared_expert_a2_scale, // shared_expert - std::optional> process_group, // all_reduce - std::optional op, // all_reduce + std::optional process_group, // all_reduce + std::optional op, // all_reduce bool is_vnni // MoEGate, experts, shared_expert ); @@ -207,10 +207,10 @@ std::tuple qkv_proj_with_rope( at::Tensor& h void initialize(int size, int rank); // shared mmeory all_reduce -void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, c10d::ReduceOp op); +void shm_allreduce(at::Tensor& data, std::string group_name, std::string op); // shared memory all_gather -at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int dim); +at::Tensor shm_allgather(at::Tensor& data, std::string group_name, int64_t dim); // rope std::tuple rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, @@ -319,4 +319,10 @@ TORCH_LIBRARY(sgl_kernel_cpu, m) { m.def("shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor"); IMPL_CPU(shared_expert_cpu); + + m.def("shm_allreduce(Tensor(a!) data, str group_name, str op) -> ()"); + IMPL_CPU(shm_allreduce); + + m.def("shm_allgather(Tensor data, str group_name, int dim) -> Tensor"); + IMPL_CPU(shm_allgather); }