Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,25 @@ def generate(self, domain: str, images=None, urls=None, **kwargs):
if images and urls:
raise ValueError("Only one of `images` or `urls` can be provided")

prediction = PredictionResponse(
id="prediction1",
status="completed",
created_at="2024-01-01T00:00:00+00:00",
completed_at="2024-01-01T00:00:01+00:00",
response={"invoice_number": "INV-001", "total_amount": 100.0},
usage=CreditUsage(credits_used=100),
)
batch = kwargs.get("batch", False)
if batch:
prediction = PredictionResponse(
id="prediction1",
status="pending",
created_at="2024-01-01T00:00:00+00:00",
completed_at=None,
response=None,
usage=CreditUsage(credits_used=0),
)
else:
prediction = PredictionResponse(
id="prediction1",
status="completed",
created_at="2024-01-01T00:00:00+00:00",
completed_at="2024-01-01T00:00:01+00:00",
response={"invoice_number": "INV-001", "total_amount": 100.0},
usage=CreditUsage(credits_used=100),
)

if kwargs.get("autocast", False):
self._cast_response_to_schema(
Expand Down
112 changes: 112 additions & 0 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,118 @@ def test_image_generate_with_url(mock_client):
assert response.id is not None


def test_image_generate_batch_with_file(mock_client, tmp_path):
"""Test batch image generation with a local file returns pending status."""
img_path = tmp_path / "test.jpg"
img = Image.new("RGB", (100, 100), color="blue")
img.save(img_path)

client = mock_client
response = client.image.generate(
domain="test-domain",
images=[img_path],
batch=True,
)
assert isinstance(response, PredictionResponse)
assert response.id is not None
assert response.status == "pending"
assert response.response is None
assert response.completed_at is None


def test_image_generate_batch_with_url(mock_client):
"""Test batch image generation with URLs returns pending status."""
client = mock_client
response = client.image.generate(
domain="test-domain",
urls=["https://example.com/image.jpg"],
batch=True,
)
assert isinstance(response, PredictionResponse)
assert response.id is not None
assert response.status == "pending"
assert response.response is None
assert response.completed_at is None


def test_image_generate_batch_with_multiple_images(mock_client, tmp_path):
"""Test batch image generation with multiple images."""
images = []
for i in range(3):
img_path = tmp_path / f"test_{i}.jpg"
img = Image.new("RGB", (100, 100), color="green")
img.save(img_path)
images.append(img_path)

client = mock_client
response = client.image.generate(
domain="test-domain",
images=images,
batch=True,
)
assert isinstance(response, PredictionResponse)
assert response.status == "pending"
assert response.response is None


def test_image_generate_batch_with_pil_images(mock_client):
"""Test batch image generation with PIL Image objects."""
images = [Image.new("RGB", (100, 100), color="red") for _ in range(2)]

client = mock_client
response = client.image.generate(
domain="test-domain",
images=images,
batch=True,
)
assert isinstance(response, PredictionResponse)
assert response.status == "pending"
assert response.response is None


def test_image_generate_non_batch_still_works(mock_client, tmp_path):
"""Test that non-batch image generation still returns completed status."""
img_path = tmp_path / "test.jpg"
img = Image.new("RGB", (100, 100), color="red")
img.save(img_path)

client = mock_client
response = client.image.generate(
domain="test-domain",
images=[img_path],
batch=False,
)
assert isinstance(response, PredictionResponse)
assert response.status == "completed"
assert response.response is not None
assert response.completed_at is not None


def test_image_generate_batch_then_wait(mock_client, tmp_path, monkeypatch):
"""Test batch image generation followed by polling until completion."""
img_path = tmp_path / "test.jpg"
img = Image.new("RGB", (100, 100), color="red")
img.save(img_path)

client = mock_client
response = client.image.generate(
domain="test-domain",
images=[img_path],
batch=True,
)
assert response.status == "pending"

def mock_sleep(seconds):
pass

monkeypatch.setattr("time.sleep", mock_sleep)

completed = client.predictions.wait(response.id, timeout=10)
assert isinstance(completed, PredictionResponse)
assert completed.status == "completed"
assert completed.response is not None


def test_image_generate_validation(mock_client):
"""Test validation of image generate parameters."""
client = mock_client
Expand Down
Loading