From f1ac1a6d492fa3460a32c13a54e5c45fcc9ab0bb Mon Sep 17 00:00:00 2001 From: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Date: Wed, 4 Feb 2026 15:14:43 -0800 Subject: [PATCH] Add apply_chat_template to HF vllm Ray deployment (#581) Signed-off-by: Abhishree Signed-off-by: NeMo Bot --- nemo_export/vllm_exporter.py | 46 ++++ .../deploy/test_hf_ray_oai_format.py | 4 +- .../unit_tests/export/test_tensorrt_llm_hf.py | 6 +- .../export/test_vllm_exporter_ray_infer.py | 221 ++++++++++++++++++ 4 files changed, 272 insertions(+), 5 deletions(-) diff --git a/nemo_export/vllm_exporter.py b/nemo_export/vllm_exporter.py index ccfd490d24..0921662388 100644 --- a/nemo_export/vllm_exporter.py +++ b/nemo_export/vllm_exporter.py @@ -548,6 +548,11 @@ def ray_infer_fn(self, inputs: Dict[str, Any]) -> Dict[str, Any]: compute_logprob = inputs.pop("compute_logprob", False) n_top_logprobs = inputs.pop("n_top_logprobs", 0) echo = inputs.pop("echo", False) + apply_chat_template = inputs.pop("apply_chat_template", False) + + # Apply chat template if requested (for Ray inference only) + if apply_chat_template: + prompts = [self.apply_chat_template(prompt) for prompt in prompts] # Map HF-style parameters to vLLM parameters if compute_logprob and n_top_logprobs > 0: @@ -570,6 +575,47 @@ def ray_infer_fn(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return output_dict + def apply_chat_template(self, messages, add_generation_prompt=True): + """Apply the chat template to messages using the tokenizer. + + This method uses the vLLM tokenizer's built-in apply_chat_template method + to format messages according to the model's expected chat format. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys, + or a JSON string representation of messages. + add_generation_prompt (bool): Whether to add the generation prompt. Defaults to True. + + Returns: + str: The formatted prompt string. + + Raises: + ValueError: If the tokenizer does not have a chat template. + """ + import json + + # Handle JSON string input + if isinstance(messages, str): + messages = json.loads(messages) + + tokenizer = self.model.get_tokenizer() + + # Check if tokenizer has chat_template + if not hasattr(tokenizer, "chat_template") or tokenizer.chat_template is None: + raise ValueError( + "The tokenizer does not have a chat template defined. " + "If you would like to evaluate a chat model, ensure your model's tokenizer has a chat template." + ) + + # Use tokenizer's apply_chat_template method + formatted_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + + return formatted_prompt + def _dict_to_str(self, messages): """Serializes dict to str.""" import json diff --git a/tests/unit_tests/deploy/test_hf_ray_oai_format.py b/tests/unit_tests/deploy/test_hf_ray_oai_format.py index 3976acd826..bbd52ff37b 100644 --- a/tests/unit_tests/deploy/test_hf_ray_oai_format.py +++ b/tests/unit_tests/deploy/test_hf_ray_oai_format.py @@ -580,8 +580,8 @@ def mock_hf_deployable_for_logprobs(self): "input_ids": torch.tensor([[1, 2, 3, 4]]), "attention_mask": torch.tensor([[1, 1, 1, 1]]), } - mock_tokenizer.decode.side_effect = ( - lambda ids: f"token_{ids[0] if isinstance(ids, list) and len(ids) > 0 else 'unknown'}" + mock_tokenizer.decode.side_effect = lambda ids: ( + f"token_{ids[0] if isinstance(ids, list) and len(ids) > 0 else 'unknown'}" ) mock_tokenizer.eos_token = "" mock_tokenizer.pad_token = "" diff --git a/tests/unit_tests/export/test_tensorrt_llm_hf.py b/tests/unit_tests/export/test_tensorrt_llm_hf.py index d78b820169..4d0408aa2c 100644 --- a/tests/unit_tests/export/test_tensorrt_llm_hf.py +++ b/tests/unit_tests/export/test_tensorrt_llm_hf.py @@ -584,9 +584,9 @@ def test_tensorrt_llm_hf_export_copies_tokenizer_files(): patch("nemo_export.tensorrt_llm_hf.LLaMAForCausalLM.from_hugging_face", return_value=mock_model), patch( "glob.glob", - side_effect=lambda x: ["/tmp/hf_model/tokenizer.json"] - if "*.json" in x - else ["/tmp/hf_model/tokenizer.model"], + side_effect=lambda x: ( + ["/tmp/hf_model/tokenizer.json"] if "*.json" in x else ["/tmp/hf_model/tokenizer.model"] + ), ), patch("shutil.copy"), patch.object(trt_llm_hf, "_load"), diff --git a/tests/unit_tests/export/test_vllm_exporter_ray_infer.py b/tests/unit_tests/export/test_vllm_exporter_ray_infer.py index 0f50c55453..7d8f477b7a 100644 --- a/tests/unit_tests/export/test_vllm_exporter_ray_infer.py +++ b/tests/unit_tests/export/test_vllm_exporter_ray_infer.py @@ -772,3 +772,224 @@ def test_post_process_logprobs_to_OAI_multiple_samples(exporter, mock_llm): assert len(result["log_probs"]) == 2 assert result["log_probs"][0] == [-0.1, -0.2] assert result["log_probs"][1] == [-0.3, -0.4] + + +class TestVLLMExporterApplyChatTemplate: + """Tests for apply_chat_template functionality added to vLLMExporter.""" + + @pytest.fixture + def mock_tokenizer_with_chat_template(self): + """Create a mock tokenizer with chat template support.""" + tokenizer = MagicMock() + tokenizer.chat_template = "{% for message in messages %}{{ message.role }}: {{ message.content }}\n{% endfor %}" + tokenizer.apply_chat_template = MagicMock(return_value="<|begin_of_text|>user: Hello\nassistant:") + return tokenizer + + @pytest.fixture + def mock_tokenizer_without_chat_template(self): + """Create a mock tokenizer without chat template support.""" + tokenizer = MagicMock() + tokenizer.chat_template = None + return tokenizer + + @pytest.fixture + def exporter_with_chat_template(self, mock_tokenizer_with_chat_template): + """Create vLLMExporter instance with chat template support.""" + from nemo_export.vllm_exporter import vLLMExporter + + exporter = vLLMExporter() + exporter.model = MagicMock() + exporter.model.get_tokenizer.return_value = mock_tokenizer_with_chat_template + return exporter + + @pytest.fixture + def exporter_without_chat_template(self, mock_tokenizer_without_chat_template): + """Create vLLMExporter instance without chat template support.""" + from nemo_export.vllm_exporter import vLLMExporter + + exporter = vLLMExporter() + exporter.model = MagicMock() + exporter.model.get_tokenizer.return_value = mock_tokenizer_without_chat_template + return exporter + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_apply_chat_template_with_messages_list( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test apply_chat_template with a list of message dictionaries.""" + messages = [{"role": "user", "content": "Hello, how are you?"}] + + result = exporter_with_chat_template.apply_chat_template(messages) + + assert result == "<|begin_of_text|>user: Hello\nassistant:" + mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once_with( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_apply_chat_template_with_json_string(self, exporter_with_chat_template, mock_tokenizer_with_chat_template): + """Test apply_chat_template with JSON string input.""" + import json + + messages = [{"role": "user", "content": "Hello"}] + messages_json = json.dumps(messages) + + result = exporter_with_chat_template.apply_chat_template(messages_json) + + assert result == "<|begin_of_text|>user: Hello\nassistant:" + # Verify it was called with the parsed list (not the JSON string) + mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once() + call_args = mock_tokenizer_with_chat_template.apply_chat_template.call_args + assert call_args[0][0] == messages # First positional arg should be the parsed messages list + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_apply_chat_template_without_generation_prompt( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test apply_chat_template with add_generation_prompt=False.""" + messages = [{"role": "user", "content": "Hello"}] + + exporter_with_chat_template.apply_chat_template(messages, add_generation_prompt=False) + + mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once_with( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_apply_chat_template_raises_error_when_no_template(self, exporter_without_chat_template): + """Test apply_chat_template raises ValueError when tokenizer has no chat template.""" + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(ValueError, match="The tokenizer does not have a chat template defined"): + exporter_without_chat_template.apply_chat_template(messages) + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_apply_chat_template_with_multi_turn_conversation( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test apply_chat_template with multi-turn conversation.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + + exporter_with_chat_template.apply_chat_template(messages) + + mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once_with( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_ray_infer_fn_with_apply_chat_template( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test ray_infer_fn correctly applies chat template when requested.""" + exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["I'm doing well, thank you!"]}) + + messages = [{"role": "user", "content": "Hello"}] + inputs = { + "prompts": [messages], + "max_tokens": 100, + "apply_chat_template": True, + } + + result = exporter_with_chat_template.ray_infer_fn(inputs) + + assert "sentences" in result + assert result["sentences"] == ["I'm doing well, thank you!"] + # Verify apply_chat_template was called + mock_tokenizer_with_chat_template.apply_chat_template.assert_called_once() + # Verify forward was called with the formatted prompt string + exporter_with_chat_template.forward.assert_called_once() + call_kwargs = exporter_with_chat_template.forward.call_args[1] + assert call_kwargs["input_texts"] == ["<|begin_of_text|>user: Hello\nassistant:"] + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_ray_infer_fn_without_apply_chat_template( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test ray_infer_fn does not apply chat template when not requested.""" + exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["Generated text"]}) + + inputs = { + "prompts": ["plain text prompt"], + "max_tokens": 100, + "apply_chat_template": False, + } + + result = exporter_with_chat_template.ray_infer_fn(inputs) + + assert "sentences" in result + # Verify apply_chat_template was NOT called + mock_tokenizer_with_chat_template.apply_chat_template.assert_not_called() + # Verify forward was called with the original prompt + exporter_with_chat_template.forward.assert_called_once() + call_kwargs = exporter_with_chat_template.forward.call_args[1] + assert call_kwargs["input_texts"] == ["plain text prompt"] + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_ray_infer_fn_with_apply_chat_template_default_false( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test ray_infer_fn defaults to not applying chat template.""" + exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["Generated text"]}) + + inputs = { + "prompts": ["plain text prompt"], + "max_tokens": 100, + # apply_chat_template not specified - should default to False + } + + result = exporter_with_chat_template.ray_infer_fn(inputs) + + assert "sentences" in result + # Verify apply_chat_template was NOT called (default is False) + mock_tokenizer_with_chat_template.apply_chat_template.assert_not_called() + + @pytest.mark.skipif(not HAVE_VLLM, reason="Need to enable virtual environment for vLLM") + @pytest.mark.run_only_on("GPU") + def test_ray_infer_fn_with_multiple_chat_prompts( + self, exporter_with_chat_template, mock_tokenizer_with_chat_template + ): + """Test ray_infer_fn with multiple chat prompts applies template to each.""" + mock_tokenizer_with_chat_template.apply_chat_template.side_effect = [ + "<|begin_of_text|>user: Hello\nassistant:", + "<|begin_of_text|>user: Goodbye\nassistant:", + ] + exporter_with_chat_template.forward = MagicMock(return_value={"sentences": ["Hi there!", "See you later!"]}) + + messages1 = [{"role": "user", "content": "Hello"}] + messages2 = [{"role": "user", "content": "Goodbye"}] + inputs = { + "prompts": [messages1, messages2], + "max_tokens": 100, + "apply_chat_template": True, + } + + result = exporter_with_chat_template.ray_infer_fn(inputs) + + assert "sentences" in result + assert len(result["sentences"]) == 2 + # Verify apply_chat_template was called twice (once for each prompt) + assert mock_tokenizer_with_chat_template.apply_chat_template.call_count == 2 + # Verify forward was called with both formatted prompts + call_kwargs = exporter_with_chat_template.forward.call_args[1] + assert call_kwargs["input_texts"] == [ + "<|begin_of_text|>user: Hello\nassistant:", + "<|begin_of_text|>user: Goodbye\nassistant:", + ]