From 34a8644db175141c9595c0925b223486bcc9dbb4 Mon Sep 17 00:00:00 2001 From: Maxim Svistunov Date: Tue, 21 Oct 2025 21:37:01 +0200 Subject: [PATCH 1/2] Convert unittest mocking to pytest mocking --- .../app/endpoints/test_conversations_v2.py | 15 +-- tests/unit/app/endpoints/test_health.py | 7 +- tests/unit/app/endpoints/test_providers.py | 14 +- .../models/requests/test_query_request.py | 6 +- tests/unit/runners/test_uvicorn_runner.py | 122 +++++++++--------- tests/unit/utils/auth_helpers.py | 7 +- tests/unit/utils/test_checks.py | 26 ++-- tests/unit/utils/test_common.py | 41 +++--- tests/unit/utils/test_mcp_headers.py | 66 +++++----- tests/unit/utils/test_types.py | 14 +- 10 files changed, 162 insertions(+), 156 deletions(-) diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 37fb92b8b..baafadb83 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -2,9 +2,8 @@ """Unit tests for the /conversations REST API endpoints.""" -from unittest.mock import Mock -import pytest from pytest_mock import MockerFixture +import pytest from fastapi import HTTPException, status from app.endpoints.conversations_v2 import ( @@ -123,10 +122,10 @@ def test_transform_message_with_empty_referenced_documents(self) -> None: @pytest.fixture -def mock_configuration(): +def mock_configuration(mocker: MockerFixture): """Mock configuration with conversation cache.""" - mock_config = Mock() - mock_cache = Mock() + mock_config = mocker.Mock() + mock_cache = mocker.Mock() mock_config.conversation_cache = mock_cache return mock_config @@ -157,7 +156,7 @@ class TestCheckConversationExistence: def test_conversation_exists(self, mocker, mock_configuration): """Test when conversation exists.""" mock_configuration.conversation_cache.list.return_value = [ - Mock(conversation_id=VALID_CONVERSATION_ID) + mocker.Mock(conversation_id=VALID_CONVERSATION_ID) ] mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) @@ -221,7 +220,7 @@ async def test_invalid_conversation_id_format( async def test_conversation_cache_not_configured(self, mocker: MockerFixture): """Test the endpoint when conversation cache is not configured.""" mock_authorization_resolvers(mocker) - mock_config = Mock() + mock_config = mocker.Mock() mock_config.conversation_cache = None mocker.patch("app.endpoints.conversations_v2.configuration", mock_config) mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) @@ -269,7 +268,7 @@ async def test_successful_update(self, mocker: MockerFixture, mock_configuration mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) mock_configuration.conversation_cache.list.return_value = [ - Mock(conversation_id=VALID_CONVERSATION_ID) + mocker.Mock(conversation_id=VALID_CONVERSATION_ID) ] update_request = ConversationUpdateRequest(topic_summary="New topic summary") diff --git a/tests/unit/app/endpoints/test_health.py b/tests/unit/app/endpoints/test_health.py index d3ec8f492..2be703bcf 100644 --- a/tests/unit/app/endpoints/test_health.py +++ b/tests/unit/app/endpoints/test_health.py @@ -1,9 +1,8 @@ """Unit tests for the /health REST API endpoint.""" -from unittest.mock import Mock +from pytest_mock import MockerFixture import pytest -from pytest_mock import MockerFixture from llama_stack.providers.datatypes import HealthStatus from app.endpoints.health import ( readiness_probe_get_method, @@ -32,7 +31,7 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker: MockerFi ] # Mock the Response object and auth - mock_response = Mock() + mock_response = mocker.Mock() auth = ("test_user", "token", {}) response = await readiness_probe_get_method(auth=auth, response=mock_response) @@ -68,7 +67,7 @@ async def test_readiness_probe_success_when_all_providers_healthy( ] # Mock the Response object and auth - mock_response = Mock() + mock_response = mocker.Mock() auth = ("test_user", "token", {}) response = await readiness_probe_get_method(auth=auth, response=mock_response) diff --git a/tests/unit/app/endpoints/test_providers.py b/tests/unit/app/endpoints/test_providers.py index 167400cc6..bff7bb1bc 100644 --- a/tests/unit/app/endpoints/test_providers.py +++ b/tests/unit/app/endpoints/test_providers.py @@ -1,7 +1,5 @@ """Unit tests for the /providers REST API endpoints.""" -from unittest.mock import AsyncMock - import pytest from fastapi import HTTPException, Request, status from llama_stack_client import APIConnectionError @@ -27,7 +25,7 @@ async def test_providers_endpoint_configuration_not_loaded(mocker): @pytest.mark.asyncio async def test_providers_endpoint_connection_error(mocker): """Test that /providers endpoint raises HTTP 500 if Llama Stack connection fails.""" - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.providers.list.side_effect = APIConnectionError(request=None) mocker.patch( "app.endpoints.providers.AsyncLlamaStackClientHolder" @@ -62,7 +60,7 @@ async def test_providers_endpoint_success(mocker): "provider_type": "remote::huggingface", }, ] - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.providers.list.return_value = provider_list mocker.patch( "app.endpoints.providers.AsyncLlamaStackClientHolder" @@ -80,7 +78,7 @@ async def test_providers_endpoint_success(mocker): @pytest.mark.asyncio async def test_get_provider_not_found(mocker): """Test that /providers/{provider_id} endpoint raises HTTP 404 if the provider is not found.""" - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.providers.list.return_value = [] mocker.patch( "app.endpoints.providers.AsyncLlamaStackClientHolder" @@ -107,7 +105,7 @@ async def test_get_provider_success(mocker): "config": {"api_key": "*****"}, "health": {"status": "OK", "message": "Healthy"}, } - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.providers.list.return_value = [provider] mocker.patch( "app.endpoints.providers.AsyncLlamaStackClientHolder" @@ -126,7 +124,7 @@ async def test_get_provider_success(mocker): @pytest.mark.asyncio async def test_get_provider_connection_error(mocker): """Test that /providers/{provider_id} raises HTTP 500 if Llama Stack connection fails.""" - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.providers.list.side_effect = APIConnectionError(request=None) mocker.patch( "app.endpoints.providers.AsyncLlamaStackClientHolder" @@ -146,7 +144,7 @@ async def test_get_provider_connection_error(mocker): @pytest.mark.asyncio async def test_get_provider_unexpected_exception(mocker): """Test that /providers/{provider_id} endpoint raises HTTP 500 for unexpected exceptions.""" - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.providers.list.side_effect = Exception("boom") mocker.patch( "app.endpoints.providers.AsyncLlamaStackClientHolder" diff --git a/tests/unit/models/requests/test_query_request.py b/tests/unit/models/requests/test_query_request.py index 551763370..119e3b36b 100644 --- a/tests/unit/models/requests/test_query_request.py +++ b/tests/unit/models/requests/test_query_request.py @@ -1,7 +1,7 @@ """Unit tests for QueryRequest model.""" from logging import Logger -from unittest.mock import Mock +from pytest_mock import MockerFixture import pytest @@ -134,11 +134,11 @@ def test_validate_provider_and_model(self) -> None: ): QueryRequest(query="Tell me about Kubernetes", provider="OpenAI") - def test_validate_media_type(self, mocker) -> None: + def test_validate_media_type(self, mocker: MockerFixture) -> None: """Test the validate_media_type method.""" # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) mocker.patch("models.requests.logger", mock_logger) qr = QueryRequest( diff --git a/tests/unit/runners/test_uvicorn_runner.py b/tests/unit/runners/test_uvicorn_runner.py index 1568fed7b..fc44811ed 100644 --- a/tests/unit/runners/test_uvicorn_runner.py +++ b/tests/unit/runners/test_uvicorn_runner.py @@ -1,56 +1,56 @@ """Unit tests for runners.""" from pathlib import Path -from unittest.mock import patch +from pytest_mock import MockerFixture from models.config import ServiceConfiguration, TLSConfiguration from runners.uvicorn import start_uvicorn -def test_start_uvicorn() -> None: +def test_start_uvicorn(mocker: MockerFixture) -> None: """Test the function to start Uvicorn server using de-facto default configuration.""" configuration = ServiceConfiguration(host="localhost", port=8080, workers=1) # don't start real Uvicorn server - with patch("uvicorn.run") as mocked_run: - start_uvicorn(configuration) - mocked_run.assert_called_once_with( - "app.main:app", - host="localhost", - port=8080, - workers=1, - log_level=20, - ssl_certfile=None, - ssl_keyfile=None, - ssl_keyfile_password="", - use_colors=True, - access_log=True, - ) + mocked_run = mocker.patch("uvicorn.run") + start_uvicorn(configuration) + mocked_run.assert_called_once_with( + "app.main:app", + host="localhost", + port=8080, + workers=1, + log_level=20, + ssl_certfile=None, + ssl_keyfile=None, + ssl_keyfile_password="", + use_colors=True, + access_log=True, + ) -def test_start_uvicorn_different_host_port() -> None: +def test_start_uvicorn_different_host_port(mocker: MockerFixture) -> None: """Test the function to start Uvicorn server using custom configuration.""" configuration = ServiceConfiguration(host="x.y.com", port=1234, workers=10) # don't start real Uvicorn server - with patch("uvicorn.run") as mocked_run: - start_uvicorn(configuration) - mocked_run.assert_called_once_with( - "app.main:app", - host="x.y.com", - port=1234, - workers=10, - log_level=20, - ssl_certfile=None, - ssl_keyfile=None, - ssl_keyfile_password="", - use_colors=True, - access_log=True, - ) + mocked_run = mocker.patch("uvicorn.run") + start_uvicorn(configuration) + mocked_run.assert_called_once_with( + "app.main:app", + host="x.y.com", + port=1234, + workers=10, + log_level=20, + ssl_certfile=None, + ssl_keyfile=None, + ssl_keyfile_password="", + use_colors=True, + access_log=True, + ) -def test_start_uvicorn_empty_tls_configuration() -> None: +def test_start_uvicorn_empty_tls_configuration(mocker: MockerFixture) -> None: """Test the function to start Uvicorn server using empty TLS configuration.""" tls_config = TLSConfiguration() configuration = ServiceConfiguration( @@ -58,23 +58,23 @@ def test_start_uvicorn_empty_tls_configuration() -> None: ) # don't start real Uvicorn server - with patch("uvicorn.run") as mocked_run: - start_uvicorn(configuration) - mocked_run.assert_called_once_with( - "app.main:app", - host="x.y.com", - port=1234, - workers=10, - log_level=20, - ssl_certfile=None, - ssl_keyfile=None, - ssl_keyfile_password="", - use_colors=True, - access_log=True, - ) + mocked_run = mocker.patch("uvicorn.run") + start_uvicorn(configuration) + mocked_run.assert_called_once_with( + "app.main:app", + host="x.y.com", + port=1234, + workers=10, + log_level=20, + ssl_certfile=None, + ssl_keyfile=None, + ssl_keyfile_password="", + use_colors=True, + access_log=True, + ) -def test_start_uvicorn_tls_configuration() -> None: +def test_start_uvicorn_tls_configuration(mocker: MockerFixture) -> None: """Test the function to start Uvicorn server using custom TLS configuration.""" tls_config = TLSConfiguration( tls_certificate_path=Path("tests/configuration/server.crt"), @@ -86,17 +86,17 @@ def test_start_uvicorn_tls_configuration() -> None: ) # don't start real Uvicorn server - with patch("uvicorn.run") as mocked_run: - start_uvicorn(configuration) - mocked_run.assert_called_once_with( - "app.main:app", - host="x.y.com", - port=1234, - workers=10, - log_level=20, - ssl_certfile=Path("tests/configuration/server.crt"), - ssl_keyfile=Path("tests/configuration/server.key"), - ssl_keyfile_password="tests/configuration/password", - use_colors=True, - access_log=True, - ) + mocked_run = mocker.patch("uvicorn.run") + start_uvicorn(configuration) + mocked_run.assert_called_once_with( + "app.main:app", + host="x.y.com", + port=1234, + workers=10, + log_level=20, + ssl_certfile=Path("tests/configuration/server.crt"), + ssl_keyfile=Path("tests/configuration/server.key"), + ssl_keyfile_password="tests/configuration/password", + use_colors=True, + access_log=True, + ) diff --git a/tests/unit/utils/auth_helpers.py b/tests/unit/utils/auth_helpers.py index c569b970e..a675058c0 100644 --- a/tests/unit/utils/auth_helpers.py +++ b/tests/unit/utils/auth_helpers.py @@ -1,6 +1,5 @@ """Helper functions for mocking authorization in tests.""" -from unittest.mock import AsyncMock, Mock from pytest_mock import MockerFixture from models.config import Action @@ -18,10 +17,10 @@ def mock_authorization_resolvers(mocker: MockerFixture) -> None: mock_resolvers = mocker.patch( "authorization.middleware.get_authorization_resolvers" ) - mock_role_resolver = AsyncMock() - mock_access_resolver = Mock() + mock_role_resolver = mocker.AsyncMock() + mock_access_resolver = mocker.Mock() mock_role_resolver.resolve_roles.return_value = set() mock_access_resolver.check_access.return_value = True # get_actions should be synchronous, not async - mock_access_resolver.get_actions = Mock(return_value=set(Action)) + mock_access_resolver.get_actions = mocker.Mock(return_value=set(Action)) mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver) diff --git a/tests/unit/utils/test_checks.py b/tests/unit/utils/test_checks.py index dc450717a..deb227e55 100644 --- a/tests/unit/utils/test_checks.py +++ b/tests/unit/utils/test_checks.py @@ -3,10 +3,10 @@ import os from pathlib import Path from types import ModuleType -from unittest.mock import patch - from typing import Any +from pytest_mock import MockerFixture + import pytest from utils import checks @@ -84,11 +84,11 @@ def test_file_check_non_existing_file() -> None: checks.file_check(Path("does-not-exists"), "description") -def test_file_check_not_readable_file(input_file: str) -> None: +def test_file_check_not_readable_file(mocker: MockerFixture, input_file: str) -> None: """Test the function file_check for not readable file.""" - with patch("os.access", return_value=False): - with pytest.raises(checks.InvalidConfigurationError): - checks.file_check(input_file, "description") + mocker.patch("os.access", return_value=False) + with pytest.raises(checks.InvalidConfigurationError): + checks.file_check(input_file, "description") def test_directory_check_non_existing_directory() -> None: @@ -120,13 +120,15 @@ def test_directory_check_non_a_directory(input_file: str) -> None: ) -def test_directory_check_existing_non_writable_directory(input_directory: str) -> None: +def test_directory_check_existing_non_writable_directory( + mocker: MockerFixture, input_directory: str +) -> None: """Test the function directory_check checks directory.""" - with patch("os.access", return_value=False): - with pytest.raises(checks.InvalidConfigurationError): - checks.directory_check( - input_directory, must_exists=True, must_be_writable=True, desc="foobar" - ) + mocker.patch("os.access", return_value=False) + with pytest.raises(checks.InvalidConfigurationError): + checks.directory_check( + input_directory, must_exists=True, must_be_writable=True, desc="foobar" + ) def test_import_python_module_success() -> None: diff --git a/tests/unit/utils/test_common.py b/tests/unit/utils/test_common.py index 353b8a18b..95fea4706 100644 --- a/tests/unit/utils/test_common.py +++ b/tests/unit/utils/test_common.py @@ -1,7 +1,7 @@ """Test module for utils/common.py.""" -from unittest.mock import Mock, AsyncMock from logging import Logger + from pytest_mock import MockerFixture import pytest @@ -21,8 +21,7 @@ @pytest.mark.asyncio async def test_register_mcp_servers_empty_list(mocker: MockerFixture) -> None: """Test register_mcp_servers with empty MCP servers list.""" - # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) # Mock the LlamaStack client (shouldn't be called since no MCP servers) mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -55,13 +54,13 @@ async def test_register_mcp_servers_single_server_not_registered( ) -> None: """Test register_mcp_servers with single MCP server that is not yet registered.""" # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) # Mock the LlamaStack client - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client - mock_tool = Mock() + mock_tool = mocker.Mock() mock_tool.provider_resource_id = "existing-server" mock_client.toolgroups.list.return_value = [mock_tool] mock_client.toolgroups.register.return_value = None @@ -102,11 +101,11 @@ async def test_register_mcp_servers_single_server_already_registered( ) -> None: """Test register_mcp_servers with single MCP server that is already registered.""" # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) # Mock the LlamaStack client - mock_client = AsyncMock() - mock_tool = Mock() + mock_client = mocker.AsyncMock() + mock_tool = mocker.Mock() mock_tool.provider_resource_id = "existing-server" mock_client.toolgroups.list.return_value = [mock_tool] mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -142,15 +141,15 @@ async def test_register_mcp_servers_multiple_servers_mixed_registration( ) -> None: """Test register_mcp_servers with multiple MCP servers - some registered, some not.""" # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) # Mock the LlamaStack client - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client - mock_tool1 = Mock() + mock_tool1 = mocker.Mock() mock_tool1.provider_resource_id = "existing-server" - mock_tool2 = Mock() + mock_tool2 = mocker.Mock() mock_tool2.provider_resource_id = "another-existing" mock_client.toolgroups.list.return_value = [mock_tool1, mock_tool2] mock_client.toolgroups.register.return_value = None @@ -204,10 +203,10 @@ async def test_register_mcp_servers_multiple_servers_mixed_registration( async def test_register_mcp_servers_with_custom_provider(mocker: MockerFixture) -> None: """Test register_mcp_servers with MCP server using custom provider.""" # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) # Mock the LlamaStack client - mock_client = AsyncMock() + mock_client = mocker.AsyncMock() mock_client.toolgroups.list.return_value = [] mock_client.toolgroups.register.return_value = None mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") @@ -255,19 +254,19 @@ async def test_register_mcp_servers_async_with_library_client( library client. """ # Mock the logger - mock_logger = Mock(spec=Logger) + mock_logger = mocker.Mock(spec=Logger) # Mock the LlamaStackAsLibraryClient - mock_async_client = AsyncMock() - mock_async_client.initialize = AsyncMock() + mock_async_client = mocker.AsyncMock() + mock_async_client.initialize = mocker.AsyncMock() mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") mock_lsc.return_value = mock_async_client # Mock tools.list to return empty list - mock_tool = Mock() + mock_tool = mocker.Mock() mock_tool.provider_resource_id = "existing-tool" - mock_async_client.toolgroups.list = AsyncMock(return_value=[mock_tool]) - mock_async_client.toolgroups.register = AsyncMock() + mock_async_client.toolgroups.list = mocker.AsyncMock(return_value=[mock_tool]) + mock_async_client.toolgroups.register = mocker.AsyncMock() # Create configuration with library client enabled mcp_server = ModelContextProtocolServer( diff --git a/tests/unit/utils/test_mcp_headers.py b/tests/unit/utils/test_mcp_headers.py index a5415197a..943f692fe 100644 --- a/tests/unit/utils/test_mcp_headers.py +++ b/tests/unit/utils/test_mcp_headers.py @@ -1,6 +1,6 @@ """Unit tests for MCP headers utility functions.""" -from unittest.mock import Mock +from pytest_mock import MockerFixture import pytest from fastapi import Request @@ -8,9 +8,9 @@ from utils import mcp_headers -def test_extract_mcp_headers_empty_headers() -> None: +def test_extract_mcp_headers_empty_headers(mocker: MockerFixture) -> None: """Test the extract_mcp_headers function for request without any headers.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # no headers request.headers = {} @@ -18,9 +18,9 @@ def test_extract_mcp_headers_empty_headers() -> None: assert result == {} -def test_extract_mcp_headers_mcp_headers_empty() -> None: +def test_extract_mcp_headers_mcp_headers_empty(mocker: MockerFixture) -> None: """Test the extract_mcp_headers function for request with empty MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # empty MCP-HEADERS request.headers = {"MCP-HEADERS": ""} @@ -29,9 +29,9 @@ def test_extract_mcp_headers_mcp_headers_empty() -> None: assert result == {} -def test_extract_mcp_headers_valid_mcp_header() -> None: +def test_extract_mcp_headers_valid_mcp_header(mocker: MockerFixture) -> None: """Test the extract_mcp_headers function for request with valid MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # valid MCP-HEADERS request.headers = {"MCP-HEADERS": '{"http://www.redhat.com": {"auth": "token123"}}'} @@ -41,9 +41,9 @@ def test_extract_mcp_headers_valid_mcp_header() -> None: assert result == expected -def test_extract_mcp_headers_valid_mcp_headers() -> None: +def test_extract_mcp_headers_valid_mcp_headers(mocker: MockerFixture) -> None: """Test the extract_mcp_headers function for request with valid MCP-HEADERS headers.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # valid MCP-HEADERS header1 = '"http://www.redhat.com": {"auth": "token123"}' header2 = '"http://www.example.com": {"auth": "tokenXYZ"}' @@ -59,9 +59,9 @@ def test_extract_mcp_headers_valid_mcp_headers() -> None: assert result == expected -def test_extract_mcp_headers_invalid_json_mcp_header() -> None: +def test_extract_mcp_headers_invalid_json_mcp_header(mocker: MockerFixture) -> None: """Test the extract_mcp_headers function for request with invalid MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # invalid MCP-HEADERS - not a JSON request.headers = {"MCP-HEADERS": "this-is-invalid"} @@ -70,9 +70,9 @@ def test_extract_mcp_headers_invalid_json_mcp_header() -> None: assert result == {} -def test_extract_mcp_headers_invalid_mcp_header_type() -> None: +def test_extract_mcp_headers_invalid_mcp_header_type(mocker: MockerFixture) -> None: """Test the extract_mcp_headers function for request with invalid MCP-HEADERS header type.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # invalid MCP-HEADERS - not a dict request.headers = {"MCP-HEADERS": "[]"} @@ -81,9 +81,11 @@ def test_extract_mcp_headers_invalid_mcp_header_type() -> None: assert result == {} -def test_extract_mcp_headers_invalid_mcp_header_null_value() -> None: +def test_extract_mcp_headers_invalid_mcp_header_null_value( + mocker: MockerFixture, +) -> None: """Test the extract_mcp_headers function for request with invalid MCP-HEADERS header type.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # invalid MCP-HEADERS - not a dict request.headers = {"MCP-HEADERS": "null"} @@ -93,9 +95,9 @@ def test_extract_mcp_headers_invalid_mcp_header_null_value() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_empty_headers() -> None: +async def test_mcp_headers_dependency_empty_headers(mocker: MockerFixture) -> None: """Test the mcp_headers_dependency function for request with empty MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # empty MCP-HEADERS request.headers = {"MCP-HEADERS": ""} @@ -105,9 +107,9 @@ async def test_mcp_headers_dependency_empty_headers() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_mcp_headers_empty() -> None: +async def test_mcp_headers_dependency_mcp_headers_empty(mocker: MockerFixture) -> None: """Test the mcp_headers_dependency function for request with empty MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # empty MCP-HEADERS request.headers = {"MCP-HEADERS": ""} @@ -117,9 +119,9 @@ async def test_mcp_headers_dependency_mcp_headers_empty() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_valid_mcp_header() -> None: +async def test_mcp_headers_dependency_valid_mcp_header(mocker: MockerFixture) -> None: """Test the mcp_headers_dependency function for request with valid MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # valid MCP-HEADERS request.headers = {"MCP-HEADERS": '{"http://www.redhat.com": {"auth": "token123"}}'} @@ -130,9 +132,9 @@ async def test_mcp_headers_dependency_valid_mcp_header() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_valid_mcp_headers() -> None: +async def test_mcp_headers_dependency_valid_mcp_headers(mocker: MockerFixture) -> None: """Test the mcp_headers_dependency function for request with valid MCP-HEADERS headers.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # valid MCP-HEADERS header1 = '"http://www.redhat.com": {"auth": "token123"}' header2 = '"http://www.example.com": {"auth": "tokenXYZ"}' @@ -149,9 +151,11 @@ async def test_mcp_headers_dependency_valid_mcp_headers() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_invalid_json_mcp_header() -> None: +async def test_mcp_headers_dependency_invalid_json_mcp_header( + mocker: MockerFixture, +) -> None: """Test the mcp_headers_dependency function for request with invalid MCP-HEADERS header.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # invalid MCP-HEADERS - not a JSON request.headers = {"MCP-HEADERS": "this-is-invalid"} @@ -161,9 +165,11 @@ async def test_mcp_headers_dependency_invalid_json_mcp_header() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_invalid_mcp_header_type() -> None: +async def test_mcp_headers_dependency_invalid_mcp_header_type( + mocker: MockerFixture, +) -> None: """Test the mcp_headers_dependency function for request with invalid MCP-HEADERS header type.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # invalid MCP-HEADERS - not a dict request.headers = {"MCP-HEADERS": "[]"} @@ -173,9 +179,11 @@ async def test_mcp_headers_dependency_invalid_mcp_header_type() -> None: @pytest.mark.asyncio -async def test_mcp_headers_dependency_invalid_mcp_header_null_value() -> None: +async def test_mcp_headers_dependency_invalid_mcp_header_null_value( + mocker: MockerFixture, +) -> None: """Test the mcp_headers_dependency function for request with invalid MCP-HEADERS header type.""" - request = Mock(spec=Request) + request = mocker.Mock(spec=Request) # invalid MCP-HEADERS - not a dict request.headers = {"MCP-HEADERS": "null"} diff --git a/tests/unit/utils/test_types.py b/tests/unit/utils/test_types.py index 7ea7081ee..a2429baa5 100644 --- a/tests/unit/utils/test_types.py +++ b/tests/unit/utils/test_types.py @@ -1,6 +1,6 @@ """Unit tests for functions defined in utils/types.py.""" -from unittest.mock import Mock +from pytest_mock import MockerFixture from utils.types import GraniteToolParser @@ -31,24 +31,26 @@ def test_get_tool_calls_from_completion_message_when_none(self) -> None: assert tool_parser is not None, "tool parser was not returned" assert tool_parser.get_tool_calls(None) == [], "get_tool_calls should return []" - def test_get_tool_calls_from_completion_message_when_not_none(self) -> None: + def test_get_tool_calls_from_completion_message_when_not_none( + self, mocker: MockerFixture + ) -> None: """Test that get_tool_calls returns an empty array when CompletionMessage has no tool_calls.""" # pylint: disable=line-too-long tool_parser = GraniteToolParser.get_parser("granite-3.3-8b-instruct") assert tool_parser is not None, "tool parser was not returned" - completion_message = Mock() + completion_message = mocker.Mock() completion_message.tool_calls = [] assert not tool_parser.get_tool_calls( completion_message ), "get_tool_calls should return []" def test_get_tool_calls_from_completion_message_when_message_has_tool_calls( - self, + self, mocker: MockerFixture ) -> None: """Test that get_tool_calls returns the tool_calls when CompletionMessage has tool_calls.""" tool_parser = GraniteToolParser.get_parser("granite-3.3-8b-instruct") assert tool_parser is not None, "tool parser was not returned" - completion_message = Mock() - tool_calls = [Mock(tool_name="tool-1"), Mock(tool_name="tool-2")] + completion_message = mocker.Mock() + tool_calls = [mocker.Mock(tool_name="tool-1"), mocker.Mock(tool_name="tool-2")] completion_message.tool_calls = tool_calls assert ( tool_parser.get_tool_calls(completion_message) == tool_calls From aa7b22b384b4756fe30f767e21675dbe4212d49f Mon Sep 17 00:00:00 2001 From: Maxim Svistunov Date: Tue, 21 Oct 2025 22:39:35 +0200 Subject: [PATCH 2/2] Disallow importing unittest using ruff --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4d4d68fbc..1c4d4b934 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,3 +185,7 @@ build-backend = "pdm.backend" [tool.pylint."MESSAGES CONTROL"] disable = ["R0801"] + +[tool.ruff] +[tool.ruff.lint.flake8-tidy-imports] +banned-api = { "unittest" = { msg = "use pytest instead of unittest" }, "unittest.mock" = { msg = "use pytest-mock instead of unittest.mock" } }