Skip to content

Using kernelize isn't yielding any speed-ups #201

@pramodith

Description

@pramodith

Hey team, I'm trying to use kernelize to load a Qwen3 model and compare the run time of the kernelized model against the raw model implementation.

However, I'm not seeing any speedup when I compare the two. I also don't see any logs indicating that the kernelized model is using a kernel from the hub. Am I doing something wrong here? I'm trying this on colab.

# Cell 1: setup and load model
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.profiler import profile, ProfilerActivity

model_id = "Qwen/Qwen3-0.6B"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    attn_implementation="sdpa"
).to(device).eval()


batch_size, seq_len = 4, 2048
input_ids = torch.randint(
    0, tokenizer.vocab_size, (batch_size, seq_len), device=device
)

import torch
import time
from torch.profiler import profile, record_function, ProfilerActivity, schedule

def profile_model(model, label, iters=20, wait=2, warmup=3, active=10):
    total_steps = wait + warmup + active
    if iters < total_steps:
        iters = total_steps
        
    sched = schedule(wait=wait, warmup=warmup, active=active, repeat=1)
    
    # Lists to store timing
    cuda_latencies = [] 

    with torch.no_grad():
        with profile(
            activities=[ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if torch.cuda.is_available() else []),
            schedule=sched,
            profile_memory=True,
        ) as prof:
            for i in range(iters):
                # Setup GPU events for this specific loop
                start_event = torch.cuda.Event(enable_timing=True)
                end_event = torch.cuda.Event(enable_timing=True)

                start_event.record() # Record start on the GPU stream
                model(input_ids)
                end_event.record()   # Record end on the GPU stream

                torch.cuda.synchronize()
                
                # elapsed_time is in milliseconds
                cuda_latencies.append(start_event.elapsed_time(end_event))
                
                prof.step()

    # Average the active window
    active_latencies = cuda_latencies[-(active):]
    avg_wall_time_ms = sum(active_latencies) / len(active_latencies)

    stats = prof.key_averages()
    total_avg = stats.total_average()

    return {
        "mode": label,
        "avg_wall_time_ms": avg_wall_time_ms,
        "cpu_time_total_ms": total_avg.cpu_time_total / 1e3,
        "cuda_time_total_ms": total_avg.device_time_total / 1e3 if torch.cuda.is_available() else None,
        "cuda_memory_mb": total_avg.device_memory_usage / (1024**2) if torch.cuda.is_available() else None,
    }

from kernels import kernelize, Mode
import logging
logging.basicConfig(level=logging.INFO)

kernelized_model = kernelize(model, mode=Mode.INFERENCE, use_fallback=False)

Results:

|index|mode|avg_wall_time_ms|
|---|---|
|0|eager|2917.5516276041667|
|1|torch_compile|963.7430623372396|
|2|huggingface_kernels|2914.344921875|

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions