diff --git a/vlmrun/client/predictions.py b/vlmrun/client/predictions.py index 8ddcf07..0aff649 100644 --- a/vlmrun/client/predictions.py +++ b/vlmrun/client/predictions.py @@ -2,12 +2,15 @@ from __future__ import annotations import json +import tempfile +import contextlib from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Union, Generator from PIL import Image from loguru import logger import time +from vlmrun.common.utils import remote_image from vlmrun.common.image import encode_image, _open_image_with_exif from vlmrun.client.base_requestor import APIRequestor from vlmrun.types.abstract import VLMRunProtocol @@ -25,6 +28,35 @@ from cachetools.keys import hashkey +@contextlib.contextmanager +def image_path_ctx( + image: Image.Image | None = None, + url: str | None = None, +) -> Generator[Path, None, None]: + """Context manager to handle temporary image paths. + + Args: + image: PIL Image object + url: URL of the image + + Yields: + str: Path to the temporary image file + """ + if not url and not image: + raise ValueError("Either `image` or `url` must be provided") + if url and image: + raise ValueError("Cannot provide both `image` and `url`") + + # Download the image from the URL if provided + if url: + image: Image.Image = remote_image(url) + + # Save the image to a temporary file, and yield the path + with tempfile.NamedTemporaryFile(suffix=".jpg") as temp_file: + image.save(temp_file.name, format="JPEG", quality=98) + yield Path(temp_file.name) + + @cachetools.cached( cache=cachetools.TTLCache(maxsize=100, ttl=3600), key=lambda _client, domain, config: hashkey( @@ -292,6 +324,22 @@ def generate( raise ValueError("Either `images` or `urls` must be provided") if images and urls: raise ValueError("Only one of `images` or `urls` can be provided") + if batch and len(images) > 1: + raise ValueError("Batch mode only supports one image") + + if batch: + assert len(images) == 1, "Batch mode only supports one image" + with image_path_ctx(image=images[0]) as image_path: + return self._client.document.generate( + file=image_path, + model=model, + domain=domain, + batch=batch, + config=config, + metadata=metadata, + callback_url=callback_url, + autocast=autocast, + ) if images: # Check if all images are of the same type diff --git a/vlmrun/version.py b/vlmrun/version.py index 04d1c7c..198d6db 100644 --- a/vlmrun/version.py +++ b/vlmrun/version.py @@ -1 +1 @@ -__version__ = "0.2.19" +__version__ = "0.2.20"