From 400a0113c2b9fc5afd84fdd9dac4db00bd013a5e Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Wed, 23 Jul 2025 11:07:58 +0200 Subject: [PATCH 01/12] add gemini and openai remote inference engine --- src/data/message_formats.py | 149 +++++++++++++++++++++++------- src/data/query_item.py | 4 +- src/eval/eval_models.py | 2 +- src/models/gemini_inference.py | 44 +++++++++ src/models/openai_inference.py | 41 ++++++++ src/models/remote_inference.py | 13 +++ src/reasoning/reasoning_engine.py | 2 +- src/train/train_qwen.py | 11 +-- tests/test_intern_vl_inference.py | 2 +- tests/test_message_format.py | 79 +++++++++------- tests/test_models_eval.py | 79 ++++++++++------ tests/test_remote_inference.py | 36 ++++++++ 12 files changed, 354 insertions(+), 108 deletions(-) create mode 100644 src/models/gemini_inference.py create mode 100644 src/models/openai_inference.py create mode 100644 src/models/remote_inference.py create mode 100644 tests/test_remote_inference.py diff --git a/src/data/message_formats.py b/src/data/message_formats.py index f4e7231..3fea42f 100644 --- a/src/data/message_formats.py +++ b/src/data/message_formats.py @@ -12,7 +12,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: pass @@ -25,7 +25,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: content = [] if system_prompt: content.append({"type": "text", "text": system_prompt}) @@ -52,10 +52,12 @@ def format( ) content.append({"type": "text", "text": f"Context Answer: {context_a}"}) - return { - "role": "user", - "content": content, - } + return [ + { + "role": "user", + "content": content, + } + ] class QwenTrainingMessageFormat(MessageFormat): @@ -71,7 +73,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: user_content = [] if system_prompt: user_content.append({"type": "text", "text": system_prompt}) @@ -99,20 +101,18 @@ def format( {"type": "text", "text": f"Context Answer: {context_a}"} ) - return { - "messages": [ - { - "role": "user", - "content": user_content, - }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": answer}, - ], - }, - ] - } + return [ + { + "role": "user", + "content": user_content, + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": answer}, + ], + }, + ] class InternVLMessageFormat(MessageFormat): @@ -124,7 +124,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: full_prompt = "" if system_prompt: full_prompt += system_prompt + "\n\n" @@ -139,10 +139,12 @@ def format( f"\n\nContext Question: {context_q}\nContext Answer: {context_a}" ) - return { - "text": full_prompt, - "image_path": image_path, - } + return [ + { + "text": full_prompt, + "image_path": image_path, + } + ] class GemmaMessageFormat(MessageFormat): @@ -154,7 +156,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: content = [ { "type": "image", @@ -179,7 +181,92 @@ def format( ) content.append({"type": "text", "text": f"Context Answer: {context_a}"}) - return { - "role": "user", - "content": content, - } + return [ + { + "role": "user", + "content": content, + } + ] + + +class OpenAIMessageFormat(MessageFormat): + def format( + self, + question: str, + image_path: str, + system_prompt: str = None, + answer: Optional[str] = None, + key_object_info: Optional[dict] = None, + context: Optional[List[Tuple[str, str]]] = None, + ) -> List[Dict[str, Any]]: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + user_content = [] + if image_path: + user_content.append( + { + "type": "image_url", + "image_url": {"url": f"file://{image_path}"}, + } + ) + + user_content.append({"type": "text", "text": "Question: " + question}) + + if key_object_info: + user_content.append( + { + "type": "text", + "text": "Key object infos:\n" + str(key_object_info), + } + ) + + if context: + for context_q, context_a in context: + user_content.append( + {"type": "text", "text": f"Context Question: {context_q}"} + ) + user_content.append( + {"type": "text", "text": f"Context Answer: {context_a}"} + ) + + messages.append({"role": "user", "content": user_content}) + + return messages + + +class GeminiMessageFormat(MessageFormat): + def format( + self, + question: str, + image_path: str, + system_prompt: str = None, + answer: Optional[str] = None, + key_object_info: Optional[dict] = None, + context: Optional[List[Tuple[str, str]]] = None, + ) -> List[Dict[str, Any]]: + import PIL.Image + + parts = [] + + if system_prompt: + parts.append(system_prompt) + + if image_path: + image = PIL.Image.open(image_path) + parts.append(image) + + prompt_parts = ["Question: " + question] + + if key_object_info: + prompt_parts.append("Key object infos:\n" + str(key_object_info)) + + if context: + for context_q, context_a in context: + prompt_parts.append(f"Context Question: {context_q}") + prompt_parts.append(f"Context Answer: {context_a}") + + parts.append("\n".join(prompt_parts)) + + return parts diff --git a/src/data/query_item.py b/src/data/query_item.py index a138034..bd1d1cc 100644 --- a/src/data/query_item.py +++ b/src/data/query_item.py @@ -15,11 +15,11 @@ class QueryItem: system_prompt: str = None ground_truth_answer: Optional[str] = None - formatted_message: Optional[Dict[str, Any]] = None + formatted_message: Optional[List[Dict[str, Any]]] = None context_pairs: List[Tuple[str, str]] = field(default_factory=list) - def format_message(self, formatter: MessageFormat) -> Dict[str, Any]: + def format_message(self, formatter: MessageFormat) -> List[Dict[str, Any]]: self.formatted_message = formatter.format( question=self.question, image_path=self.image_path, diff --git a/src/eval/eval_models.py b/src/eval/eval_models.py index 9c203f6..279e4ff 100644 --- a/src/eval/eval_models.py +++ b/src/eval/eval_models.py @@ -54,7 +54,7 @@ def evaluate_model( if use_reasoning: batch = reasoning_engine.process_batch(batch) - formatted_messages = [[item.formatted_message] for item in batch] + formatted_messages = [item.formatted_message for item in batch] batch_results = engine.predict_batch(formatted_messages) diff --git a/src/models/gemini_inference.py b/src/models/gemini_inference.py new file mode 100644 index 0000000..7f60f78 --- /dev/null +++ b/src/models/gemini_inference.py @@ -0,0 +1,44 @@ +from typing import Dict, List, Optional + +import dotenv +from google import genai +from google.genai import types + +from src.data.message_formats import GeminiMessageFormat +from src.models.remote_inference import RemoteInferenceEngine +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class GeminiInferenceEngine(RemoteInferenceEngine): + def __init__(self, model: Optional[str] = "gemini-2.0-flash"): + super().__init__(model_path=model) + dotenv.load_dotenv() + self.client = genai.Client() + self.model = model + self.message_formatter = GeminiMessageFormat() + + def predict_batch(self, messages: List[List[Dict]]): + responses = [] + total_tokens = 0 + + for msg in messages: + response = self.client.models.generate_content( + model=self.model, + contents=msg, + config=types.GenerateContentConfig( + temperature=0.6, + max_output_tokens=128, + ), + ) + + responses.append(response.text.strip()) + total_tokens += response.usage_metadata.total_token_count + + logger.debug( + f"Generated {len(responses)} responses for batch of size {len(messages)}" + ) + logger.debug(f"Total tokens used: {total_tokens}") + + return responses diff --git a/src/models/openai_inference.py b/src/models/openai_inference.py new file mode 100644 index 0000000..8bc984d --- /dev/null +++ b/src/models/openai_inference.py @@ -0,0 +1,41 @@ +import os +from typing import Dict, List, Optional + +import dotenv +from openai import OpenAI + +from src.data.message_formats import OpenAIMessageFormat +from src.models.remote_inference import RemoteInferenceEngine +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class OpenAIInferenceEngine(RemoteInferenceEngine): + def __init__(self, model: Optional[str] = "gpt-4.1"): + super().__init__(model_path=model) + dotenv.load_dotenv() + api_key = os.getenv("OPENAI_API_KEY") + self.client = OpenAI(api_key=api_key) + self.model = model + self.message_formatter = OpenAIMessageFormat() + + def predict_batch(self, messages: List[List[Dict]]): + responses = [] + total_tokens = 0 + for msg in messages: + response = self.client.chat.completions.create( + model=self.model, + messages=msg, + temperature=0.6, + max_tokens=128, + ) + responses.append(response.output_text) + total_tokens += response.usage.total_tokens + + logger.debug( + f"Generated {len(responses)} responses for batch of size {len(messages)}" + ) + logger.debug(f"Total tokens used: {total_tokens}") + + return responses diff --git a/src/models/remote_inference.py b/src/models/remote_inference.py new file mode 100644 index 0000000..ec2073b --- /dev/null +++ b/src/models/remote_inference.py @@ -0,0 +1,13 @@ +from abc import abstractmethod +from typing import Dict, List + +from src.models.base_inference import BaseInferenceEngine + + +class RemoteInferenceEngine(BaseInferenceEngine): + def load_model(self) -> None: + pass + + @abstractmethod + def predict_batch(self, messages: List[List[Dict]]) -> list[str]: + pass diff --git a/src/reasoning/reasoning_engine.py b/src/reasoning/reasoning_engine.py index 816548a..f848a70 100644 --- a/src/reasoning/reasoning_engine.py +++ b/src/reasoning/reasoning_engine.py @@ -47,7 +47,7 @@ def process_batch(self, batch_items: List[QueryItem]) -> List[QueryItem]: if descriptor_items: descriptor_messages = [ - [desc_item.formatted_message] for _, desc_item in descriptor_items + desc_item.formatted_message for _, desc_item in descriptor_items ] descriptor_answers = self.engine.predict_batch(descriptor_messages) diff --git a/src/train/train_qwen.py b/src/train/train_qwen.py index f97c6be..97282b9 100644 --- a/src/train/train_qwen.py +++ b/src/train/train_qwen.py @@ -265,12 +265,9 @@ def train( def collator(batch: Any): texts = [ - engine.processor.apply_chat_template(data["messages"], tokenize=False) - for data in batch + engine.processor.apply_chat_template(data, tokenize=False) for data in batch ] - image_inputs, video_inputs = process_vision_info( - [data["messages"][0] for data in batch] - ) + image_inputs, video_inputs = process_vision_info([data[0] for data in batch]) processed_batch = engine.processor( text=texts, @@ -287,12 +284,12 @@ def collator(batch: Any): assistant_idx = next( j for data in batch - for j, m in enumerate(data["messages"]) + for j, m in enumerate(data) if m["role"] == "assistant" ) pre_text = engine.processor.apply_chat_template( - data["messages"][:assistant_idx], tokenize=False + data[:assistant_idx], tokenize=False ) pre_tokens = engine.processor.tokenizer(pre_text, return_tensors="pt")[ "input_ids" diff --git a/tests/test_intern_vl_inference.py b/tests/test_intern_vl_inference.py index b39a9b9..7370264 100644 --- a/tests/test_intern_vl_inference.py +++ b/tests/test_intern_vl_inference.py @@ -57,7 +57,7 @@ def test_internvl_model_predict(self): results = [] for batch in dataloader: query_items = batch - formatted_messages = [[item.formatted_message] for item in query_items] + formatted_messages = [item.formatted_message for item in query_items] predictions = engine.predict_batch(formatted_messages) results.extend(predictions) diff --git a/tests/test_message_format.py b/tests/test_message_format.py index f9fa58c..54de524 100644 --- a/tests/test_message_format.py +++ b/tests/test_message_format.py @@ -24,20 +24,25 @@ def test_format_of_qwen_message(self): context=context, ) - expected_message = { - "role": "user", - "content": [ - {"type": "text", "text": system_prompt}, - {"type": "text", "text": "Question: " + question}, - {"type": "image", "image": "file:///path/to/your/image.jpg"}, - { - "type": "text", - "text": "Key object infos:\n{'object': 'car', 'color': 'red'}", - }, - {"type": "text", "text": "Context Question: What is this?"}, - {"type": "text", "text": "Context Answer: This is a car."}, - ], - } + expected_message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": system_prompt}, + {"type": "text", "text": "Question: " + question}, + { + "type": "image", + "image": "file:///path/to/your/image.jpg", + }, + { + "type": "text", + "text": "Key object infos:\n{'object': 'car', 'color': 'red'}", + }, + {"type": "text", "text": "Context Question: What is this?"}, + {"type": "text", "text": "Context Answer: This is a car."}, + ], + } + ] self.assertEqual( formatted_message, expected_message, @@ -60,10 +65,12 @@ def test_format_of_internvl_message(self): context=context, ) - expected_message = { - "text": "This is the system prompt\n\nQuestion: What is the color of the car?\n\nKey object infos:\n{'object': 'car', 'color': 'blue'}\n\nContext Question: What is this?\nContext Answer: This is a car.", - "image_path": image_path, - } + expected_message = [ + { + "text": "This is the system prompt\n\nQuestion: What is the color of the car?\n\nKey object infos:\n{'object': 'car', 'color': 'blue'}\n\nContext Question: What is this?\nContext Answer: This is a car.", + "image_path": image_path, + } + ] self.assertEqual( formatted_message, expected_message, @@ -86,23 +93,25 @@ def test_format_of_gemma_message(self): context=context, ) - expected_message = { - "role": "user", - "content": [ - {"type": "image", "image": image_path}, - {"type": "text", "text": system_prompt}, - {"type": "text", "text": "Question: " + question}, - { - "type": "text", - "text": "Key object infos:\n{'country': 'France', 'capital': 'Paris'}", - }, - {"type": "text", "text": "Context Question: What is this?"}, - { - "type": "text", - "text": "Context Answer: This is a map of France.", - }, - ], - } + expected_message = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": system_prompt}, + {"type": "text", "text": "Question: " + question}, + { + "type": "text", + "text": "Key object infos:\n{'country': 'France', 'capital': 'Paris'}", + }, + {"type": "text", "text": "Context Question: What is this?"}, + { + "type": "text", + "text": "Context Answer: This is a map of France.", + }, + ], + } + ] self.assertEqual( formatted_message, diff --git a/tests/test_models_eval.py b/tests/test_models_eval.py index 05e0844..beb7873 100644 --- a/tests/test_models_eval.py +++ b/tests/test_models_eval.py @@ -7,11 +7,28 @@ from src.constants import data_dir from src.eval.eval_models import evaluate_model +from src.models.gemini_inference import GeminiInferenceEngine from src.models.intern_vl_inference import InternVLInferenceEngine from src.models.qwen_vl_inference import QwenVLInferenceEngine from src.utils.utils import is_cuda, sanitize_model_name +def check_eval_files(output_dir, output_file_name, submission_file_name): + output_file = os.path.join(output_dir, output_file_name) + submission_file = os.path.join(output_dir, submission_file_name) + + assert os.path.exists(output_file), "Output file should be created." + assert os.path.exists(submission_file), "Submission file should be created." + + with open(output_file, "r") as f: + results = json.load(f) + assert len(results) > 0, "Results should not be empty." + for result in results: + assert "id" in result + assert "question" in result + assert "answer" in result + + @pytest.mark.eval class TestModelEvaluation(unittest.TestCase): def test_intern_vl_eval(self): @@ -32,28 +49,18 @@ def test_intern_vl_eval(self): test_set_size=1, use_reasoning=True, use_grid=True, - approach_name="test_models_eval", + approach_name="test_intern_vl_eval", ) model_dir = sanitize_model_name(engine.model_path) output_dir = os.path.join(data_dir, "output", model_dir) - output_file = os.path.join(output_dir, "test_models_eval_output.json") - submission_file = os.path.join(output_dir, "test_models_eval_submission.json") - - self.assertTrue(os.path.exists(output_file), "Output file should be created.") - self.assertTrue( - os.path.exists(submission_file), - "Submission file should be created.", + check_eval_files( + output_dir, + "test_intern_vl_eval_output.json", + "test_intern_vl_eval_submission.json", ) - with open(output_file, "r") as f: - results = json.load(f) - self.assertGreater(len(results), 0, "Results should not be empty.") - for result in results: - self.assertIn("id", result) - self.assertIn("question", result) - self.assertIn("answer", result) - + @unittest.skip("Skipping Qwen evaluation test") def test_qwen_eval(self): if is_cuda(): engine = QwenVLInferenceEngine( @@ -66,24 +73,36 @@ def test_qwen_eval(self): engine = QwenVLInferenceEngine("Qwen/Qwen2.5-VL-3B-Instruct") evaluate_model( - engine=engine, dataset_split="val", batch_size=1, test_set_size=1 + engine=engine, + dataset_split="val", + batch_size=1, + test_set_size=1, + approach_name="test_qwen_eval", ) model_dir = sanitize_model_name(engine.model_path) output_dir = os.path.join(data_dir, "output", model_dir) - output_file = os.path.join(output_dir, "test_test_test_output.json") - submission_file = os.path.join(output_dir, "test_test_test_submission.json") + check_eval_files( + output_dir, + "test_qwen_eval_output.json", + "test_qwen_eval_submission.json", + ) - self.assertTrue(os.path.exists(output_file), "Output file should be created.") - self.assertTrue( - os.path.exists(submission_file), - "Submission file should be created.", + def test_gemini_eval(self): + engine = GeminiInferenceEngine(model="gemini-2.0-flash") + evaluate_model( + engine=engine, + dataset_split="val", + batch_size=1, + test_set_size=1, + use_grid=True, + approach_name="test_gemini_eval", ) - with open(output_file, "r") as f: - results = json.load(f) - self.assertGreater(len(results), 0, "Results should not be empty.") - for result in results: - self.assertIn("id", result) - self.assertIn("question", result) - self.assertIn("answer", result) + model_dir = engine.model + output_dir = os.path.join(data_dir, "output", model_dir) + check_eval_files( + output_dir, + "test_gemini_eval_output.json", + "test_gemini_eval_submission.json", + ) diff --git a/tests/test_remote_inference.py b/tests/test_remote_inference.py new file mode 100644 index 0000000..6c13f99 --- /dev/null +++ b/tests/test_remote_inference.py @@ -0,0 +1,36 @@ +import unittest + +import src.data.message_formats as message_formats +from src.models.gemini_inference import GeminiInferenceEngine +from src.models.openai_inference import OpenAIInferenceEngine + + +class TestRemoteInferenceEngine(unittest.TestCase): + @unittest.skip("Skipping OpenAI inference test") + def test_openai_predict_batch(self): + engine = OpenAIInferenceEngine() + # Prepare a batch of messages in OpenAI format + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + ] + results = engine.predict_batch(messages) + print(results) + + def test_gemini_predict_batch(self): + engine = GeminiInferenceEngine(model="gemini-2.0-flash") + # Prepare a batch of messages in Gemini format + message = message_formats.GeminiMessageFormat().format( + question="What is the capital of France?", + image_path=None, + system_prompt="You are a helpful assistant.", + ) + + results = engine.predict_batch([message]) + print(results) + + +if __name__ == "__main__": + unittest.main() From 64c32bd2e2bcfa6151f735514a2b3573a59e231b Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Wed, 23 Jul 2025 11:44:12 +0200 Subject: [PATCH 02/12] add basic anthropic --- src/data/message_formats.py | 49 +++++++++++++++++++++++++++++ src/models/anthropic_inference.py | 52 +++++++++++++++++++++++++++++++ tests/test_models_eval.py | 22 +++++++++++++ tests/test_remote_inference.py | 12 +++++++ 4 files changed, 135 insertions(+) create mode 100644 src/models/anthropic_inference.py diff --git a/src/data/message_formats.py b/src/data/message_formats.py index 3fea42f..2fb2c04 100644 --- a/src/data/message_formats.py +++ b/src/data/message_formats.py @@ -270,3 +270,52 @@ def format( parts.append("\n".join(prompt_parts)) return parts + + +class AnthropicMessageFormat(MessageFormat): + def format( + self, + question: str, + image_path: str, + system_prompt: str = None, + answer: Optional[str] = None, + key_object_info: Optional[dict] = None, + context: Optional[List[Tuple[str, str]]] = None, + ) -> List[Dict[str, Any]]: + import base64 + from pathlib import Path + + user_content = [] + if image_path: + image_bytes = Path(image_path).read_bytes() + image_b64 = base64.b64encode(image_bytes).decode("utf-8") + user_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_b64, + }, + } + ) + + messages = [] + + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + text_parts = ["Question: " + question] + + if key_object_info: + text_parts.append("Key object infos:\n" + str(key_object_info)) + if context: + for context_q, context_a in context: + text_parts.append(f"Context Question: {context_q}") + text_parts.append(f"Context Answer: {context_a}") + + user_content.append({"type": "text", "text": "\n".join(text_parts)}) + + messages.append({"role": "user", "content": user_content}) + + return messages diff --git a/src/models/anthropic_inference.py b/src/models/anthropic_inference.py new file mode 100644 index 0000000..3508d2f --- /dev/null +++ b/src/models/anthropic_inference.py @@ -0,0 +1,52 @@ +import os +from typing import Dict, List, Optional + +import anthropic +import dotenv + +from src.data.message_formats import AnthropicMessageFormat +from src.models.remote_inference import RemoteInferenceEngine +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class AnthropicInferenceEngine(RemoteInferenceEngine): + def __init__(self, model: Optional[str] = "claude-3-5-haiku-20241022"): + super().__init__(model_path=model) + dotenv.load_dotenv() + api_key = os.getenv("ANTHROPIC_API_KEY") + self.client = anthropic.Anthropic(api_key=api_key) + self.model = model + self.message_formatter = AnthropicMessageFormat() + + def predict_batch(self, messages: List[List[Dict]]): + responses = [] + total_tokens = 0 + for msg in messages: + system_prompt = None + filtered_msg = [] + for m in msg: + if m["role"] == "system": + if system_prompt is None: + system_prompt = m["content"] + else: + system_prompt += "\n" + m["content"] + else: + filtered_msg.append(m) + + response = self.client.messages.create( + model=self.model, + messages=filtered_msg, + temperature=0.6, + max_tokens=128, + ) + responses.append(response.content[0].text.strip()) + total_tokens += response.usage.input_tokens + response.usage.output_tokens + + logger.debug( + f"Generated {len(responses)} responses for batch of size {len(messages)}" + ) + logger.debug(f"Total tokens used: {total_tokens}") + + return responses diff --git a/tests/test_models_eval.py b/tests/test_models_eval.py index beb7873..c308f17 100644 --- a/tests/test_models_eval.py +++ b/tests/test_models_eval.py @@ -7,6 +7,7 @@ from src.constants import data_dir from src.eval.eval_models import evaluate_model +from src.models.anthropic_inference import AnthropicInferenceEngine from src.models.gemini_inference import GeminiInferenceEngine from src.models.intern_vl_inference import InternVLInferenceEngine from src.models.qwen_vl_inference import QwenVLInferenceEngine @@ -96,6 +97,7 @@ def test_gemini_eval(self): batch_size=1, test_set_size=1, use_grid=True, + use_system_prompt=True, approach_name="test_gemini_eval", ) @@ -106,3 +108,23 @@ def test_gemini_eval(self): "test_gemini_eval_output.json", "test_gemini_eval_submission.json", ) + + def test_anthropic_eval(self): + engine = AnthropicInferenceEngine(model="claude-3-5-haiku-20241022") + evaluate_model( + engine=engine, + dataset_split="val", + batch_size=1, + test_set_size=1, + use_grid=True, + use_system_prompt=True, + approach_name="test_anthropic_eval", + ) + + model_dir = sanitize_model_name(engine.model_path) + output_dir = os.path.join(data_dir, "output", model_dir) + check_eval_files( + output_dir, + "test_anthropic_eval_output.json", + "test_anthropic_eval_submission.json", + ) diff --git a/tests/test_remote_inference.py b/tests/test_remote_inference.py index 6c13f99..0df0f4f 100644 --- a/tests/test_remote_inference.py +++ b/tests/test_remote_inference.py @@ -1,6 +1,7 @@ import unittest import src.data.message_formats as message_formats +from src.models.anthropic_inference import AnthropicInferenceEngine from src.models.gemini_inference import GeminiInferenceEngine from src.models.openai_inference import OpenAIInferenceEngine @@ -31,6 +32,17 @@ def test_gemini_predict_batch(self): results = engine.predict_batch([message]) print(results) + def test_anthropic_predict_batch(self): + engine = AnthropicInferenceEngine(model="claude-3-5-haiku-20241022") + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + ] + results = engine.predict_batch(messages) + print(results) + if __name__ == "__main__": unittest.main() From b55302d1aa2db5d5ece1eda3d04bf6174f5a1abf Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Wed, 23 Jul 2025 11:47:52 +0200 Subject: [PATCH 03/12] ignore remote inference tests --- tests/test_remote_inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_remote_inference.py b/tests/test_remote_inference.py index 0df0f4f..1a21c12 100644 --- a/tests/test_remote_inference.py +++ b/tests/test_remote_inference.py @@ -1,11 +1,14 @@ import unittest +import pytest + import src.data.message_formats as message_formats from src.models.anthropic_inference import AnthropicInferenceEngine from src.models.gemini_inference import GeminiInferenceEngine from src.models.openai_inference import OpenAIInferenceEngine +@pytest.mark.inference class TestRemoteInferenceEngine(unittest.TestCase): @unittest.skip("Skipping OpenAI inference test") def test_openai_predict_batch(self): From a143fef0d942d3cb9985f31931e0b048846e25ea Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 21:10:43 +0200 Subject: [PATCH 04/12] add llm sdks to requirements --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index 8e0bfb3..2e5d04d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,6 @@ gdown~=5.2.0 pre-commit~=4.2.0 peft~=0.15.2 trl~=0.18.1 +anthropic +openai +google-genai From bec1578b499910175d297a1706dd3cf4dfb8aa4c Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 21:35:15 +0200 Subject: [PATCH 05/12] add dotenv --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 0c1526b..6e7cb09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,5 +16,6 @@ trl~=0.18.1 anthropic~=0.58.2 openai~=1.97.1 google-genai~=1.24.0 +dotenv~=1.1.1 polars==1.31.0 ultralytics==8.3.168 From dcd5fa954bc25e2348601aa4429e48558acf1723 Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 21:36:53 +0200 Subject: [PATCH 06/12] cleanup remote tests --- tests/test_remote_inference.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_remote_inference.py b/tests/test_remote_inference.py index 1a21c12..ddb015d 100644 --- a/tests/test_remote_inference.py +++ b/tests/test_remote_inference.py @@ -13,7 +13,6 @@ class TestRemoteInferenceEngine(unittest.TestCase): @unittest.skip("Skipping OpenAI inference test") def test_openai_predict_batch(self): engine = OpenAIInferenceEngine() - # Prepare a batch of messages in OpenAI format messages = [ [ {"role": "system", "content": "You are a helpful assistant."}, @@ -21,11 +20,11 @@ def test_openai_predict_batch(self): ] ] results = engine.predict_batch(messages) - print(results) + + self.assertTrue(len(results) > 0, "Results should not be empty") def test_gemini_predict_batch(self): engine = GeminiInferenceEngine(model="gemini-2.0-flash") - # Prepare a batch of messages in Gemini format message = message_formats.GeminiMessageFormat().format( question="What is the capital of France?", image_path=None, @@ -33,7 +32,8 @@ def test_gemini_predict_batch(self): ) results = engine.predict_batch([message]) - print(results) + + self.assertTrue(len(results) > 0, "Results should not be empty") def test_anthropic_predict_batch(self): engine = AnthropicInferenceEngine(model="claude-3-5-haiku-20241022") @@ -44,7 +44,8 @@ def test_anthropic_predict_batch(self): ] ] results = engine.predict_batch(messages) - print(results) + + self.assertTrue(len(results) > 0, "Results should not be empty") if __name__ == "__main__": From 4464968d5a64fc1560f1b0301ae014985549e294 Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 21:37:32 +0200 Subject: [PATCH 07/12] add provider argument to main --- main.py | 42 +++++++++++++++++++++++++------ src/models/intern_vl_inference.py | 2 +- src/models/qwen_vl_inference.py | 2 +- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 5cef7cb..414f7e7 100644 --- a/main.py +++ b/main.py @@ -3,16 +3,34 @@ from src.constants import model_dir from src.eval.eval_models import evaluate_model +from src.models.anthropic_inference import AnthropicInferenceEngine +from src.models.gemini_inference import GeminiInferenceEngine +from src.models.gemma_inference import GemmaInferenceEngine +from src.models.intern_vl_inference import InternVLInferenceEngine +from src.models.openai_inference import OpenAIInferenceEngine from src.models.qwen_vl_inference import QwenVLInferenceEngine from src.train.train_qwen import train from src.utils.approach import get_approach_kwargs, get_approach_name from src.utils.logger import get_logger -from src.utils.utils import get_resize_image_size, is_cuda +from src.utils.utils import get_resize_image_size logger = get_logger(__name__) if __name__ == "__main__": parser = ArgumentParser() + parser.add_argument( + "--provider", + help="The inference provider/model to use.", + choices=[ + "openai", + "anthropic", + "gemini", + "gemma", + "intern_vl", + "qwen", + ], + default="qwen", + ) parser.add_argument( "--train", help="Set to finetune the current model", @@ -51,7 +69,7 @@ parser.add_argument( "--resize_factor", help="Resize factor to apply to the images. Original size is (1600 x 900). Currently only applied if using image_grid approach.", - default="0.5", + default="1", ) parser.add_argument( "--batch_size", @@ -94,14 +112,24 @@ grid="image_grid" in args.approach, ) logger.debug(f"Using resize image size: {resize_image_size}") - if is_cuda(): + + engine = None + if args.provider == "openai": + engine = OpenAIInferenceEngine(model=args.model_path) + elif args.provider == "anthropic": + engine = AnthropicInferenceEngine(model=args.model_path) + elif args.provider == "gemini": + engine = GeminiInferenceEngine(model=args.model_path) + elif args.provider == "gemma": + engine = GemmaInferenceEngine(model_path=model_path) + elif args.provider == "intern_vl": + engine = InternVLInferenceEngine(model_path=model_path) + elif args.provider == "qwen": engine = QwenVLInferenceEngine( - model_path=model_path, - use_4bit=True, - resize_image_size=resize_image_size, + model_path=model_path, resize_image_size=resize_image_size ) else: - engine = QwenVLInferenceEngine(resize_image_size=resize_image_size) + raise ValueError(f"Unknown provider: {args.provider}") evaluate_model( engine=engine, diff --git a/src/models/intern_vl_inference.py b/src/models/intern_vl_inference.py index cd15e14..75ca8a3 100644 --- a/src/models/intern_vl_inference.py +++ b/src/models/intern_vl_inference.py @@ -16,7 +16,7 @@ class InternVLInferenceEngine(BaseInferenceEngine): def __init__( self, model_path: str = "OpenGVLab/InternVL3-2B", - use_4bit: bool = False, + use_4bit: bool = True, torch_dtype: Optional[torch.dtype] = None, revision: Optional[str] = None, device: Optional[str] = None, diff --git a/src/models/qwen_vl_inference.py b/src/models/qwen_vl_inference.py index 342c497..8aa3553 100644 --- a/src/models/qwen_vl_inference.py +++ b/src/models/qwen_vl_inference.py @@ -19,7 +19,7 @@ def __init__( self, processor_path: str = "Qwen/Qwen2.5-VL-7B-Instruct", model_path: Optional[str] = None, - use_4bit: bool = False, + use_4bit: bool = True, torch_dtype: Optional[torch.dtype] = None, revision: Optional[str] = None, device: Optional[str] = None, From 533683c67a615cf8b04664db8240263685340a1c Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 21:39:26 +0200 Subject: [PATCH 08/12] fix requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6e7cb09..fa35a85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,6 @@ trl~=0.18.1 anthropic~=0.58.2 openai~=1.97.1 google-genai~=1.24.0 -dotenv~=1.1.1 +python-dotenv~=1.1.1 polars==1.31.0 ultralytics==8.3.168 From 2e816ac2b7b7d71f8d6d0be4984e7faf10ed165a Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 22:42:34 +0200 Subject: [PATCH 09/12] add rate limiting and system prompt to anthropic --- src/models/anthropic_inference.py | 49 +++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/models/anthropic_inference.py b/src/models/anthropic_inference.py index 3508d2f..72fa889 100644 --- a/src/models/anthropic_inference.py +++ b/src/models/anthropic_inference.py @@ -1,4 +1,6 @@ import os +import time +from collections import deque from typing import Dict, List, Optional import anthropic @@ -12,13 +14,20 @@ class AnthropicInferenceEngine(RemoteInferenceEngine): - def __init__(self, model: Optional[str] = "claude-3-5-haiku-20241022"): + def __init__( + self, + model: Optional[str] = "claude-3-5-haiku-20241022", + max_requests_per_minute: int = 5, + ): super().__init__(model_path=model) dotenv.load_dotenv() api_key = os.getenv("ANTHROPIC_API_KEY") self.client = anthropic.Anthropic(api_key=api_key) self.model = model self.message_formatter = AnthropicMessageFormat() + self._request_records = deque() + self._max_requests = max_requests_per_minute + self._window_seconds = 60 def predict_batch(self, messages: List[List[Dict]]): responses = [] @@ -35,14 +44,18 @@ def predict_batch(self, messages: List[List[Dict]]): else: filtered_msg.append(m) - response = self.client.messages.create( - model=self.model, - messages=filtered_msg, - temperature=0.6, - max_tokens=128, - ) - responses.append(response.content[0].text.strip()) - total_tokens += response.usage.input_tokens + response.usage.output_tokens + self._rate_limit() + + response = self.client.messages.create( + model=self.model, + messages=filtered_msg, + system=system_prompt, + temperature=0.6, + max_tokens=128, + ) + + responses.append(response.content[0].text.strip()) + total_tokens += response.usage.input_tokens + response.usage.output_tokens logger.debug( f"Generated {len(responses)} responses for batch of size {len(messages)}" @@ -50,3 +63,21 @@ def predict_batch(self, messages: List[List[Dict]]): logger.debug(f"Total tokens used: {total_tokens}") return responses + + def _rate_limit(self): + now = time.time() + + while self._request_records and ( + now - self._request_records[0] > self._window_seconds + ): + self._request_records.popleft() + + if len(self._request_records) >= self._max_requests: + oldest_request_time = self._request_records[0] + sleep_time = self._window_seconds - (now - oldest_request_time) + logger.warning( + f"Request rate limit reached. Sleeping for {sleep_time:.2f} seconds." + ) + time.sleep(max(sleep_time, 0)) + + self._request_records.append(now) From bbd1181cb5b7ae54e5efdefe10e5edbe04769cb7 Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Sun, 10 Aug 2025 22:42:47 +0200 Subject: [PATCH 10/12] add images to remote llm test --- tests/test_remote_inference.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/test_remote_inference.py b/tests/test_remote_inference.py index ddb015d..a62498c 100644 --- a/tests/test_remote_inference.py +++ b/tests/test_remote_inference.py @@ -13,37 +13,38 @@ class TestRemoteInferenceEngine(unittest.TestCase): @unittest.skip("Skipping OpenAI inference test") def test_openai_predict_batch(self): engine = OpenAIInferenceEngine() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - ] - ] + + messages = message_formats.OpenAIMessageFormat().format( + question="What is the capital of France?", + system_prompt="You are a helpful assistant.", + image_path="./tests/test_data/images/CAM_BACK.jpg", + ) + results = engine.predict_batch(messages) self.assertTrue(len(results) > 0, "Results should not be empty") def test_gemini_predict_batch(self): engine = GeminiInferenceEngine(model="gemini-2.0-flash") - message = message_formats.GeminiMessageFormat().format( + messages = message_formats.GeminiMessageFormat().format( question="What is the capital of France?", - image_path=None, + image_path="./tests/test_data/images/CAM_BACK.jpg", system_prompt="You are a helpful assistant.", ) - results = engine.predict_batch([message]) + results = engine.predict_batch([messages]) self.assertTrue(len(results) > 0, "Results should not be empty") def test_anthropic_predict_batch(self): engine = AnthropicInferenceEngine(model="claude-3-5-haiku-20241022") - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - ] - ] - results = engine.predict_batch(messages) + + messages = message_formats.AnthropicMessageFormat().format( + question="What is the capital of France?", + image_path="./tests/test_data/images/CAM_BACK.jpg", + system_prompt="You are a helpful assistant.", + ) + results = engine.predict_batch([messages]) self.assertTrue(len(results) > 0, "Results should not be empty") From 3d440bbe9581865b201c34dcd29cd4b8db60c474 Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Mon, 11 Aug 2025 12:10:34 +0200 Subject: [PATCH 11/12] sorting of dataset by qa_type only for reasoning batch processing --- src/data/basic_dataset.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/data/basic_dataset.py b/src/data/basic_dataset.py index f1aafdf..0558858 100644 --- a/src/data/basic_dataset.py +++ b/src/data/basic_dataset.py @@ -186,9 +186,13 @@ def __init__( ) logger.info(f"Removed {removed} scenes due to missing image files.") - logger.info(f"Loaded {len(qa_list)} QAs from the DriveLM dataset.") + logger.info( + f"Loaded {len(qa_list)} QAs from the DriveLM dataset for split {self.split}." + ) + + if use_reasoning: + qa_list.sort(key=lambda qa: qa["qa_type"]) - qa_list.sort(key=lambda qa: qa["qa_type"]) self.qas = qa_list def __len__(self): From bd3c4a00682e0ddaa8cfb9531c8115f3ff5a90e7 Mon Sep 17 00:00:00 2001 From: Caspar Siemssen Date: Tue, 12 Aug 2025 19:12:49 +0200 Subject: [PATCH 12/12] fix qwen message --- src/data/message_formats.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/data/message_formats.py b/src/data/message_formats.py index 8aa8b7b..5e0cdfa 100644 --- a/src/data/message_formats.py +++ b/src/data/message_formats.py @@ -1,4 +1,6 @@ +import base64 from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -54,10 +56,12 @@ def format( "image": f"file://{image_path}", } ) - return { - "role": "user", - "content": content, - } + return [ + { + "role": "user", + "content": content, + } + ] class QwenTrainingMessageFormat(MessageFormat): @@ -282,9 +286,6 @@ def format( key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, ) -> List[Dict[str, Any]]: - import base64 - from pathlib import Path - user_content = [] if image_path: image_bytes = Path(image_path).read_bytes() @@ -297,6 +298,7 @@ def format( "media_type": "image/jpeg", "data": image_b64, }, + "cache_control": {"type": "ephemeral"}, } )