Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class Settings(BaseSettings):
DB_MAX_RETRIES: int = 3
DB_RETRY_DELAY: float = 1.0

# ColPali model configuration
COLPALI_MODEL_NAME: str = "vidore/colpali-v1.2-merged"


# Embedding configuration
EMBEDDING_PROVIDER: Literal["litellm"] = "litellm"
EMBEDDING_MODEL: str
Expand Down
108 changes: 71 additions & 37 deletions core/embedding/colpali_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,104 @@

import numpy as np
import torch
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
from PIL.Image import Image
from PIL.Image import open as open_image

# Import ALL 3 model architectures
from colpali_engine.models import (
ColQwen2_5,
ColQwen2_5_Processor,
ColIdefics3,
ColIdefics3Processor,
ColPali,
ColPaliProcessor
)

from core.config import get_settings
from core.embedding.base_embedding_model import BaseEmbeddingModel
from core.models.chunk import Chunk
from core.utils.fast_ops import data_uri_to_bytes

logger = logging.getLogger(__name__)


_INGEST_METRICS: ContextVar[Dict[str, Any]] = ContextVar("_colpali_ingest_metrics", default={})


class ColpaliEmbeddingModel(BaseEmbeddingModel):
def __init__(self):
self.settings = get_settings()
self.mode = self.settings.MODE

# 1. Determine Device
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing ColpaliEmbeddingModel with device: {device}")
start_time = time.time()

# Enable TF32 for faster matmuls on Ampere+ GPUs (A10, A100, etc.)

# 2. Configure Attention (Flash Attn 2 Check)
attn_implementation = "eager"
if device == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logger.info("Enabled TF32 for CUDA matmul operations")

attn_implementation = "eager"
if device == "cuda":
if importlib.util.find_spec("flash_attn") is not None:
attn_implementation = "flash_attention_2"
else:
logger.warning(
"flash_attn package not found; falling back to 'eager' attention. "
"Install flash-attn to enable FlashAttention2 on GPU."
)

# 3. Model Selector Logic
# Get model name from morphik.toml via settings, default to the standard ColPali v1.2
model_name = getattr(self.settings, 'COLPALI_MODEL_NAME', "vidore/colpali-v1.2-merged")

logger.info(f"Loading ColPali Model: {model_name}")

# Dynamic Loading for Smol, Qwen, AND PaliGemma

# CASE 1: SMOLVLM (Idefics3 Architecture)
# See: https://huggingface.co/vidore/colSmol-256M
if "smol" in model_name.lower() or "idefics" in model_name.lower():
logger.info("Detected SmolVLM/Idefics3 architecture.")
self.model = ColIdefics3.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
self.processor = ColIdefics3Processor.from_pretrained(model_name, use_fast=True)
# Smol is tiny (256M/500M), boost batch size!
self.batch_size = 32 if self.mode == "cloud" else 4

# CASE 2: QWEN (Qwen2/2.5-VL Architecture)
elif "qwen" in model_name.lower():
logger.info("Detected Qwen2/2.5-VL architecture.")
self.model = ColQwen2_5.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
self.processor = ColQwen2_5_Processor.from_pretrained(model_name, use_fast=True)
# Qwen is heavy (3B+), keep batch size conservative
self.batch_size = 8 if self.mode == "cloud" else 1

# CASE 3: COLPALI (PaliGemma Architecture - The Default)
else:
logger.info("Detected Standard ColPali (PaliGemma) architecture.")
self.model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
self.processor = ColPaliProcessor.from_pretrained(model_name, use_fast=True)
# PaliGemma is ~3B, keep batch size conservative
self.batch_size = 8 if self.mode == "cloud" else 1

self.model = ColQwen2_5.from_pretrained(
"tsystems/colqwen2.5-3b-multilingual-v1.0",
dtype=torch.bfloat16, # preferred kwarg per upstream deprecation notice
device_map=device, # Automatically detect and use available device
attn_implementation=attn_implementation,
).eval()
self.processor: ColQwen2_5_Processor = ColQwen2_5_Processor.from_pretrained(
"tsystems/colqwen2.5-3b-multilingual-v1.0",
use_fast=True,
)
self.settings = get_settings()
self.mode = self.settings.MODE
self.device = device
# Set batch size based on mode
self.batch_size = 8 if self.mode == "cloud" else 1
logger.info(f"Colpali running in mode: {self.mode} with batch size: {self.batch_size}")
total_init_time = time.time() - start_time
logger.info(f"Colpali running in mode: {self.mode} with batch size: {self.batch_size}")
logger.info(f"Colpali initialization time: {total_init_time:.2f} seconds")

async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[np.ndarray]:
Expand Down Expand Up @@ -95,7 +139,7 @@ async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[n
image_items.append((index, image))
except Exception as e:
logger.error(f"Error processing image chunk {index}: {str(e)}. Falling back to text.")
text_items.append((index, chunk.content)) # Fallback: treat content as text
text_items.append((index, chunk.content))
else:
text_items.append((index, chunk.content))

Expand Down Expand Up @@ -127,7 +171,6 @@ async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[n
image_model += batch_metrics["model"]
image_convert += batch_metrics["convert"]
image_total += batch_metrics["total"]
# Place embeddings in the correct position in results
for original_index, embedding in zip(batch_indices, batch_embeddings):
results[original_index] = embedding
batch_time = time.time() - batch_start
Expand Down Expand Up @@ -160,7 +203,6 @@ async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[n
text_model += batch_metrics["model"]
text_convert += batch_metrics["convert"]
text_total += batch_metrics["total"]
# Place embeddings in the correct position in results
for original_index, embedding in zip(batch_indices, batch_embeddings):
results[original_index] = embedding
batch_time = time.time() - batch_start
Expand All @@ -174,27 +216,25 @@ async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[n
text_process = text_model = text_convert = text_total = 0.0
text_time = 0.0

# Ensure all chunks were processed (handle potential None entries if errors occurred,
# though unlikely with fallback)
# Ensure all chunks were processed
final_results = [res for res in results if res is not None]
if len(final_results) != len(chunks):
logger.warning(
f"Number of embeddings ({len(final_results)}) does not match number of chunks "
f"({len(chunks)}). Some chunks might have failed."
)
# Fill potential gaps if necessary, though the current logic should cover all chunks
# For safety, let's reconstruct based on successfully processed indices, though it shouldn't be needed
processed_indices = {idx for idx, _ in image_items} | {idx for idx, _ in text_items}
if len(processed_indices) != len(chunks):
logger.error("Mismatch in processed indices vs original chunks count. This indicates a logic error.")
# Assuming results contains embeddings at correct original indices, filter out Nones
final_results = [results[i] for i in range(len(chunks)) if results[i] is not None]

total_time = time.time() - job_start_time
logger.info(
f"Total Colpali embed_for_ingestion took {total_time:.2f}s for {len(chunks)} chunks "
f"({total_time/len(chunks) if chunks else 0:.2f}s per chunk)"
)

# Collect and store metrics
metrics = {
"sorting": sorting_time,
"image_process": image_process,
Expand All @@ -214,7 +254,7 @@ async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[n
"chunk_count": len(chunks),
}
_INGEST_METRICS.set(metrics)
# Cast is safe because we filter out Nones, though Nones shouldn't occur with the fallback logic

return final_results # type: ignore

def latest_ingest_metrics(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -242,8 +282,6 @@ async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray:

model_start = time.time()

# inference_mode is faster than no_grad (disables version tracking)
# autocast ensures consistent bf16 inference on CUDA
with torch.inference_mode():
if self.device == "cuda":
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
Expand All @@ -254,9 +292,7 @@ async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray:
model_time = time.time() - model_start

convert_start = time.time()

result = embeddings.to(torch.float32).numpy(force=True)[0]

convert_time = time.time() - convert_start

total_time = time.time() - start_time
Expand All @@ -266,8 +302,6 @@ async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray:
)
return result

# ---- Batch processing methods (only used in 'cloud' mode) ----

async def generate_embeddings_batch_images(self, images: List[Image]) -> Tuple[List[np.ndarray], Dict[str, float]]:
batch_start_time = time.time()
process_start = time.time()
Expand Down
15 changes: 14 additions & 1 deletion morphik.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ arq_max_jobs = 1 # Maximum concurrent jobs for ARQ worker
colpali_store_batch_size = 16 # Batch size for ColPali vector storage

[pdf]
colpali_pdf_dpi = 150 # DPI for PDF to image conversion in ColPali processing
# colpali_pdf_dpi = 150 # DPI for PDF to image conversion in ColPali processing
colpali_pdf_dpi = 300 # For clarity when using small colpali model

[morphik]
enable_colpali = true
Expand All @@ -151,6 +152,18 @@ api_domain = "api.morphik.ai" # API domain for cloud URIs
morphik_embedding_api_domain = ["http://localhost:6000"]
colpali_mode = "local" # "off", "local", or "api"

# --- ColPali Model Selection ---
# 1. SmolVLM (Lightweight & Fast - Uses ColIdefics3)
# - vidore/colSmol-256M
# - vidore/colSmol-500M
# 2. Qwen 2.5-VL (High Performance - Uses ColQwen2_5)
# - vidore/colqwen2.5-v0.1
# 3. Standard ColPali (PaliGemma - Uses ColPali)
# - vidore/colpali-v1.2-merged
# - vidore/colpali-v1.3
# Default if not set: "vidore/colpali-v1.2-merged"
colpali_model_name="vidore/colpali-v1.2-merged"

[pdf_viewer]
frontend_url = "http://localhost:3000/api/pdf" # "https://morphik.ai/api/pdf" # "http://localhost:3000/api/pdf" # "https://morphik.ai/api/pdf"

Expand Down
Loading