diff --git a/main.py b/main.py index 5cef7cb..414f7e7 100644 --- a/main.py +++ b/main.py @@ -3,16 +3,34 @@ from src.constants import model_dir from src.eval.eval_models import evaluate_model +from src.models.anthropic_inference import AnthropicInferenceEngine +from src.models.gemini_inference import GeminiInferenceEngine +from src.models.gemma_inference import GemmaInferenceEngine +from src.models.intern_vl_inference import InternVLInferenceEngine +from src.models.openai_inference import OpenAIInferenceEngine from src.models.qwen_vl_inference import QwenVLInferenceEngine from src.train.train_qwen import train from src.utils.approach import get_approach_kwargs, get_approach_name from src.utils.logger import get_logger -from src.utils.utils import get_resize_image_size, is_cuda +from src.utils.utils import get_resize_image_size logger = get_logger(__name__) if __name__ == "__main__": parser = ArgumentParser() + parser.add_argument( + "--provider", + help="The inference provider/model to use.", + choices=[ + "openai", + "anthropic", + "gemini", + "gemma", + "intern_vl", + "qwen", + ], + default="qwen", + ) parser.add_argument( "--train", help="Set to finetune the current model", @@ -51,7 +69,7 @@ parser.add_argument( "--resize_factor", help="Resize factor to apply to the images. Original size is (1600 x 900). Currently only applied if using image_grid approach.", - default="0.5", + default="1", ) parser.add_argument( "--batch_size", @@ -94,14 +112,24 @@ grid="image_grid" in args.approach, ) logger.debug(f"Using resize image size: {resize_image_size}") - if is_cuda(): + + engine = None + if args.provider == "openai": + engine = OpenAIInferenceEngine(model=args.model_path) + elif args.provider == "anthropic": + engine = AnthropicInferenceEngine(model=args.model_path) + elif args.provider == "gemini": + engine = GeminiInferenceEngine(model=args.model_path) + elif args.provider == "gemma": + engine = GemmaInferenceEngine(model_path=model_path) + elif args.provider == "intern_vl": + engine = InternVLInferenceEngine(model_path=model_path) + elif args.provider == "qwen": engine = QwenVLInferenceEngine( - model_path=model_path, - use_4bit=True, - resize_image_size=resize_image_size, + model_path=model_path, resize_image_size=resize_image_size ) else: - engine = QwenVLInferenceEngine(resize_image_size=resize_image_size) + raise ValueError(f"Unknown provider: {args.provider}") evaluate_model( engine=engine, diff --git a/requirements.txt b/requirements.txt index 200f3dc..fa35a85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,9 @@ gdown~=5.2.0 pre-commit~=4.2.0 peft~=0.15.2 trl~=0.18.1 +anthropic~=0.58.2 +openai~=1.97.1 +google-genai~=1.24.0 +python-dotenv~=1.1.1 polars==1.31.0 ultralytics==8.3.168 diff --git a/src/data/basic_dataset.py b/src/data/basic_dataset.py index f1aafdf..0558858 100644 --- a/src/data/basic_dataset.py +++ b/src/data/basic_dataset.py @@ -186,9 +186,13 @@ def __init__( ) logger.info(f"Removed {removed} scenes due to missing image files.") - logger.info(f"Loaded {len(qa_list)} QAs from the DriveLM dataset.") + logger.info( + f"Loaded {len(qa_list)} QAs from the DriveLM dataset for split {self.split}." + ) + + if use_reasoning: + qa_list.sort(key=lambda qa: qa["qa_type"]) - qa_list.sort(key=lambda qa: qa["qa_type"]) self.qas = qa_list def __len__(self): diff --git a/src/data/message_formats.py b/src/data/message_formats.py index cd13bd9..5e0cdfa 100644 --- a/src/data/message_formats.py +++ b/src/data/message_formats.py @@ -1,4 +1,6 @@ +import base64 from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -12,7 +14,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: pass @@ -25,7 +27,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: content = [] if system_prompt: content.append({"type": "text", "text": system_prompt}) @@ -54,10 +56,12 @@ def format( "image": f"file://{image_path}", } ) - return { - "role": "user", - "content": content, - } + return [ + { + "role": "user", + "content": content, + } + ] class QwenTrainingMessageFormat(MessageFormat): @@ -73,7 +77,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: user_content = [] if system_prompt: user_content.append({"type": "text", "text": system_prompt}) @@ -101,20 +105,18 @@ def format( {"type": "text", "text": f"Context Answer: {context_a}"} ) - return { - "messages": [ - { - "role": "user", - "content": user_content, - }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": answer}, - ], - }, - ] - } + return [ + { + "role": "user", + "content": user_content, + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": answer}, + ], + }, + ] class InternVLMessageFormat(MessageFormat): @@ -126,7 +128,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: full_prompt = "" if system_prompt: full_prompt += system_prompt + "\n\n" @@ -141,10 +143,12 @@ def format( f"\n\nContext Question: {context_q}\nContext Answer: {context_a}" ) - return { - "text": full_prompt, - "image_path": image_path, - } + return [ + { + "text": full_prompt, + "image_path": image_path, + } + ] class GemmaMessageFormat(MessageFormat): @@ -156,7 +160,7 @@ def format( answer: Optional[str] = None, key_object_info: Optional[dict] = None, context: Optional[List[Tuple[str, str]]] = None, - ) -> Dict[str, Any]: + ) -> List[Dict[str, Any]]: content = [ { "type": "image", @@ -181,7 +185,139 @@ def format( ) content.append({"type": "text", "text": f"Context Answer: {context_a}"}) - return { - "role": "user", - "content": content, - } + return [ + { + "role": "user", + "content": content, + } + ] + + +class OpenAIMessageFormat(MessageFormat): + def format( + self, + question: str, + image_path: str, + system_prompt: str = None, + answer: Optional[str] = None, + key_object_info: Optional[dict] = None, + context: Optional[List[Tuple[str, str]]] = None, + ) -> List[Dict[str, Any]]: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + user_content = [] + if image_path: + user_content.append( + { + "type": "image_url", + "image_url": {"url": f"file://{image_path}"}, + } + ) + + user_content.append({"type": "text", "text": "Question: " + question}) + + if key_object_info: + user_content.append( + { + "type": "text", + "text": "Key object infos:\n" + str(key_object_info), + } + ) + + if context: + for context_q, context_a in context: + user_content.append( + {"type": "text", "text": f"Context Question: {context_q}"} + ) + user_content.append( + {"type": "text", "text": f"Context Answer: {context_a}"} + ) + + messages.append({"role": "user", "content": user_content}) + + return messages + + +class GeminiMessageFormat(MessageFormat): + def format( + self, + question: str, + image_path: str, + system_prompt: str = None, + answer: Optional[str] = None, + key_object_info: Optional[dict] = None, + context: Optional[List[Tuple[str, str]]] = None, + ) -> List[Dict[str, Any]]: + import PIL.Image + + parts = [] + + if system_prompt: + parts.append(system_prompt) + + if image_path: + image = PIL.Image.open(image_path) + parts.append(image) + + prompt_parts = ["Question: " + question] + + if key_object_info: + prompt_parts.append("Key object infos:\n" + str(key_object_info)) + + if context: + for context_q, context_a in context: + prompt_parts.append(f"Context Question: {context_q}") + prompt_parts.append(f"Context Answer: {context_a}") + + parts.append("\n".join(prompt_parts)) + + return parts + + +class AnthropicMessageFormat(MessageFormat): + def format( + self, + question: str, + image_path: str, + system_prompt: str = None, + answer: Optional[str] = None, + key_object_info: Optional[dict] = None, + context: Optional[List[Tuple[str, str]]] = None, + ) -> List[Dict[str, Any]]: + user_content = [] + if image_path: + image_bytes = Path(image_path).read_bytes() + image_b64 = base64.b64encode(image_bytes).decode("utf-8") + user_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_b64, + }, + "cache_control": {"type": "ephemeral"}, + } + ) + + messages = [] + + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + text_parts = ["Question: " + question] + + if key_object_info: + text_parts.append("Key object infos:\n" + str(key_object_info)) + if context: + for context_q, context_a in context: + text_parts.append(f"Context Question: {context_q}") + text_parts.append(f"Context Answer: {context_a}") + + user_content.append({"type": "text", "text": "\n".join(text_parts)}) + + messages.append({"role": "user", "content": user_content}) + + return messages diff --git a/src/data/query_item.py b/src/data/query_item.py index a2a644e..3115b6c 100644 --- a/src/data/query_item.py +++ b/src/data/query_item.py @@ -17,11 +17,11 @@ class QueryItem: system_prompt: str = None ground_truth_answer: Optional[str] = None - formatted_message: Optional[Dict[str, Any]] = None + formatted_message: Optional[List[Dict[str, Any]]] = None context_pairs: List[Tuple[str, str]] = field(default_factory=list) - def format_message(self, formatter: MessageFormat) -> Dict[str, Any]: + def format_message(self, formatter: MessageFormat) -> List[Dict[str, Any]]: self.formatted_message = formatter.format( question=self.question, image_path=self.image_path, diff --git a/src/eval/eval_models.py b/src/eval/eval_models.py index 6e73d82..b83c74f 100644 --- a/src/eval/eval_models.py +++ b/src/eval/eval_models.py @@ -75,7 +75,7 @@ def evaluate_model( if use_reasoning: batch = reasoning_engine.process_batch(batch) - formatted_messages = [[item.formatted_message] for item in batch] + formatted_messages = [item.formatted_message for item in batch] batch_results = engine.predict_batch(formatted_messages) diff --git a/src/models/anthropic_inference.py b/src/models/anthropic_inference.py new file mode 100644 index 0000000..72fa889 --- /dev/null +++ b/src/models/anthropic_inference.py @@ -0,0 +1,83 @@ +import os +import time +from collections import deque +from typing import Dict, List, Optional + +import anthropic +import dotenv + +from src.data.message_formats import AnthropicMessageFormat +from src.models.remote_inference import RemoteInferenceEngine +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class AnthropicInferenceEngine(RemoteInferenceEngine): + def __init__( + self, + model: Optional[str] = "claude-3-5-haiku-20241022", + max_requests_per_minute: int = 5, + ): + super().__init__(model_path=model) + dotenv.load_dotenv() + api_key = os.getenv("ANTHROPIC_API_KEY") + self.client = anthropic.Anthropic(api_key=api_key) + self.model = model + self.message_formatter = AnthropicMessageFormat() + self._request_records = deque() + self._max_requests = max_requests_per_minute + self._window_seconds = 60 + + def predict_batch(self, messages: List[List[Dict]]): + responses = [] + total_tokens = 0 + for msg in messages: + system_prompt = None + filtered_msg = [] + for m in msg: + if m["role"] == "system": + if system_prompt is None: + system_prompt = m["content"] + else: + system_prompt += "\n" + m["content"] + else: + filtered_msg.append(m) + + self._rate_limit() + + response = self.client.messages.create( + model=self.model, + messages=filtered_msg, + system=system_prompt, + temperature=0.6, + max_tokens=128, + ) + + responses.append(response.content[0].text.strip()) + total_tokens += response.usage.input_tokens + response.usage.output_tokens + + logger.debug( + f"Generated {len(responses)} responses for batch of size {len(messages)}" + ) + logger.debug(f"Total tokens used: {total_tokens}") + + return responses + + def _rate_limit(self): + now = time.time() + + while self._request_records and ( + now - self._request_records[0] > self._window_seconds + ): + self._request_records.popleft() + + if len(self._request_records) >= self._max_requests: + oldest_request_time = self._request_records[0] + sleep_time = self._window_seconds - (now - oldest_request_time) + logger.warning( + f"Request rate limit reached. Sleeping for {sleep_time:.2f} seconds." + ) + time.sleep(max(sleep_time, 0)) + + self._request_records.append(now) diff --git a/src/models/gemini_inference.py b/src/models/gemini_inference.py new file mode 100644 index 0000000..7f60f78 --- /dev/null +++ b/src/models/gemini_inference.py @@ -0,0 +1,44 @@ +from typing import Dict, List, Optional + +import dotenv +from google import genai +from google.genai import types + +from src.data.message_formats import GeminiMessageFormat +from src.models.remote_inference import RemoteInferenceEngine +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class GeminiInferenceEngine(RemoteInferenceEngine): + def __init__(self, model: Optional[str] = "gemini-2.0-flash"): + super().__init__(model_path=model) + dotenv.load_dotenv() + self.client = genai.Client() + self.model = model + self.message_formatter = GeminiMessageFormat() + + def predict_batch(self, messages: List[List[Dict]]): + responses = [] + total_tokens = 0 + + for msg in messages: + response = self.client.models.generate_content( + model=self.model, + contents=msg, + config=types.GenerateContentConfig( + temperature=0.6, + max_output_tokens=128, + ), + ) + + responses.append(response.text.strip()) + total_tokens += response.usage_metadata.total_token_count + + logger.debug( + f"Generated {len(responses)} responses for batch of size {len(messages)}" + ) + logger.debug(f"Total tokens used: {total_tokens}") + + return responses diff --git a/src/models/intern_vl_inference.py b/src/models/intern_vl_inference.py index cd15e14..75ca8a3 100644 --- a/src/models/intern_vl_inference.py +++ b/src/models/intern_vl_inference.py @@ -16,7 +16,7 @@ class InternVLInferenceEngine(BaseInferenceEngine): def __init__( self, model_path: str = "OpenGVLab/InternVL3-2B", - use_4bit: bool = False, + use_4bit: bool = True, torch_dtype: Optional[torch.dtype] = None, revision: Optional[str] = None, device: Optional[str] = None, diff --git a/src/models/openai_inference.py b/src/models/openai_inference.py new file mode 100644 index 0000000..8bc984d --- /dev/null +++ b/src/models/openai_inference.py @@ -0,0 +1,41 @@ +import os +from typing import Dict, List, Optional + +import dotenv +from openai import OpenAI + +from src.data.message_formats import OpenAIMessageFormat +from src.models.remote_inference import RemoteInferenceEngine +from src.utils.logger import get_logger + +logger = get_logger(__name__) + + +class OpenAIInferenceEngine(RemoteInferenceEngine): + def __init__(self, model: Optional[str] = "gpt-4.1"): + super().__init__(model_path=model) + dotenv.load_dotenv() + api_key = os.getenv("OPENAI_API_KEY") + self.client = OpenAI(api_key=api_key) + self.model = model + self.message_formatter = OpenAIMessageFormat() + + def predict_batch(self, messages: List[List[Dict]]): + responses = [] + total_tokens = 0 + for msg in messages: + response = self.client.chat.completions.create( + model=self.model, + messages=msg, + temperature=0.6, + max_tokens=128, + ) + responses.append(response.output_text) + total_tokens += response.usage.total_tokens + + logger.debug( + f"Generated {len(responses)} responses for batch of size {len(messages)}" + ) + logger.debug(f"Total tokens used: {total_tokens}") + + return responses diff --git a/src/models/qwen_vl_inference.py b/src/models/qwen_vl_inference.py index 342c497..8aa3553 100644 --- a/src/models/qwen_vl_inference.py +++ b/src/models/qwen_vl_inference.py @@ -19,7 +19,7 @@ def __init__( self, processor_path: str = "Qwen/Qwen2.5-VL-7B-Instruct", model_path: Optional[str] = None, - use_4bit: bool = False, + use_4bit: bool = True, torch_dtype: Optional[torch.dtype] = None, revision: Optional[str] = None, device: Optional[str] = None, diff --git a/src/models/remote_inference.py b/src/models/remote_inference.py new file mode 100644 index 0000000..ec2073b --- /dev/null +++ b/src/models/remote_inference.py @@ -0,0 +1,13 @@ +from abc import abstractmethod +from typing import Dict, List + +from src.models.base_inference import BaseInferenceEngine + + +class RemoteInferenceEngine(BaseInferenceEngine): + def load_model(self) -> None: + pass + + @abstractmethod + def predict_batch(self, messages: List[List[Dict]]) -> list[str]: + pass diff --git a/src/reasoning/reasoning_engine.py b/src/reasoning/reasoning_engine.py index 9589d66..a967945 100644 --- a/src/reasoning/reasoning_engine.py +++ b/src/reasoning/reasoning_engine.py @@ -49,7 +49,7 @@ def process_batch(self, batch_items: List[QueryItem]) -> List[QueryItem]: if descriptor_items: descriptor_messages = [ - [desc_item.formatted_message] for _, desc_item in descriptor_items + desc_item.formatted_message for _, desc_item in descriptor_items ] descriptor_answers = self.engine.predict_batch(descriptor_messages) diff --git a/src/train/train_qwen.py b/src/train/train_qwen.py index 6f9655a..1903a29 100644 --- a/src/train/train_qwen.py +++ b/src/train/train_qwen.py @@ -268,12 +268,9 @@ def train( def collator(batch: Any): texts = [ - engine.processor.apply_chat_template(data["messages"], tokenize=False) - for data in batch + engine.processor.apply_chat_template(data, tokenize=False) for data in batch ] - image_inputs, video_inputs = process_vision_info( - [data["messages"][0] for data in batch] - ) + image_inputs, video_inputs = process_vision_info([data[0] for data in batch]) processed_batch = engine.processor( text=texts, @@ -290,12 +287,12 @@ def collator(batch: Any): assistant_idx = next( j for data in batch - for j, m in enumerate(data["messages"]) + for j, m in enumerate(data) if m["role"] == "assistant" ) pre_text = engine.processor.apply_chat_template( - data["messages"][:assistant_idx], tokenize=False + data[:assistant_idx], tokenize=False ) pre_tokens = engine.processor.tokenizer(pre_text, return_tensors="pt")[ "input_ids" diff --git a/tests/test_intern_vl_inference.py b/tests/test_intern_vl_inference.py index b39a9b9..7370264 100644 --- a/tests/test_intern_vl_inference.py +++ b/tests/test_intern_vl_inference.py @@ -57,7 +57,7 @@ def test_internvl_model_predict(self): results = [] for batch in dataloader: query_items = batch - formatted_messages = [[item.formatted_message] for item in query_items] + formatted_messages = [item.formatted_message for item in query_items] predictions = engine.predict_batch(formatted_messages) results.extend(predictions) diff --git a/tests/test_message_format.py b/tests/test_message_format.py index e9debe8..5b479f0 100644 --- a/tests/test_message_format.py +++ b/tests/test_message_format.py @@ -60,10 +60,12 @@ def test_format_of_internvl_message(self): context=context, ) - expected_message = { - "text": "This is the system prompt\n\nQuestion: What is the color of the car?\n\nKey object infos:\n{'object': 'car', 'color': 'blue'}\n\nContext Question: What is this?\nContext Answer: This is a car.", - "image_path": image_path, - } + expected_message = [ + { + "text": "This is the system prompt\n\nQuestion: What is the color of the car?\n\nKey object infos:\n{'object': 'car', 'color': 'blue'}\n\nContext Question: What is this?\nContext Answer: This is a car.", + "image_path": image_path, + } + ] self.assertEqual( formatted_message, expected_message, @@ -86,23 +88,25 @@ def test_format_of_gemma_message(self): context=context, ) - expected_message = { - "role": "user", - "content": [ - {"type": "image", "image": image_path}, - {"type": "text", "text": system_prompt}, - {"type": "text", "text": "Question: " + question}, - { - "type": "text", - "text": "Key object infos:\n{'country': 'France', 'capital': 'Paris'}", - }, - {"type": "text", "text": "Context Question: What is this?"}, - { - "type": "text", - "text": "Context Answer: This is a map of France.", - }, - ], - } + expected_message = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": system_prompt}, + {"type": "text", "text": "Question: " + question}, + { + "type": "text", + "text": "Key object infos:\n{'country': 'France', 'capital': 'Paris'}", + }, + {"type": "text", "text": "Context Question: What is this?"}, + { + "type": "text", + "text": "Context Answer: This is a map of France.", + }, + ], + } + ] self.assertEqual( formatted_message, diff --git a/tests/test_models_eval.py b/tests/test_models_eval.py index 05e0844..c308f17 100644 --- a/tests/test_models_eval.py +++ b/tests/test_models_eval.py @@ -7,11 +7,29 @@ from src.constants import data_dir from src.eval.eval_models import evaluate_model +from src.models.anthropic_inference import AnthropicInferenceEngine +from src.models.gemini_inference import GeminiInferenceEngine from src.models.intern_vl_inference import InternVLInferenceEngine from src.models.qwen_vl_inference import QwenVLInferenceEngine from src.utils.utils import is_cuda, sanitize_model_name +def check_eval_files(output_dir, output_file_name, submission_file_name): + output_file = os.path.join(output_dir, output_file_name) + submission_file = os.path.join(output_dir, submission_file_name) + + assert os.path.exists(output_file), "Output file should be created." + assert os.path.exists(submission_file), "Submission file should be created." + + with open(output_file, "r") as f: + results = json.load(f) + assert len(results) > 0, "Results should not be empty." + for result in results: + assert "id" in result + assert "question" in result + assert "answer" in result + + @pytest.mark.eval class TestModelEvaluation(unittest.TestCase): def test_intern_vl_eval(self): @@ -32,28 +50,18 @@ def test_intern_vl_eval(self): test_set_size=1, use_reasoning=True, use_grid=True, - approach_name="test_models_eval", + approach_name="test_intern_vl_eval", ) model_dir = sanitize_model_name(engine.model_path) output_dir = os.path.join(data_dir, "output", model_dir) - output_file = os.path.join(output_dir, "test_models_eval_output.json") - submission_file = os.path.join(output_dir, "test_models_eval_submission.json") - - self.assertTrue(os.path.exists(output_file), "Output file should be created.") - self.assertTrue( - os.path.exists(submission_file), - "Submission file should be created.", + check_eval_files( + output_dir, + "test_intern_vl_eval_output.json", + "test_intern_vl_eval_submission.json", ) - with open(output_file, "r") as f: - results = json.load(f) - self.assertGreater(len(results), 0, "Results should not be empty.") - for result in results: - self.assertIn("id", result) - self.assertIn("question", result) - self.assertIn("answer", result) - + @unittest.skip("Skipping Qwen evaluation test") def test_qwen_eval(self): if is_cuda(): engine = QwenVLInferenceEngine( @@ -66,24 +74,57 @@ def test_qwen_eval(self): engine = QwenVLInferenceEngine("Qwen/Qwen2.5-VL-3B-Instruct") evaluate_model( - engine=engine, dataset_split="val", batch_size=1, test_set_size=1 + engine=engine, + dataset_split="val", + batch_size=1, + test_set_size=1, + approach_name="test_qwen_eval", ) model_dir = sanitize_model_name(engine.model_path) output_dir = os.path.join(data_dir, "output", model_dir) - output_file = os.path.join(output_dir, "test_test_test_output.json") - submission_file = os.path.join(output_dir, "test_test_test_submission.json") + check_eval_files( + output_dir, + "test_qwen_eval_output.json", + "test_qwen_eval_submission.json", + ) - self.assertTrue(os.path.exists(output_file), "Output file should be created.") - self.assertTrue( - os.path.exists(submission_file), - "Submission file should be created.", + def test_gemini_eval(self): + engine = GeminiInferenceEngine(model="gemini-2.0-flash") + evaluate_model( + engine=engine, + dataset_split="val", + batch_size=1, + test_set_size=1, + use_grid=True, + use_system_prompt=True, + approach_name="test_gemini_eval", + ) + + model_dir = engine.model + output_dir = os.path.join(data_dir, "output", model_dir) + check_eval_files( + output_dir, + "test_gemini_eval_output.json", + "test_gemini_eval_submission.json", ) - with open(output_file, "r") as f: - results = json.load(f) - self.assertGreater(len(results), 0, "Results should not be empty.") - for result in results: - self.assertIn("id", result) - self.assertIn("question", result) - self.assertIn("answer", result) + def test_anthropic_eval(self): + engine = AnthropicInferenceEngine(model="claude-3-5-haiku-20241022") + evaluate_model( + engine=engine, + dataset_split="val", + batch_size=1, + test_set_size=1, + use_grid=True, + use_system_prompt=True, + approach_name="test_anthropic_eval", + ) + + model_dir = sanitize_model_name(engine.model_path) + output_dir = os.path.join(data_dir, "output", model_dir) + check_eval_files( + output_dir, + "test_anthropic_eval_output.json", + "test_anthropic_eval_submission.json", + ) diff --git a/tests/test_remote_inference.py b/tests/test_remote_inference.py new file mode 100644 index 0000000..a62498c --- /dev/null +++ b/tests/test_remote_inference.py @@ -0,0 +1,53 @@ +import unittest + +import pytest + +import src.data.message_formats as message_formats +from src.models.anthropic_inference import AnthropicInferenceEngine +from src.models.gemini_inference import GeminiInferenceEngine +from src.models.openai_inference import OpenAIInferenceEngine + + +@pytest.mark.inference +class TestRemoteInferenceEngine(unittest.TestCase): + @unittest.skip("Skipping OpenAI inference test") + def test_openai_predict_batch(self): + engine = OpenAIInferenceEngine() + + messages = message_formats.OpenAIMessageFormat().format( + question="What is the capital of France?", + system_prompt="You are a helpful assistant.", + image_path="./tests/test_data/images/CAM_BACK.jpg", + ) + + results = engine.predict_batch(messages) + + self.assertTrue(len(results) > 0, "Results should not be empty") + + def test_gemini_predict_batch(self): + engine = GeminiInferenceEngine(model="gemini-2.0-flash") + messages = message_formats.GeminiMessageFormat().format( + question="What is the capital of France?", + image_path="./tests/test_data/images/CAM_BACK.jpg", + system_prompt="You are a helpful assistant.", + ) + + results = engine.predict_batch([messages]) + + self.assertTrue(len(results) > 0, "Results should not be empty") + + def test_anthropic_predict_batch(self): + engine = AnthropicInferenceEngine(model="claude-3-5-haiku-20241022") + + messages = message_formats.AnthropicMessageFormat().format( + question="What is the capital of France?", + image_path="./tests/test_data/images/CAM_BACK.jpg", + system_prompt="You are a helpful assistant.", + ) + results = engine.predict_batch([messages]) + + self.assertTrue(len(results) > 0, "Results should not be empty") + + +if __name__ == "__main__": + unittest.main()