Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,34 @@

from src.constants import model_dir
from src.eval.eval_models import evaluate_model
from src.models.anthropic_inference import AnthropicInferenceEngine
from src.models.gemini_inference import GeminiInferenceEngine
from src.models.gemma_inference import GemmaInferenceEngine
from src.models.intern_vl_inference import InternVLInferenceEngine
from src.models.openai_inference import OpenAIInferenceEngine
from src.models.qwen_vl_inference import QwenVLInferenceEngine
from src.train.train_qwen import train
from src.utils.approach import get_approach_kwargs, get_approach_name
from src.utils.logger import get_logger
from src.utils.utils import get_resize_image_size, is_cuda
from src.utils.utils import get_resize_image_size

logger = get_logger(__name__)

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--provider",
help="The inference provider/model to use.",
choices=[
"openai",
"anthropic",
"gemini",
"gemma",
"intern_vl",
"qwen",
],
default="qwen",
)
parser.add_argument(
"--train",
help="Set to finetune the current model",
Expand Down Expand Up @@ -51,7 +69,7 @@
parser.add_argument(
"--resize_factor",
help="Resize factor to apply to the images. Original size is (1600 x 900). Currently only applied if using image_grid approach.",
default="0.5",
default="1",
)
parser.add_argument(
"--batch_size",
Expand Down Expand Up @@ -94,14 +112,24 @@
grid="image_grid" in args.approach,
)
logger.debug(f"Using resize image size: {resize_image_size}")
if is_cuda():

engine = None
if args.provider == "openai":
engine = OpenAIInferenceEngine(model=args.model_path)
elif args.provider == "anthropic":
engine = AnthropicInferenceEngine(model=args.model_path)
elif args.provider == "gemini":
engine = GeminiInferenceEngine(model=args.model_path)
elif args.provider == "gemma":
engine = GemmaInferenceEngine(model_path=model_path)
elif args.provider == "intern_vl":
engine = InternVLInferenceEngine(model_path=model_path)
elif args.provider == "qwen":
engine = QwenVLInferenceEngine(
model_path=model_path,
use_4bit=True,
resize_image_size=resize_image_size,
model_path=model_path, resize_image_size=resize_image_size
)
else:
engine = QwenVLInferenceEngine(resize_image_size=resize_image_size)
raise ValueError(f"Unknown provider: {args.provider}")

evaluate_model(
engine=engine,
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,9 @@ gdown~=5.2.0
pre-commit~=4.2.0
peft~=0.15.2
trl~=0.18.1
anthropic~=0.58.2
openai~=1.97.1
google-genai~=1.24.0
python-dotenv~=1.1.1
polars==1.31.0
ultralytics==8.3.168
8 changes: 6 additions & 2 deletions src/data/basic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,13 @@ def __init__(
)

logger.info(f"Removed {removed} scenes due to missing image files.")
logger.info(f"Loaded {len(qa_list)} QAs from the DriveLM dataset.")
logger.info(
f"Loaded {len(qa_list)} QAs from the DriveLM dataset for split {self.split}."
)

if use_reasoning:
qa_list.sort(key=lambda qa: qa["qa_type"])

qa_list.sort(key=lambda qa: qa["qa_type"])
self.qas = qa_list

def __len__(self):
Expand Down
198 changes: 167 additions & 31 deletions src/data/message_formats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


Expand All @@ -12,7 +14,7 @@ def format(
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:
pass


Expand All @@ -25,7 +27,7 @@ def format(
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:
content = []
if system_prompt:
content.append({"type": "text", "text": system_prompt})
Expand Down Expand Up @@ -54,10 +56,12 @@ def format(
"image": f"file://{image_path}",
}
)
return {
"role": "user",
"content": content,
}
return [
{
"role": "user",
"content": content,
}
]


class QwenTrainingMessageFormat(MessageFormat):
Expand All @@ -73,7 +77,7 @@ def format(
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:
user_content = []
if system_prompt:
user_content.append({"type": "text", "text": system_prompt})
Expand Down Expand Up @@ -101,20 +105,18 @@ def format(
{"type": "text", "text": f"Context Answer: {context_a}"}
)

return {
"messages": [
{
"role": "user",
"content": user_content,
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer},
],
},
]
}
return [
{
"role": "user",
"content": user_content,
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer},
],
},
]


