Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,7 @@ build-backend = "pdm.backend"

[tool.pylint."MESSAGES CONTROL"]
disable = ["R0801"]

[tool.ruff]
[tool.ruff.lint.flake8-tidy-imports]
banned-api = { "unittest" = { msg = "use pytest instead of unittest" }, "unittest.mock" = { msg = "use pytest-mock instead of unittest.mock" } }
15 changes: 7 additions & 8 deletions tests/unit/app/endpoints/test_conversations_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

"""Unit tests for the /conversations REST API endpoints."""

from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
import pytest
from fastapi import HTTPException, status

from app.endpoints.conversations_v2 import (
Expand Down Expand Up @@ -123,10 +122,10 @@ def test_transform_message_with_empty_referenced_documents(self) -> None:


@pytest.fixture
def mock_configuration():
def mock_configuration(mocker: MockerFixture):
"""Mock configuration with conversation cache."""
mock_config = Mock()
mock_cache = Mock()
mock_config = mocker.Mock()
mock_cache = mocker.Mock()
mock_config.conversation_cache = mock_cache
return mock_config

Expand Down Expand Up @@ -157,7 +156,7 @@ class TestCheckConversationExistence:
def test_conversation_exists(self, mocker, mock_configuration):
"""Test when conversation exists."""
mock_configuration.conversation_cache.list.return_value = [
Mock(conversation_id=VALID_CONVERSATION_ID)
mocker.Mock(conversation_id=VALID_CONVERSATION_ID)
]
mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration)

Expand Down Expand Up @@ -221,7 +220,7 @@ async def test_invalid_conversation_id_format(
async def test_conversation_cache_not_configured(self, mocker: MockerFixture):
"""Test the endpoint when conversation cache is not configured."""
mock_authorization_resolvers(mocker)
mock_config = Mock()
mock_config = mocker.Mock()
mock_config.conversation_cache = None
mocker.patch("app.endpoints.conversations_v2.configuration", mock_config)
mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True)
Expand Down Expand Up @@ -269,7 +268,7 @@ async def test_successful_update(self, mocker: MockerFixture, mock_configuration
mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration)
mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True)
mock_configuration.conversation_cache.list.return_value = [
Mock(conversation_id=VALID_CONVERSATION_ID)
mocker.Mock(conversation_id=VALID_CONVERSATION_ID)
]

update_request = ConversationUpdateRequest(topic_summary="New topic summary")
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/app/endpoints/test_health.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Unit tests for the /health REST API endpoint."""

from unittest.mock import Mock
from pytest_mock import MockerFixture

import pytest
from pytest_mock import MockerFixture
from llama_stack.providers.datatypes import HealthStatus
from app.endpoints.health import (
readiness_probe_get_method,
Expand Down Expand Up @@ -32,7 +31,7 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker: MockerFi
]

# Mock the Response object and auth
mock_response = Mock()
mock_response = mocker.Mock()
auth = ("test_user", "token", {})

response = await readiness_probe_get_method(auth=auth, response=mock_response)
Expand Down Expand Up @@ -68,7 +67,7 @@ async def test_readiness_probe_success_when_all_providers_healthy(
]

# Mock the Response object and auth
mock_response = Mock()
mock_response = mocker.Mock()
auth = ("test_user", "token", {})

response = await readiness_probe_get_method(auth=auth, response=mock_response)
Expand Down
14 changes: 6 additions & 8 deletions tests/unit/app/endpoints/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Unit tests for the /providers REST API endpoints."""

from unittest.mock import AsyncMock

