Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,26 @@ def select_model_and_provider_id(
models: ModelListResponse, query_request: QueryRequest
) -> tuple[str, str | None]:
"""Select the model ID and provider ID based on the request or available models."""
# If model_id and provider_id are provided in the request, use them
model_id = query_request.model
provider_id = query_request.provider

# TODO(lucasagomes): support default model selection via configuration
if not model_id:
logger.info("No model specified in request, using the first available LLM")
# If model_id is not provided in the request, check the configuration
if not model_id or not provider_id:
logger.debug(
"No model ID or provider ID specified in request, checking configuration"
)
model_id = configuration.inference.default_model # type: ignore[reportAttributeAccessIssue]
provider_id = (
configuration.inference.default_provider # type: ignore[reportAttributeAccessIssue]
)

# If no model is specified in the request or configuration, use the first available LLM
if not model_id or not provider_id:
logger.debug(
"No model ID or provider ID specified in request or configuration, "
"using the first available LLM"
)
try:
model = next(
m
Expand All @@ -202,7 +216,8 @@ def select_model_and_provider_id(
},
) from e

logger.info("Searching for model: %s, provider: %s", model_id, provider_id)
# Validate that the model_id and provider_id are in the available models
logger.debug("Searching for model: %s, provider: %s", model_id, provider_id)
if not any(
m.identifier == model_id and m.provider_id == provider_id for m in models
):
Expand Down
9 changes: 9 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ServiceConfiguration,
ModelContextProtocolServer,
AuthenticationConfiguration,
InferenceConfiguration,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,5 +100,13 @@ def customization(self) -> Optional[Customization]:
), "logic error: configuration is not loaded"
return self._configuration.customization

@property
def inference(self) -> Optional[InferenceConfiguration]:
"""Return inference configuration."""
assert (
self._configuration is not None
), "logic error: configuration is not loaded"
return self._configuration.inference