class InternVLMessageFormat(MessageFormat):
Expand All @@ -126,7 +128,7 @@ def format(
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:
full_prompt = ""
if system_prompt:
full_prompt += system_prompt + "\n\n"
Expand All @@ -141,10 +143,12 @@ def format(
f"\n\nContext Question: {context_q}\nContext Answer: {context_a}"
)

return {
"text": full_prompt,
"image_path": image_path,
}
return [
{
"text": full_prompt,
"image_path": image_path,
}
]


class GemmaMessageFormat(MessageFormat):
Expand All @@ -156,7 +160,7 @@ def format(
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> Dict[str, Any]:
) -> List[Dict[str, Any]]:
content = [
{
"type": "image",
Expand All @@ -181,7 +185,139 @@ def format(
)
content.append({"type": "text", "text": f"Context Answer: {context_a}"})

return {
"role": "user",
"content": content,
}
return [
{
"role": "user",
"content": content,
}
]


class OpenAIMessageFormat(MessageFormat):
def format(
self,
question: str,
image_path: str,
system_prompt: str = None,
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> List[Dict[str, Any]]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})

user_content = []
if image_path:
user_content.append(
{
"type": "image_url",
"image_url": {"url": f"file://{image_path}"},
}
)

user_content.append({"type": "text", "text": "Question: " + question})

if key_object_info:
user_content.append(
{
"type": "text",
"text": "Key object infos:\n" + str(key_object_info),
}
)

if context:
for context_q, context_a in context:
user_content.append(
{"type": "text", "text": f"Context Question: {context_q}"}
)
user_content.append(
{"type": "text", "text": f"Context Answer: {context_a}"}
)

messages.append({"role": "user", "content": user_content})

return messages


class GeminiMessageFormat(MessageFormat):
def format(
self,
question: str,
image_path: str,
system_prompt: str = None,
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> List[Dict[str, Any]]:
import PIL.Image

parts = []

if system_prompt:
parts.append(system_prompt)

if image_path:
image = PIL.Image.open(image_path)
parts.append(image)

prompt_parts = ["Question: " + question]

if key_object_info:
prompt_parts.append("Key object infos:\n" + str(key_object_info))

if context:
for context_q, context_a in context:
prompt_parts.append(f"Context Question: {context_q}")
prompt_parts.append(f"Context Answer: {context_a}")

parts.append("\n".join(prompt_parts))

return parts


class AnthropicMessageFormat(MessageFormat):
def format(
self,
question: str,
image_path: str,
system_prompt: str = None,
answer: Optional[str] = None,
key_object_info: Optional[dict] = None,
context: Optional[List[Tuple[str, str]]] = None,
) -> List[Dict[str, Any]]:
user_content = []
if image_path:
image_bytes = Path(image_path).read_bytes()
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
user_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_b64,
},
"cache_control": {"type": "ephemeral"},
}
)

messages = []

if system_prompt:
messages.append({"role": "system", "content": system_prompt})

text_parts = ["Question: " + question]

if key_object_info:
text_parts.append("Key object infos:\n" + str(key_object_info))
if context:
for context_q, context_a in context:
text_parts.append(f"Context Question: {context_q}")
text_parts.append(f"Context Answer: {context_a}")

user_content.append({"type": "text", "text": "\n".join(text_parts)})

messages.append({"role": "user", "content": user_content})

return messages
4 changes: 2 additions & 2 deletions src/data/query_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class QueryItem:
system_prompt: str = None
ground_truth_answer: Optional[str] = None

formatted_message: Optional[Dict[str, Any]] = None
formatted_message: Optional[List[Dict[str, Any]]] = None

context_pairs: List[Tuple[str, str]] = field(default_factory=list)

def format_message(self, formatter: MessageFormat) -> Dict[str, Any]:
def format_message(self, formatter: MessageFormat) -> List[Dict[str, Any]]:
self.formatted_message = formatter.format(
question=self.question,
image_path=self.image_path,
Expand Down
2 changes: 1 addition & 1 deletion src/eval/eval_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def evaluate_model(
if use_reasoning:
batch = reasoning_engine.process_batch(batch)

formatted_messages = [[item.formatted_message] for item in batch]
formatted_messages = [item.formatted_message for item in batch]

batch_results = engine.predict_batch(formatted_messages)

Expand Down
Loading
Loading