Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions llmsql/inference/inference_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
max_new_tokens=256,
temperature=0.7,
tensor_parallel_size=1,
lora_path="path/to/lora"
)

Notes
Expand All @@ -47,6 +48,7 @@
from dotenv import load_dotenv
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from llmsql.config.config import (
DEFAULT_LLMSQL_VERSION,
Expand Down Expand Up @@ -74,6 +76,8 @@ def inference_vllm(
hf_token: str | None = None,
llm_kwargs: dict[str, Any] | None = None,
use_chat_template: bool = True,
# === LoRA Parameters ===
lora_config: dict[str, Any] | None = None, # new optional dict
# === Generation Parameters ===
max_new_tokens: int = 256,
temperature: float = 1.0,
Expand Down Expand Up @@ -105,6 +109,16 @@ def inference_vllm(
'trust_remote_code' are handled separately and will
override values here.

lora_config: Optional dict with LoRA parameters:
- lora_path: Path to the pretrained LoRA adapter (required if enable_lora)
- lora_name: Logical name for the LoRA adapter
- lora_scale: Scaling factor for LoRA weights
- max_lora_rank: Maximum LoRA rank supported by vLLM
LoRA usage rules:
- If `lora_config` is provided, `enable_lora` must be True in `llm_kwargs`.
- If `enable_lora` is True, a valid `lora_config` must be provided.
- Otherwise, an exception is raised to prevent inconsistent configuration.

# Generation:
max_new_tokens: Maximum tokens to generate per sequence.
temperature: Sampling temperature (0.0 = greedy).
Expand All @@ -126,6 +140,8 @@ def inference_vllm(
the first N samples. If a float between 0.0 and 1.0, evaluates the
first X*100% of samples. If None, evaluates all samples (default).



Returns:
List of dicts containing `question_id` and generated `completion`.
"""
Expand Down Expand Up @@ -167,19 +183,41 @@ def inference_vllm(
)
questions = questions[:limit]

# --- Validate LoRA usage ---
enable_lora = llm_kwargs.get("enable_lora", False)
if lora_config is not None and not enable_lora:
raise ValueError(
"LoRA config provided but `enable_lora` is not True in llm_kwargs."
)
if enable_lora and lora_config is None:
raise ValueError("`enable_lora` is True but no `lora_config` was provided.")
if lora_config is not None and not enable_lora:
raise ValueError(
"`lora_config` provided but `enable_lora` is not True in llm_kwargs."
)

# --- init model ---
llm_init_args = {
"model": model_name,
"tokenizer": model_name,
"tensor_parallel_size": tensor_parallel_size,
"trust_remote_code": trust_remote_code,
**llm_kwargs, # User kwargs come first, but explicit params above will override
**llm_kwargs, # user overrides
}

log.info(f"Loading vLLM model '{model_name}' (tp={tensor_parallel_size})...")

llm = LLM(**llm_init_args)

# --- LoRA request ---
lora_request = None
if enable_lora and lora_config is not None:
log.info(f"Loading LoRA adapter from {lora_config['lora_path']}")
lora_request = LoRARequest(
lora_name=lora_config["lora_name"],
lora_path=lora_config["lora_path"],
scaling=lora_config["lora_scale"],
)

tokenizer = llm.get_tokenizer()
if use_chat_template:
use_chat_template = getattr(tokenizer, "chat_template", None) # type: ignore
Expand Down Expand Up @@ -228,7 +266,11 @@ def inference_vllm(

prompts.append(final_prompt)

outputs = llm.generate(prompts, sampling_params)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=lora_request,
)

batch_results: list[dict[str, str]] = []
for q, out in zip(batch, outputs, strict=False):
Expand Down
2 changes: 1 addition & 1 deletion tests/inference/test_limit_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _patch_common_vllm(monkeypatch, tmp_path):
)

fake_llm = MagicMock()
fake_llm.generate.side_effect = lambda prompts, _params: [
fake_llm.generate.side_effect = lambda prompts, *a, **kw: [
MagicMock(outputs=[MagicMock(text=f"SELECT {i}")]) for i in range(len(prompts))
]
monkeypatch.setattr(vllm_mod, "LLM", lambda *a, **kw: fake_llm)
Expand Down