import pytest
from fastapi import HTTPException, Request, status
from llama_stack_client import APIConnectionError
Expand All @@ -27,7 +25,7 @@ async def test_providers_endpoint_configuration_not_loaded(mocker):
@pytest.mark.asyncio
async def test_providers_endpoint_connection_error(mocker):
"""Test that /providers endpoint raises HTTP 500 if Llama Stack connection fails."""
mock_client = AsyncMock()
mock_client = mocker.AsyncMock()
mock_client.providers.list.side_effect = APIConnectionError(request=None)
mocker.patch(
"app.endpoints.providers.AsyncLlamaStackClientHolder"
Expand Down Expand Up @@ -62,7 +60,7 @@ async def test_providers_endpoint_success(mocker):
"provider_type": "remote::huggingface",
},
]
mock_client = AsyncMock()
mock_client = mocker.AsyncMock()
mock_client.providers.list.return_value = provider_list
mocker.patch(
"app.endpoints.providers.AsyncLlamaStackClientHolder"
Expand All @@ -80,7 +78,7 @@ async def test_providers_endpoint_success(mocker):
@pytest.mark.asyncio
async def test_get_provider_not_found(mocker):
"""Test that /providers/{provider_id} endpoint raises HTTP 404 if the provider is not found."""
mock_client = AsyncMock()
mock_client = mocker.AsyncMock()
mock_client.providers.list.return_value = []
mocker.patch(
"app.endpoints.providers.AsyncLlamaStackClientHolder"
Expand All @@ -107,7 +105,7 @@ async def test_get_provider_success(mocker):
"config": {"api_key": "*****"},
"health": {"status": "OK", "message": "Healthy"},
}
mock_client = AsyncMock()
mock_client = mocker.AsyncMock()
mock_client.providers.list.return_value = [provider]
mocker.patch(
"app.endpoints.providers.AsyncLlamaStackClientHolder"
Expand All @@ -126,7 +124,7 @@ async def test_get_provider_success(mocker):
@pytest.mark.asyncio
async def test_get_provider_connection_error(mocker):
"""Test that /providers/{provider_id} raises HTTP 500 if Llama Stack connection fails."""
mock_client = AsyncMock()
mock_client = mocker.AsyncMock()
mock_client.providers.list.side_effect = APIConnectionError(request=None)
mocker.patch(
"app.endpoints.providers.AsyncLlamaStackClientHolder"
Expand All @@ -146,7 +144,7 @@ async def test_get_provider_connection_error(mocker):
@pytest.mark.asyncio
async def test_get_provider_unexpected_exception(mocker):
"""Test that /providers/{provider_id} endpoint raises HTTP 500 for unexpected exceptions."""
mock_client = AsyncMock()
mock_client = mocker.AsyncMock()
mock_client.providers.list.side_effect = Exception("boom")
mocker.patch(
"app.endpoints.providers.AsyncLlamaStackClientHolder"
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/models/requests/test_query_request.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Unit tests for QueryRequest model."""

from logging import Logger
from unittest.mock import Mock
from pytest_mock import MockerFixture

import pytest

Expand Down Expand Up @@ -134,11 +134,11 @@ def test_validate_provider_and_model(self) -> None:
):
QueryRequest(query="Tell me about Kubernetes", provider="OpenAI")

def test_validate_media_type(self, mocker) -> None:
def test_validate_media_type(self, mocker: MockerFixture) -> None:
"""Test the validate_media_type method."""

# Mock the logger
mock_logger = Mock(spec=Logger)
mock_logger = mocker.Mock(spec=Logger)
mocker.patch("models.requests.logger", mock_logger)

qr = QueryRequest(
Expand Down
122 changes: 61 additions & 61 deletions tests/unit/runners/test_uvicorn_runner.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,80 @@
"""Unit tests for runners."""

from pathlib import Path
from unittest.mock import patch
from pytest_mock import MockerFixture


from models.config import ServiceConfiguration, TLSConfiguration
from runners.uvicorn import start_uvicorn


def test_start_uvicorn() -> None:
def test_start_uvicorn(mocker: MockerFixture) -> None:
"""Test the function to start Uvicorn server using de-facto default configuration."""
configuration = ServiceConfiguration(host="localhost", port=8080, workers=1)

# don't start real Uvicorn server
with patch("uvicorn.run") as mocked_run:
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="localhost",
port=8080,
workers=1,
log_level=20,
ssl_certfile=None,
ssl_keyfile=None,
ssl_keyfile_password="",
use_colors=True,
access_log=True,
)
mocked_run = mocker.patch("uvicorn.run")
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="localhost",
port=8080,
workers=1,
log_level=20,
ssl_certfile=None,
ssl_keyfile=None,
ssl_keyfile_password="",
use_colors=True,
access_log=True,
)


def test_start_uvicorn_different_host_port() -> None:
def test_start_uvicorn_different_host_port(mocker: MockerFixture) -> None:
"""Test the function to start Uvicorn server using custom configuration."""
configuration = ServiceConfiguration(host="x.y.com", port=1234, workers=10)

# don't start real Uvicorn server
with patch("uvicorn.run") as mocked_run:
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="x.y.com",
port=1234,
workers=10,
log_level=20,
ssl_certfile=None,
ssl_keyfile=None,
ssl_keyfile_password="",
use_colors=True,
access_log=True,
)
mocked_run = mocker.patch("uvicorn.run")
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="x.y.com",
port=1234,
workers=10,
log_level=20,
ssl_certfile=None,
ssl_keyfile=None,
ssl_keyfile_password="",
use_colors=True,
access_log=True,
)


def test_start_uvicorn_empty_tls_configuration() -> None:
def test_start_uvicorn_empty_tls_configuration(mocker: MockerFixture) -> None:
"""Test the function to start Uvicorn server using empty TLS configuration."""
tls_config = TLSConfiguration()
configuration = ServiceConfiguration(
host="x.y.com", port=1234, workers=10, tls_config=tls_config
)

# don't start real Uvicorn server
with patch("uvicorn.run") as mocked_run:
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="x.y.com",
port=1234,
workers=10,
log_level=20,
ssl_certfile=None,
ssl_keyfile=None,
ssl_keyfile_password="",
use_colors=True,
access_log=True,
)
mocked_run = mocker.patch("uvicorn.run")
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="x.y.com",
port=1234,
workers=10,
log_level=20,
ssl_certfile=None,
ssl_keyfile=None,
ssl_keyfile_password="",
use_colors=True,
access_log=True,
)


def test_start_uvicorn_tls_configuration() -> None:
def test_start_uvicorn_tls_configuration(mocker: MockerFixture) -> None:
"""Test the function to start Uvicorn server using custom TLS configuration."""
tls_config = TLSConfiguration(
tls_certificate_path=Path("tests/configuration/server.crt"),
Expand All @@ -86,17 +86,17 @@ def test_start_uvicorn_tls_configuration() -> None:
)

# don't start real Uvicorn server
with patch("uvicorn.run") as mocked_run:
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="x.y.com",
port=1234,
workers=10,
log_level=20,
ssl_certfile=Path("tests/configuration/server.crt"),
ssl_keyfile=Path("tests/configuration/server.key"),
ssl_keyfile_password="tests/configuration/password",
use_colors=True,
access_log=True,
)
mocked_run = mocker.patch("uvicorn.run")
start_uvicorn(configuration)
mocked_run.assert_called_once_with(
"app.main:app",
host="x.y.com",
port=1234,
workers=10,
log_level=20,
ssl_certfile=Path("tests/configuration/server.crt"),
ssl_keyfile=Path("tests/configuration/server.key"),
ssl_keyfile_password="tests/configuration/password",
use_colors=True,
access_log=True,
)
7 changes: 3 additions & 4 deletions tests/unit/utils/auth_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Helper functions for mocking authorization in tests."""

from unittest.mock import AsyncMock, Mock
from pytest_mock import MockerFixture

from models.config import Action
Expand All @@ -18,10 +17,10 @@ def mock_authorization_resolvers(mocker: MockerFixture) -> None:
mock_resolvers = mocker.patch(
"authorization.middleware.get_authorization_resolvers"
)
mock_role_resolver = AsyncMock()
mock_access_resolver = Mock()
mock_role_resolver = mocker.AsyncMock()
mock_access_resolver = mocker.Mock()
mock_role_resolver.resolve_roles.return_value = set()
mock_access_resolver.check_access.return_value = True
# get_actions should be synchronous, not async
mock_access_resolver.get_actions = Mock(return_value=set(Action))
mock_access_resolver.get_actions = mocker.Mock(return_value=set(Action))
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
Loading
Loading