configuration: AppConfig = AppConfig()
23 changes: 18 additions & 5 deletions src/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Utility functions for metrics handling."""

from configuration import configuration
from client import LlamaStackClientHolder
from log import get_logger
import metrics

logger = get_logger(__name__)


# TODO(lucasagomes): Change this metric once we are allowed to set the the
# default model/provider via the configuration.The default provider/model
# will be set to 1, and the rest will be set to 0.
def setup_model_metrics() -> None:
"""Perform setup of all metrics related to LLM model and provider."""
client = LlamaStackClientHolder().get_client()
Expand All @@ -19,14 +17,29 @@ def setup_model_metrics() -> None:
if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
]

default_model_label = (
configuration.inference.default_provider, # type: ignore[reportAttributeAccessIssue]
configuration.inference.default_model, # type: ignore[reportAttributeAccessIssue]
)

for model in models:
provider = model.provider_id
model_name = model.identifier
if provider and model_name:
# If the model/provider combination is the default, set the metric value to 1
# Otherwise, set it to 0
default_model_value = 0
label_key = (provider, model_name)
metrics.provider_model_configuration.labels(*label_key).set(1)
if label_key == default_model_label:
default_model_value = 1

# Set the metric for the provider/model configuration
metrics.provider_model_configuration.labels(*label_key).set(
default_model_value
)
logger.debug(
"Set provider/model configuration for %s/%s to 1",
"Set provider/model configuration for %s/%s to %d",
provider,
model_name,
default_model_value,
)
21 changes: 21 additions & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,26 @@ def check_customization_model(self) -> Self:
return self


class InferenceConfiguration(BaseModel):
"""Inference configuration."""

default_model: Optional[str] = None
default_provider: Optional[str] = None

@model_validator(mode="after")
def check_default_model_and_provider(self) -> Self:
"""Check default model and provider."""
if self.default_model is None and self.default_provider is not None:
raise ValueError(
"Default model must be specified when default provider is set"
)
if self.default_model is not None and self.default_provider is None:
raise ValueError(
"Default provider must be specified when default model is set"
)
return self


class Configuration(BaseModel):
"""Global service configuration."""

Expand All @@ -197,6 +217,7 @@ class Configuration(BaseModel):
AuthenticationConfiguration()
)
customization: Optional[Customization] = None
inference: Optional[InferenceConfiguration] = InferenceConfiguration()

def dump(self, filename: str = "configuration.json") -> None:
"""Dump actual configuration into JSON file."""
Expand Down
69 changes: 54 additions & 15 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,30 +179,70 @@ def test_query_endpoint_handler_store_transcript(mocker):
_test_query_endpoint_handler(mocker, store_transcript_to_file=True)


def test_select_model_and_provider_id(mocker):
def test_select_model_and_provider_id_from_request(mocker):
"""Test the select_model_and_provider_id function."""
mock_client = mocker.Mock()
mock_client.models.list.return_value = [
mocker.patch(
"metrics.utils.configuration.inference.default_provider",
"default_provider",
)
mocker.patch(
"metrics.utils.configuration.inference.default_model",
"default_model",
)

model_list = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"),
mocker.Mock(
identifier="default_model", model_type="llm", provider_id="default_provider"
),
]

# Create a query request with model and provider specified
query_request = QueryRequest(
query="What is OpenStack?", model="model1", provider="provider1"
query="What is OpenStack?", model="model2", provider="provider2"
)

model_id, provider_id = select_model_and_provider_id(
mock_client.models.list(), query_request
# Assert the model and provider from request take precedence from the configuration one
model_id, provider_id = select_model_and_provider_id(model_list, query_request)

assert model_id == "model2"
assert provider_id == "provider2"


def test_select_model_and_provider_id_from_configuration(mocker):
"""Test the select_model_and_provider_id function."""
mocker.patch(
"metrics.utils.configuration.inference.default_provider",
"default_provider",
)
mocker.patch(
"metrics.utils.configuration.inference.default_model",
"default_model",
)

assert model_id == "model1"
assert provider_id == "provider1"
model_list = [
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
mocker.Mock(
identifier="default_model", model_type="llm", provider_id="default_provider"
),
]

# Create a query request without model and provider specified
query_request = QueryRequest(
query="What is OpenStack?",
)

model_id, provider_id = select_model_and_provider_id(model_list, query_request)

# Assert that the default model and provider from the configuration are returned
assert model_id == "default_model"
assert provider_id == "default_provider"


def test_select_model_and_provider_id_no_model(mocker):
def test_select_model_and_provider_id_first_from_list(mocker):
"""Test the select_model_and_provider_id function when no model is specified."""
mock_client = mocker.Mock()
mock_client.models.list.return_value = [
model_list = [
mocker.Mock(
identifier="not_llm_type", model_type="embedding", provider_id="provider1"
),
Expand All @@ -216,11 +256,10 @@ def test_select_model_and_provider_id_no_model(mocker):

query_request = QueryRequest(query="What is OpenStack?")

model_id, provider_id = select_model_and_provider_id(
mock_client.models.list(), query_request
)
model_id, provider_id = select_model_and_provider_id(model_list, query_request)

# Assert return the first available LLM model
# Assert return the first available LLM model when no model/provider is
# specified in the request or in the configuration
assert model_id == "first_model"
assert provider_id == "provider1"

Expand Down
58 changes: 52 additions & 6 deletions tests/unit/metrics/test_utis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,62 @@ def test_setup_model_metrics(mocker):

# Mock the LlamaStackAsLibraryClient
mock_client = mocker.patch("client.LlamaStackClientHolder.get_client").return_value
mocker.patch(
"metrics.utils.configuration.inference.default_provider",
"default_provider",
)
mocker.patch(
"metrics.utils.configuration.inference.default_model",
"default_model",
)

mock_metric = mocker.patch("metrics.provider_model_configuration")
fake_model = mocker.Mock(
provider_id="test_provider",
identifier="test_model",
# Mock a model that is the default
model_default = mocker.Mock(
provider_id="default_provider",
identifier="default_model",
model_type="llm",
)
mock_client.models.list.return_value = [fake_model]
# Mock a model that is not the default
model_0 = mocker.Mock(
provider_id="test_provider-0",
identifier="test_model-0",
model_type="llm",
)
# Mock a second model which is not default
model_1 = mocker.Mock(
provider_id="test_provider-1",
identifier="test_model-1",
model_type="llm",
)
# Mock a model that is not an LLM type, should be ignored
not_llm_model = mocker.Mock(
provider_id="not-llm-provider",
identifier="not-llm-model",
model_type="not-llm",
)

# Mock the list of models returned by the client
mock_client.models.list.return_value = [
model_0,
model_default,
not_llm_model,
model_1,
]

setup_model_metrics()

# Assert that the metric was set correctly
mock_metric.labels("test_provider", "test_model").set.assert_called_once_with(1)
# Check that the provider_model_configuration metric was set correctly
# The default model should have a value of 1, others should be 0
assert mock_metric.labels.call_count == 3
mock_metric.assert_has_calls(
[
mocker.call.labels("test_provider-0", "test_model-0"),
mocker.call.labels().set(0),
mocker.call.labels("default_provider", "default_model"),
mocker.call.labels().set(1),
mocker.call.labels("test_provider-1", "test_model-1"),
mocker.call.labels().set(0),
],
any_order=False, # Order matters here
)
Loading