diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d7629a02e..60539b43e 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -175,12 +175,26 @@ def select_model_and_provider_id( models: ModelListResponse, query_request: QueryRequest ) -> tuple[str, str | None]: """Select the model ID and provider ID based on the request or available models.""" + # If model_id and provider_id are provided in the request, use them model_id = query_request.model provider_id = query_request.provider - # TODO(lucasagomes): support default model selection via configuration - if not model_id: - logger.info("No model specified in request, using the first available LLM") + # If model_id is not provided in the request, check the configuration + if not model_id or not provider_id: + logger.debug( + "No model ID or provider ID specified in request, checking configuration" + ) + model_id = configuration.inference.default_model # type: ignore[reportAttributeAccessIssue] + provider_id = ( + configuration.inference.default_provider # type: ignore[reportAttributeAccessIssue] + ) + + # If no model is specified in the request or configuration, use the first available LLM + if not model_id or not provider_id: + logger.debug( + "No model ID or provider ID specified in request or configuration, " + "using the first available LLM" + ) try: model = next( m @@ -202,7 +216,8 @@ def select_model_and_provider_id( }, ) from e - logger.info("Searching for model: %s, provider: %s", model_id, provider_id) + # Validate that the model_id and provider_id are in the available models + logger.debug("Searching for model: %s, provider: %s", model_id, provider_id) if not any( m.identifier == model_id and m.provider_id == provider_id for m in models ): diff --git a/src/configuration.py b/src/configuration.py index 21eca543e..d374840b3 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -12,6 +12,7 @@ ServiceConfiguration, ModelContextProtocolServer, AuthenticationConfiguration, + InferenceConfiguration, ) logger = logging.getLogger(__name__) @@ -99,5 +100,13 @@ def customization(self) -> Optional[Customization]: ), "logic error: configuration is not loaded" return self._configuration.customization + @property + def inference(self) -> Optional[InferenceConfiguration]: + """Return inference configuration.""" + assert ( + self._configuration is not None + ), "logic error: configuration is not loaded" + return self._configuration.inference + configuration: AppConfig = AppConfig() diff --git a/src/metrics/utils.py b/src/metrics/utils.py index 62f94bb04..29e2bcce4 100644 --- a/src/metrics/utils.py +++ b/src/metrics/utils.py @@ -1,5 +1,6 @@ """Utility functions for metrics handling.""" +from configuration import configuration from client import LlamaStackClientHolder from log import get_logger import metrics @@ -7,9 +8,6 @@ logger = get_logger(__name__) -# TODO(lucasagomes): Change this metric once we are allowed to set the the -# default model/provider via the configuration.The default provider/model -# will be set to 1, and the rest will be set to 0. def setup_model_metrics() -> None: """Perform setup of all metrics related to LLM model and provider.""" client = LlamaStackClientHolder().get_client() @@ -19,14 +17,29 @@ def setup_model_metrics() -> None: if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue] ] + default_model_label = ( + configuration.inference.default_provider, # type: ignore[reportAttributeAccessIssue] + configuration.inference.default_model, # type: ignore[reportAttributeAccessIssue] + ) + for model in models: provider = model.provider_id model_name = model.identifier if provider and model_name: + # If the model/provider combination is the default, set the metric value to 1 + # Otherwise, set it to 0 + default_model_value = 0 label_key = (provider, model_name) - metrics.provider_model_configuration.labels(*label_key).set(1) + if label_key == default_model_label: + default_model_value = 1 + + # Set the metric for the provider/model configuration + metrics.provider_model_configuration.labels(*label_key).set( + default_model_value + ) logger.debug( - "Set provider/model configuration for %s/%s to 1", + "Set provider/model configuration for %s/%s to %d", provider, model_name, + default_model_value, ) diff --git a/src/models/config.py b/src/models/config.py index 90a532757..fec40575e 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -185,6 +185,26 @@ def check_customization_model(self) -> Self: return self +class InferenceConfiguration(BaseModel): + """Inference configuration.""" + + default_model: Optional[str] = None + default_provider: Optional[str] = None + + @model_validator(mode="after") + def check_default_model_and_provider(self) -> Self: + """Check default model and provider.""" + if self.default_model is None and self.default_provider is not None: + raise ValueError( + "Default model must be specified when default provider is set" + ) + if self.default_model is not None and self.default_provider is None: + raise ValueError( + "Default provider must be specified when default model is set" + ) + return self + + class Configuration(BaseModel): """Global service configuration.""" @@ -197,6 +217,7 @@ class Configuration(BaseModel): AuthenticationConfiguration() ) customization: Optional[Customization] = None + inference: Optional[InferenceConfiguration] = InferenceConfiguration() def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 8ace106ab..7e824e995 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -179,30 +179,70 @@ def test_query_endpoint_handler_store_transcript(mocker): _test_query_endpoint_handler(mocker, store_transcript_to_file=True) -def test_select_model_and_provider_id(mocker): +def test_select_model_and_provider_id_from_request(mocker): """Test the select_model_and_provider_id function.""" - mock_client = mocker.Mock() - mock_client.models.list.return_value = [ + mocker.patch( + "metrics.utils.configuration.inference.default_provider", + "default_provider", + ) + mocker.patch( + "metrics.utils.configuration.inference.default_model", + "default_model", + ) + + model_list = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), + mocker.Mock( + identifier="default_model", model_type="llm", provider_id="default_provider" + ), ] + # Create a query request with model and provider specified query_request = QueryRequest( - query="What is OpenStack?", model="model1", provider="provider1" + query="What is OpenStack?", model="model2", provider="provider2" ) - model_id, provider_id = select_model_and_provider_id( - mock_client.models.list(), query_request + # Assert the model and provider from request take precedence from the configuration one + model_id, provider_id = select_model_and_provider_id(model_list, query_request) + + assert model_id == "model2" + assert provider_id == "provider2" + + +def test_select_model_and_provider_id_from_configuration(mocker): + """Test the select_model_and_provider_id function.""" + mocker.patch( + "metrics.utils.configuration.inference.default_provider", + "default_provider", + ) + mocker.patch( + "metrics.utils.configuration.inference.default_model", + "default_model", ) - assert model_id == "model1" - assert provider_id == "provider1" + model_list = [ + mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock( + identifier="default_model", model_type="llm", provider_id="default_provider" + ), + ] + + # Create a query request without model and provider specified + query_request = QueryRequest( + query="What is OpenStack?", + ) + + model_id, provider_id = select_model_and_provider_id(model_list, query_request) + + # Assert that the default model and provider from the configuration are returned + assert model_id == "default_model" + assert provider_id == "default_provider" -def test_select_model_and_provider_id_no_model(mocker): +def test_select_model_and_provider_id_first_from_list(mocker): """Test the select_model_and_provider_id function when no model is specified.""" - mock_client = mocker.Mock() - mock_client.models.list.return_value = [ + model_list = [ mocker.Mock( identifier="not_llm_type", model_type="embedding", provider_id="provider1" ), @@ -216,11 +256,10 @@ def test_select_model_and_provider_id_no_model(mocker): query_request = QueryRequest(query="What is OpenStack?") - model_id, provider_id = select_model_and_provider_id( - mock_client.models.list(), query_request - ) + model_id, provider_id = select_model_and_provider_id(model_list, query_request) - # Assert return the first available LLM model + # Assert return the first available LLM model when no model/provider is + # specified in the request or in the configuration assert model_id == "first_model" assert provider_id == "provider1" diff --git a/tests/unit/metrics/test_utis.py b/tests/unit/metrics/test_utis.py index 5d273c3a3..e3e2c6ab4 100644 --- a/tests/unit/metrics/test_utis.py +++ b/tests/unit/metrics/test_utis.py @@ -8,16 +8,62 @@ def test_setup_model_metrics(mocker): # Mock the LlamaStackAsLibraryClient mock_client = mocker.patch("client.LlamaStackClientHolder.get_client").return_value + mocker.patch( + "metrics.utils.configuration.inference.default_provider", + "default_provider", + ) + mocker.patch( + "metrics.utils.configuration.inference.default_model", + "default_model", + ) mock_metric = mocker.patch("metrics.provider_model_configuration") - fake_model = mocker.Mock( - provider_id="test_provider", - identifier="test_model", + # Mock a model that is the default + model_default = mocker.Mock( + provider_id="default_provider", + identifier="default_model", model_type="llm", ) - mock_client.models.list.return_value = [fake_model] + # Mock a model that is not the default + model_0 = mocker.Mock( + provider_id="test_provider-0", + identifier="test_model-0", + model_type="llm", + ) + # Mock a second model which is not default + model_1 = mocker.Mock( + provider_id="test_provider-1", + identifier="test_model-1", + model_type="llm", + ) + # Mock a model that is not an LLM type, should be ignored + not_llm_model = mocker.Mock( + provider_id="not-llm-provider", + identifier="not-llm-model", + model_type="not-llm", + ) + + # Mock the list of models returned by the client + mock_client.models.list.return_value = [ + model_0, + model_default, + not_llm_model, + model_1, + ] setup_model_metrics() - # Assert that the metric was set correctly - mock_metric.labels("test_provider", "test_model").set.assert_called_once_with(1) + # Check that the provider_model_configuration metric was set correctly + # The default model should have a value of 1, others should be 0 + assert mock_metric.labels.call_count == 3 + mock_metric.assert_has_calls( + [ + mocker.call.labels("test_provider-0", "test_model-0"), + mocker.call.labels().set(0), + mocker.call.labels("default_provider", "default_model"), + mocker.call.labels().set(1), + mocker.call.labels("test_provider-1", "test_model-1"), + mocker.call.labels().set(0), + ], + any_order=False, # Order matters here + ) diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index 293f48a1b..38a36119d 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -23,6 +23,7 @@ TLSConfiguration, ModelContextProtocolServer, DataCollectorConfiguration, + InferenceConfiguration, ) from utils.checks import InvalidConfigurationError @@ -131,6 +132,53 @@ def test_llama_stack_wrong_configuration_no_config_file() -> None: LlamaStackConfiguration(use_as_library_client=True) +def test_inference_constructor() -> None: + """ + Test the InferenceConfiguration constructor with valid + parameters. + """ + # Test with no default provider or model, as they are optional + inference_config = InferenceConfiguration() + assert inference_config is not None + assert inference_config.default_provider is None + assert inference_config.default_model is None + + # Test with default provider and model + inference_config = InferenceConfiguration( + default_provider="default_provider", + default_model="default_model", + ) + assert inference_config is not None + assert inference_config.default_provider == "default_provider" + assert inference_config.default_model == "default_model" + + +def test_inference_default_model_missing() -> None: + """ + Test case where only default provider is set, should fail + """ + with pytest.raises( + ValueError, + match="Default model must be specified when default provider is set", + ): + InferenceConfiguration( + default_provider="default_provider", + ) + + +def test_inference_default_provider_missing() -> None: + """ + Test case where only default model is set, should fail + """ + with pytest.raises( + ValueError, + match="Default provider must be specified when default model is set", + ): + InferenceConfiguration( + default_model="default_model", + ) + + def test_user_data_collection_feedback_enabled() -> None: """Test the UserDataCollection constructor for feedback.""" # correct configuration @@ -426,6 +474,10 @@ def test_dump_configuration(tmp_path) -> None: ), mcp_servers=[], customization=None, + inference=InferenceConfiguration( + default_provider="default_provider", + default_model="default_model", + ), ) assert cfg is not None dump_file = tmp_path / "test.json" @@ -443,6 +495,8 @@ def test_dump_configuration(tmp_path) -> None: assert "user_data_collection" in content assert "mcp_servers" in content assert "authentication" in content + assert "customization" in content + assert "inference" in content # check the whole deserialized JSON file content assert content == { @@ -489,6 +543,10 @@ def test_dump_configuration(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, + "inference": { + "default_provider": "default_provider", + "default_model": "default_model", + }, } @@ -516,6 +574,10 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None: ), mcp_servers=mcp_servers, customization=None, + inference=InferenceConfiguration( + default_provider="default_provider", + default_model="default_model", + ), ) dump_file = tmp_path / "test.json" cfg.dump(dump_file) @@ -580,6 +642,10 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, + "inference": { + "default_provider": "default_provider", + "default_model": "default_model", + }, } @@ -610,6 +676,10 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None: ), mcp_servers=mcp_servers, customization=None, + inference=InferenceConfiguration( + default_provider="default_provider", + default_model="default_model", + ), ) dump_file = tmp_path / "test.json" cfg.dump(dump_file) @@ -690,6 +760,10 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None: "k8s_cluster_api": None, }, "customization": None, + "inference": { + "default_provider": "default_provider", + "default_model": "default_model", + }, }