Skip to content

RFC: Use torch.compile to reduce Python overhead#73

Draft
gau-nernst wants to merge 6 commits intomingfeima:cpu_opt_ww11from
gau-nernst:cpu_compile
Draft

RFC: Use torch.compile to reduce Python overhead#73
gau-nernst wants to merge 6 commits intomingfeima:cpu_opt_ww11from
gau-nernst:cpu_compile

Conversation

@gau-nernst
Copy link

@gau-nernst gau-nernst commented May 9, 2025

Motivation

torch.compile() is usually thought of as a way to fuse and generate efficient kernels. However, in this PR, I propose to use torch.compile() in a slightly unusual way - use torch.compile() to reduce Python overhead.

How does it work?

torch.compile will "flatten" / "in-line" kernel calls, eliminating all Python redirections in SGLang runtime.

Without TP (single NUMA machine)

Before After
image image
Qwen2.5-7B-Instruct Prefill tok/s Decode tok/s
Before 106.12 9.73
After 106.38 10.65

We can see the generated code with TORCH_LOGS=output_code

image

With TP=4 (multi NUMA machine) (I couldn't get profile trace from this system, so only benchmark numbers here)

Qwen2.5-7B-Instruct Prefill tok/s Decode tok/s
Before 1194.46 47.47
After 1192.55 53.65

Note:

  • Currently I only apply torch.compile() for decode, so prefill numbers should stay the same. I think this can be applied for prefill as well, though might be tricky to handle dynamic shapes

Why is this useful?

  • I saw that the Intel team rewrote DeepSeek model code in C++ (deepseek.cpp) to eliminate the Python overhead.
  • To me, this seems like a perfect use case for a compiler like torch.compile() - translate Python code to C++ (note that I still use dedicated VNNI/INT8/FP8... kernels here. I don't expect torch.compile() to replace those kernels).
  • This would be more scalable to support other model architectures x quantization combinations.

Can we do even better?

Like you can see in the stack trace above, there are still some sources of Python overhead:

  • Whenever we enter the torch.compile() code, there is quite a bit of overhead (torch/_dynamo/***)
  • Kernel calls are still done in Python

I played around it in a bit, and it seems like we can use AOTI for this https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html. The basic idea

  • Use torch.export to export the Pytorch program (captured by Dynamo)
  • Run this exported program in C++ torch::inductor::AOTIModelPackageLoader -> we only need a Python wrapper around this
    • I tested torch._inductor.aoti_load_package(), and this still calls each kernel from Python, so it's still slower than using C++ API.

Side note: I think this can be an interesting idea for the CPU Inductor team - generate "glue" code that calls kernels in C++ instead of Python 😄

Update: just found out about this https://docs.pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html. Will re-run the benchmark with this option

What's next?

I'm opening this PR for gather feedback first, since to get this feature merged, it would require quite a lot of changes, and more extensive testing (this PR is just a proof-of-concept. It won't be merged). If you are keen on this feature, we can proceed with the following gradual changes:

  1. Change all PyBind11 to TORCH_LIBRARY registration (+ register fake). This is required to work with torch.compile.
  2. Add cpu_compile_runner.py. This follows closely the existing cuda_graph_runner.py
    • This will need more extensive testing, especially to have a strong guarantee to not recompile during runtime
  3. (As needed) Modify some model code to make sure it works with torch.compile fullgraph=True
    • We need fullgraph=True since there is quite a bit of overhead when entering torch.compile() code -> if there are graph breaks, this feature won't speed things up
  4. (Optional) See if this feature can replace deepseek.cpp (and other C++ functions that are there simply to reduce Python overhead)
    • It might be possible to remove various tricks in deepseek_v2.py file to avoid Python attribute lookup -> simplify the code, easier to upstream.

@mingfeima mingfeima requested a review from chunyuan-w May 12, 2025 01:19
@mingfeima
Copy link
Owner

Wow, thanks for this contribution! Really nice work! We will evaluate internally on our machines to see the results!

@zou3519
Copy link

zou3519 commented May 12, 2025

Btw, regular torch.compile can use the C++ codegen via torch._inductor.config.cpp_wrapper = True, there's no need to use torch.export+AOTI if you don't want to: https://github.com/pytorch/pytorch/blob/daca611465c93ac6b8147e6b7070ce2b4254cfc5/torch/_inductor/config.py#L153-L155

In general with export+AOTI you'd need to manage your compilations yourself vs torch.compile manages the compilations for you.

@gau-nernst
Copy link
Author

@zou3519 Nice seeing you here! Yes I only found out about the C++ inductor wrapper a few days ago, after I opened this PR. I think it's the perfect use case for this 😄
https://docs.pytorch.org/tutorials/prototype/inductor_cpp_wrapper_tutorial.html

@mingfeima
Copy link
Owner

Haha, small world _

Right now deepseek R1 671B on our machine has a latency of roughly 60ms and about 10% results from python overhead (can be easily seen from the trace and logs). Previously we tried torch.compile but no luck so we turned out to use a very large C++ function to cover almost the entire model.

We are evaluating this PR to see if final results are good enough, will come back with measured data soon.

@zou3519 thanks very much for the remainder!

@mingfeima
Copy link
Owner

@chunyuan-w could you please help review this one? Also ask Mingxu to measure performance on our machines, both DDR5 and MRDIMM with All BS configs. Int8 would be enough, don't have to test fp8.

@mingfeima
Copy link
Owner

@gau-nernst this would be enabled with --enable-torch-compile, right ?

@gau-nernst
Copy link
Author

this would be enabled with --enable-torch-compile, right ?

Yes, correct.

I want to note that I only tested Qwen2.5-7B-Instruct on this PR. Other models might not work out-of-the-box, either because torch.compile() can't trace the model without further modifications, or missing custom op registrations (I have only done custom op registrations for some ops).

I didn't test DeepSeek because right now it's calling deepseek.cpp. I can make remove calls to deepseek.cpp and see how torch.compile() solution compares with deepseek.cpp.

I will re-do the benchmarks with torch._inductor.config.cpp_wrapper = True tmr

@gau-nernst
Copy link
Author

Inductor C++ wrapper requires this PR pytorch/pytorch#150143 (not included in 2.7 release) since some ops return None (e.g. decode attention). A workaround is perhaps to return a zero-sized tensor (non-ideal)

Using Inductor C++ wrapper has the same speed as eager. Inspecting TORCH_LOGS=output_code, I see that it's calling custom ops in Python? Using PyTorch 2.6. Is this expected @zou3519?

    {
        py::gil_scoped_acquire acquire;

        RAIIPyObject py_args_12(PyTuple_New(5));
        if (py_args_12.get() == NULL) {
        throw std::runtime_error("PyTuple_New py_args_12 failed");
        }
        PyTuple_SetItem(py_args_12, 0, PyUnicode_FromString("torch.ops.sgl_kernel_cpu.fused_add_rmsnorm_cpu.default"));
        PyTuple_SetItem(py_args_12, 1, PyCapsule_New(reinterpret_cast<void*>(buf47.get()), NULL, NULL));
        PyTuple_SetItem(py_args_12, 2, PyCapsule_New(reinterpret_cast<void*>(buf19.get()), NULL, NULL));
        PyTuple_SetItem(py_args_12, 3, PyCapsule_New(reinterpret_cast<void*>(arg24_1.get()), NULL, NULL));
        PyTuple_SetItem(py_args_12, 4, PyFloat_FromDouble(1e-06));

        // Call the custom op in Python
        RAIIPyObject py_buf48(PyObject_CallObject(custom_op_wrapper, py_args_12));
        if (py_buf48.get() == NULL) {
        if (PyErr_Occurred()) {
        return;
        }
        throw std::runtime_error("PyObject_CallObject torch.ops.sgl_kernel_cpu.fused_add_rmsnorm_cpu.default failed");
        }
    }

@chunyuan-w
Copy link
Collaborator

@chunyuan-w could you please help review this one? Also ask Mingxu to measure performance on our machines, both DDR5 and MRDIMM with All BS configs. Int8 would be enough, don't have to test fp8.

Sure, I’ll check it out today.

@zou3519
Copy link

zou3519 commented May 13, 2025

Inductor C++ wrapper requires this PR pytorch/pytorch#150143 (not included in 2.7 release) since some ops return None (e.g. decode attention). A workaround is perhaps to return a zero-sized tensor (non-ideal)

Using Inductor C++ wrapper has the same speed as eager. Inspecting TORCH_LOGS=output_code, I see that it's calling custom ops in Python? Using PyTorch 2.6. Is this expected @zou3519?

I was not aware that torch.compile with C++ wrapper just decides to call back into Python for the custom operators. Thoughts @desertfire @eellison ?

@eellison
Copy link

If there is a C++ bound custom op, we should be able to avoid that, a la https://docs.pytorch.org/tutorials/advanced/torch_script_custom_ops.html. If the custom op is just bound in python than we will need to call back into python. @zou3519 you would know better than I would the best, most recent way to bind an op in C++.

@zou3519
Copy link

zou3519 commented May 13, 2025

cpp_wrapper and AOTI should be able to codegen a call to Dispatcher().get_operator("fused_add_rmsnorm_cpu").call(...) in C++ code even if the implementation is defined in Python. The implementation will call back into Python anyways, but at least this saves a roundtrip to Python (the code snippet above calls into Python to invoke the operator, which will go into C++, and then call back into the python implementation)

@gau-nernst
Copy link
Author

@zou3519 @eellison I have done custom op registration in C++ in sgl-kernel/csrc/cpu/torch_extension_cpu.cpp, and fake registration in Python in sgl-kernel/python/sgl_kernel/cpu.py (you can see in this PR)

#define IMPL_CPU(op) m.impl(#op, at::kCPU, &op);

TORCH_LIBRARY(sgl_kernel_cpu, m) {
  m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor(b!) residual, Tensor weight, float eps) -> ()");
  IMPL_CPU(fused_add_rmsnorm_cpu);

@eellison
Copy link

I filed this issue: pytorch/pytorch#153478.

@mingfeima
Copy link
Owner

@CaoE do we have any update on this topic?

@CaoE
Copy link

CaoE commented May 19, 2025

@mingfeima Yes, for deepseek, the current PR needs to register some ops to pytorch and compile, and also lacks the processing of sglang custom types for int8/fp8. I have added there parts locally and I'm testing now. At present, for the convenience of debugging, I have not enabled the tp setting. I will test whether it can be supported with turning on tp later.

@mingfeima mingfeima self-requested a review May 19, 2025 02:56
@gau-nernst
Copy link
Author

For DeepSeek, I don't expect torch.compile() to help, since right now it's going through deepseek.cpp. Therefore, I compared Qwen2.5 models in my PR description.

@CaoE
Copy link

CaoE commented May 20, 2025

I found that on a single process, after handling sglang custom types, compile deepseek int8 can run successfully. The overall performance may be improved by 5% at most, and further testing is needed. At the same time, I found that different inputs will cause recompilation, and on multi-process, compile seems to have a deadlock issue, which may be related to recompile. The problem of recompile and compile hanging needs to be further solved.

@benjaminglass1
Copy link

@gau-nernst I just got pytorch/pytorch#154142 merged today, which should avoid calling custom ops through Python in 90% of cases. As you have time, feel free to test with that commit and see if your measured overhead gets resolved.

Specifically, the 10% of remaining uncovered situations include mostly things like input lists and strings (objects with length that are not tensors). Based on a quick scroll through your diff, some of your custom ops will be caught in that 10%, but hopefully a good number will be dispatched properly now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants