Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions nemo_export/vllm_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,11 @@ def ray_infer_fn(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
compute_logprob = inputs.pop("compute_logprob", False)
n_top_logprobs = inputs.pop("n_top_logprobs", 0)
echo = inputs.pop("echo", False)
apply_chat_template = inputs.pop("apply_chat_template", False)

# Apply chat template if requested (for Ray inference only)
if apply_chat_template:
prompts = [self.apply_chat_template(prompt) for prompt in prompts]

# Map HF-style parameters to vLLM parameters
if compute_logprob and n_top_logprobs > 0:
Expand All @@ -570,6 +575,47 @@ def ray_infer_fn(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

return output_dict

def apply_chat_template(self, messages, add_generation_prompt=True):
"""Apply the chat template to messages using the tokenizer.

This method uses the vLLM tokenizer's built-in apply_chat_template method
to format messages according to the model's expected chat format.

Args:
messages: List of message dictionaries with 'role' and 'content' keys,
or a JSON string representation of messages.
add_generation_prompt (bool): Whether to add the generation prompt. Defaults to True.

Returns:
str: The formatted prompt string.

Raises:
ValueError: If the tokenizer does not have a chat template.
"""
import json

# Handle JSON string input
if isinstance(messages, str):
messages = json.loads(messages)

tokenizer = self.model.get_tokenizer()

# Check if tokenizer has chat_template
if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None:
raise ValueError(
"The tokenizer does not have a chat template defined. "
"If you would like to evaluate a chat model, ensure your model's tokenizer has a chat template."
)

# Use tokenizer's apply_chat_template method
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)

return formatted_prompt

def _dict_to_str(self, messages):
"""Serializes dict to str."""
import json
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/deploy/test_hf_ray_oai_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ def mock_hf_deployable_for_logprobs(self):
"input_ids": torch.tensor([[1, 2, 3, 4]]),
"attention_mask": torch.tensor([[1, 1, 1, 1]]),
}
mock_tokenizer.decode.side_effect = (
lambda ids: f"token_{ids[0] if isinstance(ids, list) and len(ids) > 0 else 'unknown'}"
mock_tokenizer.decode.side_effect = lambda ids: (
f"token_{ids[0] if isinstance(ids, list) and len(ids) > 0 else 'unknown'}"
)
mock_tokenizer.eos_token = "</s>"
mock_tokenizer.pad_token = "</s>"
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/export/test_tensorrt_llm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,9 @@ def test_tensorrt_llm_hf_export_copies_tokenizer_files():
patch("nemo_export.tensorrt_llm_hf.LLaMAForCausalLM.from_hugging_face", return_value=mock_model),
patch(
"glob.glob",
side_effect=lambda x: ["/tmp/hf_model/tokenizer.json"]
if "*.json" in x
else ["/tmp/hf_model/tokenizer.model"],
side_effect=lambda x: (
["/tmp/hf_model/tokenizer.json"] if "*.json" in x else ["/tmp/hf_model/tokenizer.model"]
),
),
patch("shutil.copy"),
patch.object(trt_llm_hf, "_load"),
Expand Down
221 changes: 221 additions & 0 deletions tests/unit_tests/export/test_vllm_exporter_ray_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,224 @@ def test_post_process_logprobs_to_OAI_multiple_samples(exporter, mock_llm):
assert len(result["log_probs"]) == 2
assert result["log_probs"][0] == [-0.1, -0.2]
assert result["log_probs"][1] == [-0.3, -0.4]


class TestVLLMExporterApplyChatTemplate:
"""Tests for apply_chat_template functionality added to vLLMExporter."""

@pytest.fixture
def mock_tokenizer_with_chat_template(self):
"""Create a mock tokenizer with chat template support."""
tokenizer = MagicMock()
tokenizer.chat_template = "{% for message in messages %}{{ message.role }}: {{ message.content }}\n{% endfor %}"
tokenizer.apply_chat_template = MagicMock(return_value="<|begin_of_text|>user: Hello\nassistant:")
return tokenizer

@pytest.fixture
def mock_tokenizer_without_chat_template(self):
"""Create a mock tokenizer without chat template support."""
tokenizer = MagicMock()
tokenizer.chat_template = None
return tokenizer

@pytest.fixture
def exporter_with_chat_template(self, mock_tokenizer_with_chat_template):
"""Create vLLMExporter instance with chat template support."""
from nemo_export.vllm_exporter import vLLMExporter

exporter = vLLMExporter()
exporter.model = MagicMock()
exporter.model.get_tokenizer.return_value = mock_tokenizer_with_chat_template
return exporter

@pytest.fixture
def exporter_without_chat_template(self, mock_tokenizer_without_chat_template):
"""Create vLLMExporter instance without chat template support."""
from nemo_export.vllm_exporter import vLLMExporter

exporter = vLLMExporter()
exporter.model = MagicMock()
exporter.model.get_tokenizer.return_value = mock_tokenizer_without_chat_template
return exporter

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_apply_chat_template_with_messages_list(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test apply_chat_template with a list of message dictionaries."""
messages = [{"role": "user", "content": "Hello, how are you?"}]

result = exporter_with_chat_template.apply_chat_template(messages)

assert result == "<|begin_of_text|>user: Hello\nassistant:"
mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once_with(
messages,
tokenize=False,
add_generation_prompt=True,
)

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_apply_chat_template_with_json_string(self, exporter_with_chat_template, mock_tokenizer_with_chat_template):
"""Test apply_chat_template with JSON string input."""
import json

messages = [{"role": "user", "content": "Hello"}]
messages_json = json.dumps(messages)

result = exporter_with_chat_template.apply_chat_template(messages_json)

assert result == "<|begin_of_text|>user: Hello\nassistant:"
# Verify it was called with the parsed list (not the JSON string)
mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once()
call_args = mock_tokenizer_with_chat_template.apply_chat_template.call_args
assert call_args[0][0] == messages # First positional arg should be the parsed messages list

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_apply_chat_template_without_generation_prompt(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test apply_chat_template with add_generation_prompt=False."""
messages = [{"role": "user", "content": "Hello"}]

exporter_with_chat_template.apply_chat_template(messages, add_generation_prompt=False)

mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once_with(
messages,
tokenize=False,
add_generation_prompt=False,
)

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_apply_chat_template_raises_error_when_no_template(self, exporter_without_chat_template):
"""Test apply_chat_template raises ValueError when tokenizer has no chat template."""
messages = [{"role": "user", "content": "Hello"}]

with pytest.raises(ValueError, match="The tokenizer does not have a chat template defined"):
exporter_without_chat_template.apply_chat_template(messages)

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_apply_chat_template_with_multi_turn_conversation(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test apply_chat_template with multi-turn conversation."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]

exporter_with_chat_template.apply_chat_template(messages)

mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once_with(
messages,
tokenize=False,
add_generation_prompt=True,
)

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_ray_infer_fn_with_apply_chat_template(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test ray_infer_fn correctly applies chat template when requested."""
exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["I'm doing well, thank you!"]})

messages = [{"role": "user", "content": "Hello"}]
inputs = {
"prompts": [messages],
"max_tokens": 100,
"apply_chat_template": True,
}

result = exporter_with_chat_template.ray_infer_fn(inputs)

assert "sentences" in result
assert result["sentences"] == ["I'm doing well, thank you!"]
# Verify apply_chat_template was called
mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once()
# Verify forward was called with the formatted prompt string
exporter_with_chat_template.forward.assert_called_once()
call_kwargs = exporter_with_chat_template.forward.call_args[1]
assert call_kwargs["input_texts"] == ["<|begin_of_text|>user: Hello\nassistant:"]

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_ray_infer_fn_without_apply_chat_template(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test ray_infer_fn does not apply chat template when not requested."""
exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["Generated text"]})

inputs = {
"prompts": ["plain text prompt"],
"max_tokens": 100,
"apply_chat_template": False,
}

result = exporter_with_chat_template.ray_infer_fn(inputs)

assert "sentences" in result
# Verify apply_chat_template was NOT called
mock_tokenizer_with_chat_template.apply_chat_template.assert_not_called()
# Verify forward was called with the original prompt
exporter_with_chat_template.forward.assert_called_once()
call_kwargs = exporter_with_chat_template.forward.call_args[1]
assert call_kwargs["input_texts"] == ["plain text prompt"]

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_ray_infer_fn_with_apply_chat_template_default_false(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test ray_infer_fn defaults to not applying chat template."""
exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["Generated text"]})

inputs = {
"prompts": ["plain text prompt"],
"max_tokens": 100,
# apply_chat_template not specified - should default to False
}

result = exporter_with_chat_template.ray_infer_fn(inputs)

assert "sentences" in result
# Verify apply_chat_template was NOT called (default is False)
mock_tokenizer_with_chat_template.apply_chat_template.assert_not_called()

@pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM")
@pytest.mark.run_only_on("GPU")
def test_ray_infer_fn_with_multiple_chat_prompts(
self, exporter_with_chat_template, mock_tokenizer_with_chat_template
):
"""Test ray_infer_fn with multiple chat prompts applies template to each."""
mock_tokenizer_with_chat_template.apply_chat_template.side_effect = [
"<|begin_of_text|>user: Hello\nassistant:",
"<|begin_of_text|>user: Goodbye\nassistant:",
]
exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["Hi there!", "See you later!"]})

messages1 = [{"role": "user", "content": "Hello"}]
messages2 = [{"role": "user", "content": "Goodbye"}]
inputs = {
"prompts": [messages1, messages2],
"max_tokens": 100,
"apply_chat_template": True,
}

result = exporter_with_chat_template.ray_infer_fn(inputs)

assert "sentences" in result
assert len(result["sentences"]) == 2
# Verify apply_chat_template was called twice (once for each prompt)
assert mock_tokenizer_with_chat_template.apply_chat_template.call_count == 2
# Verify forward was called with both formatted prompts
call_kwargs = exporter_with_chat_template.forward.call_args[1]
assert call_kwargs["input_texts"] == [
"<|begin_of_text|>user: Hello\nassistant:",
"<|begin_of_text|>user: Goodbye\nassistant:",
]