diff --git a/.github/workflows/e2e_tests.yaml b/.github/workflows/e2e_tests.yaml index 9450def8..03fc8e19 100644 --- a/.github/workflows/e2e_tests.yaml +++ b/.github/workflows/e2e_tests.yaml @@ -60,7 +60,7 @@ jobs: cp "${CONFIG_FILE}" lightspeed-stack.yaml echo "✅ Configuration loaded successfully" - + - name: Select and configure run.yaml env: CONFIG_ENVIRONMENT: ${{ matrix.environment || 'ci' }} @@ -100,7 +100,7 @@ jobs: echo "=== Configuration Summary ===" echo "Deployment mode: ${{ matrix.mode }}" echo "Environment: ${{ matrix.environment }}" - echo "Source config: tests/e2e/configs/run-ci.yaml" + echo "Source config: tests/e2e/configs/run-${{ matrix.environment }}.yaml" echo "" echo "=== Configuration Preview ===" echo "Providers: $(grep -c "provider_id:" run.yaml)" diff --git a/.github/workflows/e2e_tests_providers.yaml b/.github/workflows/e2e_tests_providers.yaml index 65099e3b..8209a3b0 100644 --- a/.github/workflows/e2e_tests_providers.yaml +++ b/.github/workflows/e2e_tests_providers.yaml @@ -52,6 +52,21 @@ jobs: echo "=== Recent commits ===" git log --oneline -5 + - name: Add Azure Entra ID config block to all test configs + if: matrix.environment == 'azure' + run: | + echo "Adding azure_entra_id configuration block to all test configs..." + for config in tests/e2e/configuration/*/lightspeed-stack*.yaml; do + if [ -f "$config" ]; then + echo "" >> "$config" + echo "azure_entra_id:" >> "$config" + echo " tenant_id: \${env.TENANT_ID}" >> "$config" + echo " client_id: \${env.CLIENT_ID}" >> "$config" + echo " client_secret: \${env.CLIENT_SECRET}" >> "$config" + echo "✅ Added to: $config" + fi + done + - name: Load lightspeed-stack.yaml configuration run: | MODE="${{ matrix.mode }}" @@ -66,32 +81,6 @@ jobs: cp "${CONFIG_FILE}" lightspeed-stack.yaml echo "✅ Configuration loaded successfully" - - - name: Get Azure API key (access token) - if: matrix.environment == 'azure' - id: azure_token - env: - CLIENT_ID: ${{ secrets.CLIENT_ID }} - CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }} - TENANT_ID: ${{ secrets.TENANT_ID }} - run: | - echo "Requesting Azure API token..." - RESPONSE=$(curl -s -X POST \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -d "client_id=$CLIENT_ID&scope=https://cognitiveservices.azure.com/.default&client_secret=$CLIENT_SECRET&grant_type=client_credentials" \ - "https://login.microsoftonline.com/$TENANT_ID/oauth2/v2.0/token") - - echo "Response received. Extracting access_token..." - ACCESS_TOKEN=$(echo "$RESPONSE" | jq -r '.access_token') - - if [ -z "$ACCESS_TOKEN" ] || [ "$ACCESS_TOKEN" == "null" ]; then - echo "❌ Failed to obtain Azure access token. Response:" - echo "$RESPONSE" - exit 1 - fi - - echo "✅ Successfully obtained Azure access token." - echo "AZURE_API_KEY=$ACCESS_TOKEN" >> $GITHUB_ENV - name: Save VertexAI service account key to file if: matrix.environment == 'vertexai' @@ -198,7 +187,9 @@ jobs: if: matrix.mode == 'server' env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - AZURE_API_KEY: ${{ env.AZURE_API_KEY }} + TENANT_ID: ${{ secrets.TENANT_ID }} + CLIENT_ID: ${{ secrets.CLIENT_ID }} + CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }} VERTEX_AI_LOCATION: ${{ secrets.VERTEX_AI_LOCATION }} VERTEX_AI_PROJECT: ${{ secrets.VERTEX_AI_PROJECT }} GOOGLE_APPLICATION_CREDENTIALS: ${{ env.GOOGLE_APPLICATION_CREDENTIALS }} @@ -227,7 +218,9 @@ jobs: if: matrix.mode == 'library' env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - AZURE_API_KEY: ${{ env.AZURE_API_KEY }} + TENANT_ID: ${{ secrets.TENANT_ID }} + CLIENT_ID: ${{ secrets.CLIENT_ID }} + CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }} VERTEX_AI_LOCATION: ${{ secrets.VERTEX_AI_LOCATION }} VERTEX_AI_PROJECT: ${{ secrets.VERTEX_AI_PROJECT }} GOOGLE_APPLICATION_CREDENTIALS: ${{ env.GOOGLE_APPLICATION_CREDENTIALS }} diff --git a/Makefile b/Makefile index 1651398b..a4cdccfb 100644 --- a/Makefile +++ b/Makefile @@ -10,8 +10,17 @@ PYTHON_REGISTRY = pypi TORCH_VERSION := 2.7.1 +# Default configuration files (override with: make run CONFIG=myconfig.yaml) +CONFIG ?= lightspeed-stack.yaml +LLAMA_STACK_CONFIG ?= run.yaml + run: ## Run the service locally - uv run src/lightspeed_stack.py + uv run src/lightspeed_stack.py -c $(CONFIG) + +run-llama-stack: ## Start Llama Stack with enriched config (for local service mode) + uv run src/llama_stack_configuration.py -c $(CONFIG) -i $(LLAMA_STACK_CONFIG) -o $(LLAMA_STACK_CONFIG) && \ + AZURE_API_KEY=$$(grep '^AZURE_API_KEY=' .env | cut -d'=' -f2-) \ + uv run llama stack run $(LLAMA_STACK_CONFIG) test-unit: ## Run the unit tests @echo "Running unit tests..." diff --git a/README.md b/README.md index ac6edce6..2a0a65e1 100644 --- a/README.md +++ b/README.md @@ -195,8 +195,8 @@ __Note__: Support for individual models is dependent on the specific inference p | RHOAI (vLLM)| meta-llama/Llama-3.2-1B-Instruct | Yes | remote::vllm | [1](tests/e2e-prow/rhoai/configs/run.yaml) | | RHAIIS (vLLM)| meta-llama/Llama-3.1-8B-Instruct | Yes | remote::vllm | [1](tests/e2e/configs/run-rhaiis.yaml) | | RHEL AI (vLLM)| meta-llama/Llama-3.1-8B-Instruct | Yes | remote::vllm | [1](tests/e2e/configs/run-rhelai.yaml) | -| Azure | gpt-5, gpt-5-mini, gpt-5-nano, gpt-5-chat, gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, o3-mini, o4-mini | Yes | remote::azure | [1](examples/azure-run.yaml) | -| Azure | o1, o1-mini | No | remote::azure | | +| Azure | gpt-5, gpt-5-mini, gpt-5-nano, gpt-4o-mini, o3-mini, o4-mini, o1| Yes | remote::azure | [1](examples/azure-run.yaml) | +| Azure | gpt-5-chat, gpt-4.1, gpt-4.1-mini, gpt-4.1-nano, o1-mini | No or limited | remote::azure | | | VertexAI | google/gemini-2.0-flash, google/gemini-2.5-flash, google/gemini-2.5-pro [^1] | Yes | remote::vertexai | [1](examples/vertexai-run.yaml) | | WatsonX | meta-llama/llama-3-3-70b-instruct | Yes | remote::watsonx | [1](examples/watsonx-run.yaml) | diff --git a/docker-compose-library.yaml b/docker-compose-library.yaml index ea34cff8..37f1ad8a 100644 --- a/docker-compose-library.yaml +++ b/docker-compose-library.yaml @@ -14,13 +14,16 @@ services: - ./run.yaml:/app-root/run.yaml:Z - ${GCP_KEYS_PATH:-./tmp/.gcp-keys-dummy}:/opt/app-root/.gcp-keys:ro environment: + # LLM Provider API Keys - BRAVE_SEARCH_API_KEY=${BRAVE_SEARCH_API_KEY:-} - TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY:-} # OpenAI - OPENAI_API_KEY=${OPENAI_API_KEY} - E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4o-mini} - # Azure - - AZURE_API_KEY=${AZURE_API_KEY:-} + # Azure Entra ID credentials (AZURE_API_KEY is obtained dynamically in Python) + - TENANT_ID=${TENANT_ID:-} + - CLIENT_ID=${CLIENT_ID:-} + - CLIENT_SECRET=${CLIENT_SECRET:-} # RHAIIS - RHAIIS_URL=${RHAIIS_URL:-} - RHAIIS_API_KEY=${RHAIIS_API_KEY:-} diff --git a/docker-compose.yaml b/docker-compose.yaml index d34d662a..d97b0779 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -11,14 +11,17 @@ services: volumes: - ./run.yaml:/opt/app-root/run.yaml:Z - ${GCP_KEYS_PATH:-./tmp/.gcp-keys-dummy}:/opt/app-root/.gcp-keys:ro + - ./lightspeed-stack.yaml:/opt/app-root/lightspeed-stack.yaml:z environment: - BRAVE_SEARCH_API_KEY=${BRAVE_SEARCH_API_KEY:-} - TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY:-} # OpenAI - OPENAI_API_KEY=${OPENAI_API_KEY} - E2E_OPENAI_MODEL=${E2E_OPENAI_MODEL:-gpt-4o-mini} - # Azure - - AZURE_API_KEY=${AZURE_API_KEY} + # Azure Entra ID credentials (AZURE_API_KEY is passed via provider_data at request time) + - TENANT_ID=${TENANT_ID:-} + - CLIENT_ID=${CLIENT_ID:-} + - CLIENT_SECRET=${CLIENT_SECRET:-} # RHAIIS - RHAIIS_URL=${RHAIIS_URL} - RHAIIS_API_KEY=${RHAIIS_API_KEY} @@ -55,10 +58,13 @@ services: ports: - "8080:8080" volumes: - - ./lightspeed-stack.yaml:/app-root/lightspeed-stack.yaml:Z + - ./lightspeed-stack.yaml:/app-root/lightspeed-stack.yaml:z environment: - OPENAI_API_KEY=${OPENAI_API_KEY} - - AZURE_API_KEY=${AZURE_API_KEY} + # Azure Entra ID credentials (AZURE_API_KEY is obtained dynamically) + - TENANT_ID=${TENANT_ID:-} + - CLIENT_ID=${CLIENT_ID:-} + - CLIENT_SECRET=${CLIENT_SECRET:-} depends_on: llama-stack: condition: service_healthy diff --git a/docs/providers.md b/docs/providers.md index 0f3891e1..87dbee9d 100644 --- a/docs/providers.md +++ b/docs/providers.md @@ -65,6 +65,115 @@ Red Hat providers: | RHAIIS (vllm) | 3.2.3 (on RHEL 9.20250429.0.4) | remote | `openai` | ✅ | | RHEL AI (vllm) | 1.5.2 | remote | `openai` | ✅ | +### Azure Provider - Entra ID Authentication Guide + +Lightspeed Core supports secure authentication using Microsoft Entra ID (formerly Azure Active Directory) for the Azure Inference Provider. This allows you to connect to Azure OpenAI without using API keys, by authenticating through your organization's Azure identity. + +#### Lightspeed Core Configuration Requirements + +To enable Entra ID authentication, the `azure_entra_id` block must be included in your LCS configuration. The `tenant_id`, `client_id`, and `client_secret` attributes are required: + +| Attribute | Required | Description | +|-----------|----------|-------------| +| `tenant_id` | Yes | Azure AD tenant ID | +| `client_id` | Yes | Application (client) ID | +| `client_secret` | Yes | Client secret value | +| `scope` | No | Token scope (default: `https://cognitiveservices.azure.com/.default`) | + +Example of LCS config section: + +```yaml +azure_entra_id: + tenant_id: ${env.TENANT_ID} + client_id: ${env.CLIENT_ID} + client_secret: ${env.CLIENT_SECRET} + # scope: "https://cognitiveservices.azure.com/.default" # optional, this is the default +``` + +#### Llama Stack Configuration Requirements + +Because Lightspeed builds on top of Llama Stack, certain configuration fields are required to satisfy the base Llama Stack schema. The config block for the Azure inference provider **must** include `api_key`, `api_base`, and `api_version` — Llama Stack will fail to start if any of these are missing. + +**Important:** The `api_key` field must be set to `${env.AZURE_API_KEY}` exactly as shown below. This is not optional — Lightspeed uses this specific environment variable name as a placeholder for injection of the Entra ID access token. Using a different variable name will break the authentication flow. + +```yaml +inference: + - provider_id: azure + provider_type: remote::azure + config: + api_key: ${env.AZURE_API_KEY} # Must be exactly this - placeholder for Entra ID token + api_base: ${env.AZURE_API_BASE} + api_version: 2025-01-01-preview +``` + +**How it works:** At startup, Lightspeed acquires an Entra ID access token and stores it in the `AZURE_API_KEY` environment variable. When Llama Stack initializes, it reads the config, substitutes `${env.AZURE_API_KEY}` with the token value, and uses it to authenticate with Azure OpenAI. Llama Stack also calls `models.list()` during initialization to validate provider connectivity, which is why the token must be available before client initialization. + +#### Access Token Lifecycle and Management + +**Library mode startup:** +1. Lightspeed reads your Entra ID configuration +2. Acquires an initial access token from Microsoft Entra ID +3. Stores the token in the `AZURE_API_KEY` environment variable +4. **Then** initializes the Llama Stack library client + +This ordering is critical because Llama Stack calls `models.list()` during initialization to validate provider connectivity. If the token is not set before client initialization, Azure requests will fail with authentication errors. + +**Service mode startup:** + +When running Llama Stack as a separate service, Lightspeed runs a pre-startup script that: +1. Reads the Entra ID configuration +2. Acquires an initial access token +3. Writes the token to the `AZURE_API_KEY` environment variable +4. **Then** Llama Stack service starts + +This initial token is used solely for the `models.list()` validation call during Llama Stack startup. After startup, Lightspeed manages token refresh independently and passes fresh tokens via request headers. + +**During inference requests:** +1. Before each request, Lightspeed checks if the token has expired +2. If expired, a new token is automatically acquired and the environment variable is updated +3. For library mode: the Llama Stack client is reloaded to pick up the new token +4. For service mode: the token is passed via `X-LlamaStack-Provider-Data` request headers + +**Token security:** +- Access tokens are wrapped in `SecretStr` to prevent accidental logging +- Tokens are stored only in the `AZURE_API_KEY` environment variable (single source of truth) +- Each Uvicorn worker maintains its own token lifecycle independently + +**Token validity:** +- Access tokens are typically valid for 1 hour +- Lightspeed refreshes tokens proactively before expiration (with a safety margin) +- Token refresh happens automatically in the background without manual intervention + +#### Local Deployment Examples + +**Prerequisites:** Export the required Azure Entra ID environment variables in your terminal(s): + +```bash +export TENANT_ID="your-tenant-id" +export CLIENT_ID="your-client-id" +export CLIENT_SECRET="your-client-secret" +``` + +**Library mode** (Llama Stack embedded in Lightspeed): + +```bash +# From project root +make run CONFIG=examples/lightspeed-stack-azure-entraid-lib.yaml +``` + +**Service mode** (Llama Stack as separate service): + +```bash +# Terminal 1: Start Llama Stack service with Azure Entra ID config +make run-llama-stack CONFIG=examples/lightspeed-stack-azure-entraid-service.yaml LLAMA_STACK_CONFIG=examples/azure-run.yaml + +# Terminal 2: Start Lightspeed (after Llama Stack is ready) +make run CONFIG=examples/lightspeed-stack-azure-entraid-service.yaml +``` + +**Note:** The `make run-llama-stack` command accepts two variables: +- `CONFIG` - Lightspeed configuration file (default: `lightspeed-stack.yaml`) +- `LLAMA_STACK_CONFIG` - Llama Stack configuration file to enrich and run (default: `run.yaml`) --- diff --git a/examples/azure-run.yaml b/examples/azure-run.yaml index a50301ad..37afcd18 100644 --- a/examples/azure-run.yaml +++ b/examples/azure-run.yaml @@ -1,128 +1,137 @@ -version: '2' -image_name: minimal-viable-llama-stack-configuration - apis: - agents + - batches - datasetio - eval - files - inference - - post_training - safety - scoring - - telemetry - tool_runtime - vector_io benchmarks: [] -container_image: null +conversations_store: + db_path: ~/.llama/storage/conversations.db + type: sqlite datasets: [] -external_providers_dir: null +image_name: starter inference_store: - db_path: .llama/distributions/ollama/inference_store.db + db_path: ~/.llama/storage/inference-store.db type: sqlite -logging: null metadata_store: - db_path: .llama/distributions/ollama/registry.db - namespace: null + db_path: ~/.llama/storage/registry.db type: sqlite providers: - files: - - provider_id: localfs - provider_type: inline::localfs - config: - storage_dir: /tmp/llama-stack-files - metadata_store: - type: sqlite - db_path: .llama/distributions/ollama/files_metadata.db agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - db_path: .llama/distributions/ollama/agents_store.db - namespace: null - type: sqlite - responses_store: - db_path: .llama/distributions/ollama/responses_store.db - type: sqlite + - config: + persistence: + agent_state: + backend: kv_default + namespace: agents_state + responses: + backend: sql_default + table_name: agents_responses + provider_id: meta-reference + provider_type: inline::meta-reference + batches: + - config: + kvstore: + backend: kv_default + namespace: batches_store + provider_id: reference + provider_type: inline::reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - db_path: .llama/distributions/ollama/huggingface_datasetio.db - namespace: null - type: sqlite - - provider_id: localfs - provider_type: inline::localfs - config: - kvstore: - db_path: .llama/distributions/ollama/localfs_datasetio.db - namespace: null - type: sqlite + - config: + kvstore: + backend: kv_default + namespace: huggingface_datasetio + provider_id: huggingface + provider_type: remote::huggingface + - config: + kvstore: + backend: kv_default + namespace: localfs_datasetio + provider_id: localfs + provider_type: inline::localfs eval: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - kvstore: - db_path: .llama/distributions/ollama/meta_reference_eval.db - namespace: null - type: sqlite + - config: + kvstore: + backend: kv_default + namespace: eval_store + provider_id: meta-reference + provider_type: inline::meta-reference + files: + - config: + metadata_store: + backend: sql_default + table_name: files_metadata + storage_dir: ~/.llama/storage + provider_id: meta-reference-files + provider_type: inline::localfs inference: - - provider_id: azure - provider_type: remote::azure - config: - api_key: ${env.AZURE_API_KEY} + - config: api_base: https://ols-test.openai.azure.com/ - api_version: 2024-02-15-preview - api_type: ${env.AZURE_API_TYPE:=} - post_training: - - provider_id: huggingface - provider_type: inline::huggingface-gpu - config: - checkpoint_format: huggingface - device: cpu - distributed_backend: null - dpo_output_dir: "." + api_key: ${env.AZURE_API_KEY} + api_version: 2025-01-01-preview + provider_id: azure + provider_type: remote::azure safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] + - config: + excluded_categories: [] + provider_id: llama-guard + provider_type: inline::llama-guard scoring: - - provider_id: basic - provider_type: inline::basic - config: {} - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - config: {} - - provider_id: braintrust - provider_type: inline::braintrust - config: - openai_api_key: '********' - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: 'lightspeed-stack-telemetry' - sinks: sqlite - sqlite_db_path: .llama/distributions/ollama/trace_store.db + - config: {} + provider_id: basic + provider_type: inline::basic + - config: {} + provider_id: llm-as-judge + provider_type: inline::llm-as-judge tool_runtime: - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} + - config: {} + provider_id: rag-runtime + provider_type: inline::rag-runtime + vector_io: + - config: + persistence: + backend: kv_default + namespace: faiss_store + provider_id: faiss + provider_type: inline::faiss +registered_resources: + benchmarks: [] + datasets: [] + models: [] + scoring_fns: [] + shields: [] + tool_groups: + - provider_id: rag-runtime + toolgroup_id: builtin::rag + vector_dbs: [] scoring_fns: [] server: - auth: null - host: null port: 8321 - quota: null - tls_cafile: null - tls_certfile: null - tls_keyfile: null -shields: [] -models: - - model_id: gpt-4o-mini - model_type: llm - provider_id: azure - provider_model_id: gpt-4o-mini \ No newline at end of file +storage: + backends: + kv_default: + db_path: ~/.llama/storage/kv_store.db + type: kv_sqlite + sql_default: + db_path: ~/.llama/storage/sql_store.db + type: sql_sqlite + stores: + conversations: + backend: sql_default + table_name: openai_conversations + inference: + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + table_name: inference_store + metadata: + backend: kv_default + namespace: registry + prompts: + backend: kv_default + namespace: prompts +version: 2 diff --git a/examples/lightspeed-stack-azure-entraid-lib.yaml b/examples/lightspeed-stack-azure-entraid-lib.yaml new file mode 100644 index 00000000..47932ac3 --- /dev/null +++ b/examples/lightspeed-stack-azure-entraid-lib.yaml @@ -0,0 +1,30 @@ +name: Lightspeed Core Service (LCS) +service: + host: 0.0.0.0 + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + # Uses a remote llama-stack service + # The instance would have already been started with a llama-stack-run.yaml file + # use_as_library_client: false + # Alternative for "as library use" + use_as_library_client: true + library_client_config_path: examples/azure-run.yaml + # url: http://localhost:8321 + # api_key: xyzzy +user_data_collection: + feedback_enabled: true + feedback_storage: "/tmp/data/feedback" + transcripts_enabled: true + transcripts_storage: "/tmp/data/transcripts" + +authentication: + module: "noop" + +azure_entra_id: + tenant_id: ${env.TENANT_ID} + client_id: ${env.CLIENT_ID} + client_secret: ${env.CLIENT_SECRET} diff --git a/examples/lightspeed-stack-azure-entraid-service.yaml b/examples/lightspeed-stack-azure-entraid-service.yaml new file mode 100644 index 00000000..fcbbc121 --- /dev/null +++ b/examples/lightspeed-stack-azure-entraid-service.yaml @@ -0,0 +1,30 @@ +name: Lightspeed Core Service (LCS) +service: + host: 0.0.0.0 + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + # Uses a remote llama-stack service + # The instance would have already been started with a llama-stack-run.yaml file + use_as_library_client: false + # Alternative for "as library use" + # use_as_library_client: true + # library_client_config_path: + url: http://localhost:8321 + api_key: xyzzy +user_data_collection: + feedback_enabled: true + feedback_storage: "/tmp/data/feedback" + transcripts_enabled: true + transcripts_storage: "/tmp/data/transcripts" + +authentication: + module: "noop" + +azure_entra_id: + tenant_id: ${env.TENANT_ID} + client_id: ${env.CLIENT_ID} + client_secret: ${env.CLIENT_SECRET} diff --git a/pyproject.toml b/pyproject.toml index bfcfdfb9..e5670e7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ dependencies = [ "urllib3==2.6.2", # Used for agent card configuration "PyYAML>=6.0.0", + # Used for Azure Entra ID token management + "azure-core", + "azure-identity", ] diff --git a/scripts/llama-stack-entrypoint.sh b/scripts/llama-stack-entrypoint.sh new file mode 100755 index 00000000..a7eeb797 --- /dev/null +++ b/scripts/llama-stack-entrypoint.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Entrypoint for llama-stack container. +# Enriches config with lightspeed dynamic values, then starts llama-stack. + +set -e + +INPUT_CONFIG="${LLAMA_STACK_CONFIG:-/opt/app-root/run.yaml}" +ENRICHED_CONFIG="/opt/app-root/run.yaml" +LIGHTSPEED_CONFIG="${LIGHTSPEED_CONFIG:-/opt/app-root/lightspeed-stack.yaml}" +ENV_FILE="/opt/app-root/.env" + +# Enrich config if lightspeed config exists +if [ -f "$LIGHTSPEED_CONFIG" ]; then + echo "Enriching llama-stack config..." + ENRICHMENT_FAILED=0 + python3 /opt/app-root/llama_stack_configuration.py \ + -c "$LIGHTSPEED_CONFIG" \ + -i "$INPUT_CONFIG" \ + -o "$ENRICHED_CONFIG" \ + -e "$ENV_FILE" 2>&1 || ENRICHMENT_FAILED=1 + + # Source .env if generated (contains AZURE_API_KEY) + if [ -f "$ENV_FILE" ]; then + # shellcheck source=/dev/null + set -a && . "$ENV_FILE" && set +a + fi + + if [ -f "$ENRICHED_CONFIG" ] && [ "$ENRICHMENT_FAILED" -eq 0 ]; then + echo "Using enriched config: $ENRICHED_CONFIG" + exec llama stack run "$ENRICHED_CONFIG" + fi +fi + +echo "Using original config: $INPUT_CONFIG" +exec llama stack run "$INPUT_CONFIG" diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 062a385f..702e8927 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -69,6 +69,8 @@ from utils.token_counter import TokenCounter, extract_and_update_token_metrics from utils.transcripts import store_transcript from utils.types import TurnSummary, content_to_str +from authorization.azure_token_manager import AzureEntraIDManager + logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) @@ -312,6 +314,28 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 user_conversation=user_conversation, query_request=query_request ), ) + + if ( + provider_id == "azure" + and AzureEntraIDManager().is_entra_id_configured + and AzureEntraIDManager().is_token_expired + and AzureEntraIDManager().refresh_token() + ): + if AsyncLlamaStackClientHolder().is_library_client: + client = await AsyncLlamaStackClientHolder().reload_library_client() + else: + azure_config = next( + p.config + for p in await client.providers.list() + if p.provider_type == "remote::azure" + ) + client = AsyncLlamaStackClientHolder().update_provider_data( + { + "azure_api_key": AzureEntraIDManager().access_token.get_secret_value(), + "azure_api_base": str(azure_config.get("api_base")), + } + ) + summary, conversation_id, referenced_documents, token_usage = ( await retrieve_response_func( client, diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index e6b14522..35abe361 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -403,7 +403,6 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche response = await client.responses.create(**create_kwargs) response = cast(OpenAIResponseObject, response) - logger.info("Response: %s", response) logger.debug( "Received response with ID: %s, conversation ID: %s, output items: %d", response.id, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 9bc5b7a7..cbfc2afa 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -49,6 +49,7 @@ from authentication import get_auth_dependency from authentication.interface import AuthTuple from authorization.middleware import authorize +from authorization.azure_token_manager import AzureEntraIDManager from client import AsyncLlamaStackClientHolder from configuration import configuration from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT @@ -890,6 +891,28 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc user_conversation=user_conversation, query_request=query_request ), ) + + if ( + provider_id == "azure" + and AzureEntraIDManager().is_entra_id_configured + and AzureEntraIDManager().is_token_expired + and AzureEntraIDManager().refresh_token() + ): + if AsyncLlamaStackClientHolder().is_library_client: + client = await AsyncLlamaStackClientHolder().reload_library_client() + else: + azure_config = next( + p.config + for p in await client.providers.list() + if p.provider_type == "remote::azure" + ) + client = AsyncLlamaStackClientHolder().update_provider_data( + { + "azure_api_key": AzureEntraIDManager().access_token.get_secret_value(), + "azure_api_base": str(azure_config.get("api_base")), + } + ) + response, conversation_id = await retrieve_response_func( client, llama_stack_model_id, diff --git a/src/app/main.py b/src/app/main.py index 6c161004..0257c600 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from starlette.routing import Mount, Route, WebSocketRoute +from authorization.azure_token_manager import AzureEntraIDManager import metrics import version from app import routers @@ -16,8 +17,8 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from log import get_logger -from models.responses import InternalServerErrorResponse from a2a_storage import A2AStorageFactory +from models.responses import InternalServerErrorResponse from utils.common import register_mcp_servers_async from utils.llama_stack_version import check_llama_stack_version @@ -39,6 +40,16 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: logger, and database before serving requests. """ configuration.load_configuration(os.environ["LIGHTSPEED_STACK_CONFIG_PATH"]) + + azure_config = configuration.configuration.azure_entra_id + if azure_config is not None: + AzureEntraIDManager().set_config(azure_config) + if not AzureEntraIDManager().refresh_token(): + logger.warning( + "Failed to refresh Azure token at startup. " + "Token refresh will be retried on next Azure request." + ) + await AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack) client = AsyncLlamaStackClientHolder().get_client() # check if the Llama Stack version is supported by the service diff --git a/src/authorization/azure_token_manager.py b/src/authorization/azure_token_manager.py new file mode 100644 index 00000000..69b3eca5 --- /dev/null +++ b/src/authorization/azure_token_manager.py @@ -0,0 +1,102 @@ +"""Azure Entra ID token manager for Azure OpenAI authentication.""" + +import logging +import os +import time +from typing import Optional + +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +from azure.identity import ClientSecretCredential, CredentialUnavailableError +from pydantic import SecretStr + +from configuration import AzureEntraIdConfiguration +from utils.types import Singleton + +logger = logging.getLogger(__name__) + +# Refresh token before actual expiration to avoid edge cases +TOKEN_EXPIRATION_LEEWAY = 30 # seconds + + +class AzureEntraIDManager(metaclass=Singleton): + """Manages Azure Entra ID access tokens for Azure OpenAI provider. + + This singleton class handles: + - Token caching and expiration tracking + - Token refresh using client credentials flow + - Configuration management for Entra ID authentication + + The access token is passed via request headers to authenticate + with Azure OpenAI services. + """ + + def __init__(self) -> None: + """Initialize the token manager with empty state.""" + self._expires_on: int = 0 + self._entra_id_config: Optional[AzureEntraIdConfiguration] = None + + def set_config(self, azure_config: AzureEntraIdConfiguration) -> None: + """Set the Azure Entra ID configuration.""" + self._entra_id_config = azure_config + logger.debug("Azure Entra ID configuration set") + + @property + def is_entra_id_configured(self) -> bool: + """Check if Entra ID configuration has been set.""" + return self._entra_id_config is not None + + @property + def is_token_expired(self) -> bool: + """Check if the cached token has expired or is not available.""" + return self._expires_on == 0 or time.time() > self._expires_on + + @property + def access_token(self) -> SecretStr: + """Return the access token from environment variable as SecretStr.""" + return SecretStr(os.environ.get("AZURE_API_KEY", "")) + + def refresh_token(self) -> bool: + """Refresh the cached Azure access token. + + Returns: + bool: True if token was successfully refreshed, False otherwise. + + Raises: + ValueError: If Entra ID configuration has not been set. + """ + if self._entra_id_config is None: + raise ValueError("Azure Entra ID configuration not set") + + logger.info("Refreshing Azure access token") + token_obj = self._retrieve_access_token() + if token_obj: + self._update_access_token(token_obj.token, token_obj.expires_on) + return True + return False + + def _update_access_token(self, token: str, expires_on: int) -> None: + """Update the token in env var and track expiration time.""" + self._expires_on = expires_on - TOKEN_EXPIRATION_LEEWAY + os.environ["AZURE_API_KEY"] = token + expiry_time = time.strftime( + "%Y-%m-%d %H:%M:%S", time.localtime(self._expires_on) + ) + logger.info("Azure access token refreshed, expires at %s", expiry_time) + + def _retrieve_access_token(self) -> Optional[AccessToken]: + """Retrieve a new access token from Azure.""" + if not self._entra_id_config: + return None + + try: + credential = ClientSecretCredential( + tenant_id=self._entra_id_config.tenant_id.get_secret_value(), + client_id=self._entra_id_config.client_id.get_secret_value(), + client_secret=self._entra_id_config.client_secret.get_secret_value(), + ) + return credential.get_token(self._entra_id_config.scope) + + except (ClientAuthenticationError, CredentialUnavailableError): + logger.warning("Failed to retrieve Azure access token") + return None diff --git a/src/client.py b/src/client.py index fd7e3d1d..f17d5fe1 100644 --- a/src/client.py +++ b/src/client.py @@ -1,17 +1,20 @@ """Llama Stack client retrieval class.""" +import json import logging - +import os +import tempfile from typing import Optional -from llama_stack import ( - AsyncLlamaStackAsLibraryClient, # type: ignore -) +import yaml +from llama_stack import AsyncLlamaStackAsLibraryClient # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore + +from configuration import configuration +from llama_stack_configuration import enrich_byok_rag, YamlDumper from models.config import LlamaStackConfiguration from utils.types import Singleton - logger = logging.getLogger(__name__) @@ -19,53 +22,81 @@ class AsyncLlamaStackClientHolder(metaclass=Singleton): """Container for an initialised AsyncLlamaStackClient.""" _lsc: Optional[AsyncLlamaStackClient] = None + _config_path: Optional[str] = None + + @property + def is_library_client(self) -> bool: + """Check if using library mode client.""" + return isinstance(self._lsc, AsyncLlamaStackAsLibraryClient) async def load(self, llama_stack_config: LlamaStackConfiguration) -> None: + """Initialize the Llama Stack client based on configuration.""" + if self._lsc is not None: # early stopping - client already initialized + return + + if llama_stack_config.use_as_library_client: + await self._load_library_client(llama_stack_config) + else: + self._load_service_client(llama_stack_config) + + async def _load_library_client(self, config: LlamaStackConfiguration) -> None: + """Initialize client in library mode. + + Stores the final config path for use in reload. """ - Load and initialize the holder's AsyncLlamaStackClient according to the provided config. + if config.library_client_config_path is None: + raise ValueError( + "Configuration problem: library_client_config_path is not set" + ) + logger.info("Using Llama stack as library client") - If `llama_stack_config.use_as_library_client` is set to True, a - library-mode client is created using - `llama_stack_config.library_client_config_path` and initialized before - being stored. + byok_rag = [b.model_dump() for b in configuration.configuration.byok_rag] - Otherwise, a service-mode client is created using - `llama_stack_config.url` and optional `llama_stack_config.api_key`. - The created client is stored on the instance for later retrieval via - `get_client()`. + if byok_rag: # BYOK RAG configured - enrich and store enriched path + self._config_path = self._enrich_library_config( + config.library_client_config_path, byok_rag + ) + else: # No RAG - store original path + self._config_path = config.library_client_config_path - Parameters: - llama_stack_config (LlamaStackConfiguration): Configuration that - selects client mode and provides either a library client config - path or service connection details (URL and optional API key). + client = AsyncLlamaStackAsLibraryClient(self._config_path) + await client.initialize() + self._lsc = client - Raises: - ValueError: If `use_as_library_client` is True but - `library_client_config_path` is not set. + def _load_service_client(self, config: LlamaStackConfiguration) -> None: + """Initialize client in service mode (remote HTTP).""" + logger.info("Using Llama stack running as a service") + api_key = config.api_key.get_secret_value() if config.api_key else None + self._lsc = AsyncLlamaStackClient(base_url=config.url, api_key=api_key) + + def _enrich_library_config( + self, input_config_path: str, byok_rag: list[dict] + ) -> str: + """Enrich llama-stack config with BYOK RAG settings. + + Only called when BYOK RAG is configured. """ - if llama_stack_config.use_as_library_client is True: - if llama_stack_config.library_client_config_path is not None: - logger.info("Using Llama stack as library client") - client = AsyncLlamaStackAsLibraryClient( - llama_stack_config.library_client_config_path - ) - await client.initialize() - self._lsc = client - else: - msg = "Configuration problem: library_client_config_path option is not set" - logger.error(msg) - # tisnik: use custom exception there - with cause etc. - raise ValueError(msg) - else: - logger.info("Using Llama stack running as a service") - self._lsc = AsyncLlamaStackClient( - base_url=llama_stack_config.url, - api_key=( - llama_stack_config.api_key.get_secret_value() - if llama_stack_config.api_key is not None - else None - ), - ) + try: + with open(input_config_path, "r", encoding="utf-8") as f: + ls_config = yaml.safe_load(f) + except (OSError, yaml.YAMLError) as e: + logger.warning("Failed to read llama-stack config: %s", e) + return input_config_path + + enrich_byok_rag(ls_config, byok_rag) + + enriched_path = os.path.join( + tempfile.gettempdir(), "llama_stack_enriched_config.yaml" + ) + + try: + with open(enriched_path, "w", encoding="utf-8") as f: + yaml.dump(ls_config, f, Dumper=YamlDumper, default_flow_style=False) + logger.info("Wrote enriched llama-stack config to %s", enriched_path) + return enriched_path + except OSError as e: + logger.warning("Failed to write enriched config: %s", e) + return input_config_path def get_client(self) -> AsyncLlamaStackClient: """ @@ -82,3 +113,52 @@ def get_client(self) -> AsyncLlamaStackClient: "AsyncLlamaStackClient has not been initialised. Ensure 'load(..)' has been called." ) return self._lsc + + async def reload_library_client(self) -> AsyncLlamaStackClient: + """Reload library client to pick up env var changes. + + For use with library mode only. + + Returns: + The reloaded client instance. + """ + if not self._config_path: + raise RuntimeError("Cannot reload: config path not set") + + client = AsyncLlamaStackAsLibraryClient(self._config_path) + await client.initialize() + self._lsc = client + return client + + def update_provider_data(self, updates: dict[str, str]) -> AsyncLlamaStackClient: + """Update provider data headers for service client. + + For use with service mode only. + + Args: + updates: Key-value pairs to merge into provider data header. + + Returns: + The updated client instance. + """ + if not self._lsc: + raise RuntimeError( + "AsyncLlamaStackClient has not been initialised. Ensure 'load(..)' has been called." + ) + + current_headers = self._lsc.default_headers or {} + provider_data_json = current_headers.get("X-LlamaStack-Provider-Data") + + try: + provider_data = json.loads(provider_data_json) if provider_data_json else {} + except (json.JSONDecodeError, TypeError): + provider_data = {} + + provider_data.update(updates) + + updated_headers = { + **current_headers, + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + self._lsc = self._lsc.copy(set_default_headers=updated_headers) # type: ignore + return self._lsc diff --git a/src/configuration.py b/src/configuration.py index ef19c619..b8c1a02c 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -11,6 +11,7 @@ from models.config import ( A2AStateConfiguration, AuthorizationConfiguration, + AzureEntraIdConfiguration, Configuration, Customization, LlamaStackConfiguration, @@ -70,7 +71,6 @@ def load_configuration(self, filename: str) -> None: with open(filename, encoding="utf-8") as fin: config_dict = yaml.safe_load(fin) config_dict = replace_env_vars(config_dict) - logger.info("Loaded configuration: %s", config_dict) self.init_from_dict(config_dict) def init_from_dict(self, config_dict: dict[Any, Any]) -> None: @@ -341,5 +341,12 @@ def token_usage_history(self) -> Optional[TokenUsageHistory]: ) return self._token_usage_history + @property + def azure_entra_id(self) -> Optional[AzureEntraIdConfiguration]: + """Return Azure Entra ID configuration, or None if not provided.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.azure_entra_id + configuration: AppConfig = AppConfig() diff --git a/src/lightspeed_stack.py b/src/lightspeed_stack.py index 314b6382..fb6f4f8f 100644 --- a/src/lightspeed_stack.py +++ b/src/lightspeed_stack.py @@ -12,7 +12,6 @@ from log import get_logger from configuration import configuration -from llama_stack_configuration import generate_configuration from runners.uvicorn import start_uvicorn from runners.quota_scheduler import start_quota_scheduler @@ -63,28 +62,6 @@ def create_argument_parser() -> ArgumentParser: help="path to configuration file (default: lightspeed-stack.yaml)", default="lightspeed-stack.yaml", ) - parser.add_argument( - "-g", - "--generate-llama-stack-configuration", - dest="generate_llama_stack_configuration", - help="generate Llama Stack configuration based on LCORE configuration", - action="store_true", - default=False, - ) - parser.add_argument( - "-i", - "--input-config-file", - dest="input_config_file", - help="Llama Stack input configuration file", - default="run.yaml", - ) - parser.add_argument( - "-o", - "--output-config-file", - dest="output_config_file", - help="Llama Stack output configuration file", - default="run_.yaml", - ) return parser @@ -128,26 +105,8 @@ def main() -> None: raise SystemExit(1) from e return - # -g or --generate-llama-stack-configuration CLI flags are used to (re)generate - # configuration for Llama Stack - if args.generate_llama_stack_configuration: - try: - generate_configuration( - args.input_config_file, - args.output_config_file, - configuration.configuration, - ) - logger.info( - "Llama Stack configuration generated and stored into %s", - args.output_config_file, - ) - except Exception as e: - logger.error("Failed to generate Llama Stack configuration: %s", e) - raise SystemExit(1) from e - return - # Store config path in env so each uvicorn worker can load it - # (step is needed because process context isn’t shared). + # (step is needed because process context isn't shared). os.environ["LIGHTSPEED_STACK_CONFIG_PATH"] = args.config_file # start the runners diff --git a/src/llama_stack_configuration.py b/src/llama_stack_configuration.py index ca56fa45..0e45e061 100644 --- a/src/llama_stack_configuration.py +++ b/src/llama_stack_configuration.py @@ -1,18 +1,26 @@ -"""Llama Stack configuration handling.""" +"""Llama Stack configuration enrichment. -from typing import Any +This module can be used in two ways: +1. As a script: `python llama_stack_configuration.py -c config.yaml` +2. As a module: `from llama_stack_configuration import generate_configuration` +""" -import yaml +import logging +import os +from argparse import ArgumentParser +from pathlib import Path +from typing import Any -from log import get_logger +from azure.core.exceptions import ClientAuthenticationError +from azure.identity import ClientSecretCredential, CredentialUnavailableError -from models.config import Configuration, ByokRag +import yaml +from llama_stack.core.stack import replace_env_vars -logger = get_logger(__name__) +logger = logging.getLogger(__name__) -# pylint: disable=too-many-ancestors -class YamlDumper(yaml.Dumper): +class YamlDumper(yaml.Dumper): # pylint: disable=too-many-ancestors """Custom YAML dumper with proper indentation levels.""" def increase_indent(self, flow: bool = False, indentless: bool = False) -> None: @@ -31,53 +39,84 @@ def increase_indent(self, flow: bool = False, indentless: bool = False) -> None: return super().increase_indent(flow, False) -def generate_configuration( - input_file: str, output_file: str, config: Configuration -) -> None: - """Generate new Llama Stack configuration. +# ============================================================================= +# Enrichment: Azure Entra ID +# ============================================================================= - Update a Llama Stack YAML configuration file by inserting BYOK RAG vector - DB and provider entries when present. - Reads the YAML configuration from `input_file`, and if `config.byok_rag` - contains items, updates or creates the `vector_dbs` and - `providers.vector_io` sections (preserving any existing entries) based on - that BYOK RAG data, then writes the resulting configuration to - `output_file`. If `config.byok_rag` is empty, the input configuration is - written unchanged to `output_file`. +def setup_azure_entra_id_token( + azure_config: dict[str, Any] | None, env_file: str +) -> None: + """Generate Azure Entra ID access token and write to .env file. - Parameters: - input_file (str): Path to the existing Llama Stack YAML configuration to read. - output_file (str): Path where the updated YAML configuration will be written. - config (Configuration): Configuration object whose `byok_rag` list - supplies BYOK RAG entries to be added. + Skips generation if AZURE_API_KEY is already set (e.g., orchestrator-injected). """ - logger.info("Reading Llama Stack configuration from file %s", input_file) - - with open(input_file, "r", encoding="utf-8") as file: - ls_config = yaml.safe_load(file) + # Skip if already injected by orchestrator (secure production setup) + if os.environ.get("AZURE_API_KEY"): + logger.info("Azure Entra ID: AZURE_API_KEY already set, skipping generation") + return + + if azure_config is None: + logger.info("Azure Entra ID: Not configured, skipping") + return + + tenant_id = azure_config.get("tenant_id") + client_id = azure_config.get("client_id") + client_secret = azure_config.get("client_secret") + scope = azure_config.get("scope", "https://cognitiveservices.azure.com/.default") + + if not all([tenant_id, client_id, client_secret]): + logger.warning( + "Azure Entra ID: Missing required fields (tenant_id, client_id, client_secret)" + ) + return - if len(config.byok_rag) == 0: - logger.info("BYOK RAG is not configured: finishing") - else: - logger.info("Processing Llama Stack configuration") - # create or update configuration section vector_dbs - ls_config["vector_dbs"] = construct_vector_dbs_section( - ls_config, config.byok_rag + try: + credential = ClientSecretCredential( + tenant_id=str(tenant_id), + client_id=str(client_id), + client_secret=str(client_secret), ) - # create or update configuration section providers/vector_io - ls_config["providers"]["vector_io"] = construct_vector_io_providers_section( - ls_config, config.byok_rag + + token = credential.get_token(scope) + + # Write to .env file + # Create file if it doesn't exist + Path(env_file).touch() + + lines = [] + with open(env_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + # Update or add AZURE_API_KEY + key_found = False + for i, line in enumerate(lines): + if line.startswith("AZURE_API_KEY="): + lines[i] = f"AZURE_API_KEY={token.token}\n" + key_found = True + break + + if not key_found: + lines.append(f"AZURE_API_KEY={token.token}\n") + + with open(env_file, "w", encoding="utf-8") as f: + f.writelines(lines) + + logger.info( + "Azure Entra ID: Access token set in env and written to %s", env_file ) - logger.info("Writing Llama Stack configuration into file %s", output_file) + except (ClientAuthenticationError, CredentialUnavailableError) as e: + logger.error("Azure Entra ID: Failed to generate token: %s", e) - with open(output_file, "w", encoding="utf-8") as file: - yaml.dump(ls_config, file, Dumper=YamlDumper, default_flow_style=False) + +# ============================================================================= +# Enrichment: BYOK RAG +# ============================================================================= def construct_vector_dbs_section( - ls_config: dict[str, Any], byok_rag: list[ByokRag] + ls_config: dict[str, Any], byok_rag: list[dict[str, Any]] ) -> list[dict[str, Any]]: """Construct vector_dbs section in Llama Stack configuration file. @@ -87,7 +126,7 @@ def construct_vector_dbs_section( ls_config (dict[str, Any]): Existing Llama Stack configuration mapping used as the base; existing `vector_dbs` entries are preserved if present. - byok_rag (list[ByokRag]): List of BYOK RAG definitions to be added to + byok_rag (list[dict[str, Any]]): List of BYOK RAG definitions to be added to the `vector_dbs` section. Returns: @@ -107,10 +146,10 @@ def construct_vector_dbs_section( for brag in byok_rag: output.append( { - "vector_db_id": brag.vector_db_id, - "provider_id": "byok_" + brag.vector_db_id, - "embedding_model": brag.embedding_model, - "embedding_dimension": brag.embedding_dimension, + "vector_db_id": brag.get("vector_db_id", ""), + "provider_id": "byok_" + brag.get("vector_db_id", ""), + "embedding_model": brag.get("embedding_model", ""), + "embedding_dimension": brag.get("embedding_dimension"), } ) logger.info( @@ -122,7 +161,7 @@ def construct_vector_dbs_section( def construct_vector_io_providers_section( - ls_config: dict[str, Any], byok_rag: list[ByokRag] + ls_config: dict[str, Any], byok_rag: list[dict[str, Any]] ) -> list[dict[str, Any]]: """Construct providers/vector_io section in Llama Stack configuration file. @@ -134,7 +173,7 @@ def construct_vector_io_providers_section( ls_config (dict[str, Any]): Existing Llama Stack configuration dictionary; if it contains providers.vector_io, those entries are used as the starting list. - byok_rag (list[ByokRag]): List of BYOK RAG specifications to convert + byok_rag (list[dict[str, Any]]): List of BYOK RAG specifications to convert into provider entries. Returns: @@ -148,18 +187,18 @@ def construct_vector_io_providers_section( output = [] # fill-in existing vector_io entries - if "vector_io" in ls_config["providers"]: + if "providers" in ls_config and "vector_io" in ls_config["providers"]: output = ls_config["providers"]["vector_io"] # append new vector_io entries for brag in byok_rag: output.append( { - "provider_id": "byok_" + brag.vector_db_id, - "provider_type": brag.rag_type, + "provider_id": "byok_" + brag.get("vector_db_id", ""), + "provider_type": brag.get("rag_type", "inline::faiss"), "config": { "kvstore": { - "db_path": ".llama/" + brag.vector_db_id + ".db", + "db_path": ".llama/" + brag.get("vector_db_id", "") + ".db", "namespace": None, "type": "sqlite", } @@ -172,3 +211,107 @@ def construct_vector_io_providers_section( len(output), ) return output + + +def enrich_byok_rag(ls_config: dict[str, Any], byok_rag: list[dict[str, Any]]) -> None: + """Enrich Llama Stack config with BYOK RAG settings. + + Args: + ls_config: Llama Stack configuration dict (modified in place) + byok_rag: List of BYOK RAG configurations + """ + if len(byok_rag) == 0: + logger.info("BYOK RAG is not configured: skipping") + return + + logger.info("Enriching Llama Stack config with BYOK RAG") + ls_config["vector_dbs"] = construct_vector_dbs_section(ls_config, byok_rag) + + if "providers" not in ls_config: + ls_config["providers"] = {} + ls_config["providers"]["vector_io"] = construct_vector_io_providers_section( + ls_config, byok_rag + ) + + +# ============================================================================= +# Main Generation Function (service/container mode only) +# ============================================================================= + + +def generate_configuration( + input_file: str, + output_file: str, + config: dict[str, Any], + env_file: str = ".env", +) -> None: + """Generate enriched Llama Stack configuration for service/container mode. + + Args: + input_file: Path to input Llama Stack config + output_file: Path to write enriched config + config: Lightspeed config dict (from YAML) + env_file: Path to .env file + """ + logger.info("Reading Llama Stack configuration from file %s", input_file) + + with open(input_file, "r", encoding="utf-8") as file: + ls_config = yaml.safe_load(file) + + # Enrichment: Azure Entra ID token + setup_azure_entra_id_token(config.get("azure_entra_id"), env_file) + + # Enrichment: BYOK RAG + enrich_byok_rag(ls_config, config.get("byok_rag", [])) + + logger.info("Writing Llama Stack configuration into file %s", output_file) + + with open(output_file, "w", encoding="utf-8") as file: + yaml.dump(ls_config, file, Dumper=YamlDumper, default_flow_style=False) + + +# ============================================================================= +# CLI Entry Point +# ============================================================================= + + +def main() -> None: + """CLI entry point.""" + parser = ArgumentParser( + description="Enrich Llama Stack config with Lightspeed values", + ) + parser.add_argument( + "-c", + "--config", + default="lightspeed-stack.yaml", + help="Lightspeed config file (default: lightspeed-stack.yaml)", + ) + parser.add_argument( + "-i", + "--input", + default="run.yaml", + help="Input Llama Stack config (default: run.yaml)", + ) + parser.add_argument( + "-o", + "--output", + default="run_.yaml", + help="Output enriched config (default: run_.yaml)", + ) + parser.add_argument( + "-e", + "--env-file", + default=".env", + help="Path to .env file for AZURE_API_KEY (default: .env)", + ) + args = parser.parse_args() + + with open(args.config, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + config = replace_env_vars(config) + + generate_configuration(args.input, args.output, config, args.env_file) + + +if __name__ == "__main__": + main() diff --git a/src/models/config.py b/src/models/config.py index 1e8335d5..9356464c 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -1534,6 +1534,20 @@ class QuotaHandlersConfiguration(ConfigurationBase): ) +class AzureEntraIdConfiguration(ConfigurationBase): + """Microsoft Entra ID authentication attributes for Azure.""" + + tenant_id: SecretStr + client_id: SecretStr + client_secret: SecretStr + scope: str = Field( + "https://cognitiveservices.azure.com/.default", + title="Token scope", + description="Azure Cognitive Services scope for token requests. " + "Override only if using a different Azure service.", + ) + + class Configuration(ConfigurationBase): """Global service configuration.""" @@ -1643,6 +1657,7 @@ class Configuration(ConfigurationBase): title="Quota handlers", description="Quota handlers configuration", ) + azure_entra_id: Optional[AzureEntraIdConfiguration] = None @model_validator(mode="after") def validate_mcp_auth_headers(self) -> Self: diff --git a/test.containerfile b/test.containerfile index 1ec104d4..a1966069 100644 --- a/test.containerfile +++ b/test.containerfile @@ -3,12 +3,20 @@ FROM quay.io/rhoai/odh-llama-stack-core-rhel9:rhoai-3.2 # Install missing dependencies and create required directories USER root -RUN pip install faiss-cpu==1.11.0 && \ +RUN pip install faiss-cpu==1.11.0 azure-identity && \ mkdir -p /app-root && \ chown -R 1001:0 /app-root && \ chmod -R 775 /app-root && \ mkdir -p /opt/app-root/src/.llama/storage /opt/app-root/src/.llama/providers.d && \ chown -R 1001:0 /opt/app-root/src/.llama +# Copy enrichment scripts for runtime config enrichment +COPY src/llama_stack_configuration.py /opt/app-root/llama_stack_configuration.py +COPY scripts/llama-stack-entrypoint.sh /opt/app-root/enrich-entrypoint.sh +RUN chmod +x /opt/app-root/enrich-entrypoint.sh && \ + chown 1001:0 /opt/app-root/enrich-entrypoint.sh /opt/app-root/llama_stack_configuration.py + # Switch back to the original user USER 1001 + +ENTRYPOINT ["/opt/app-root/enrich-entrypoint.sh"] diff --git a/tests/e2e/configs/run-azure.yaml b/tests/e2e/configs/run-azure.yaml index 08004a1d..89adfa5b 100644 --- a/tests/e2e/configs/run-azure.yaml +++ b/tests/e2e/configs/run-azure.yaml @@ -87,6 +87,7 @@ providers: - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers + safety: - config: excluded_categories: [] @@ -152,4 +153,4 @@ registered_resources: benchmarks: [] tool_groups: - toolgroup_id: builtin::rag - provider_id: rag-runtime \ No newline at end of file + provider_id: rag-runtime diff --git a/tests/e2e/configuration/library-mode/lightspeed-stack-auth-noop-token.yaml b/tests/e2e/configuration/library-mode/lightspeed-stack-auth-noop-token.yaml index 777421f7..4ab62b2b 100644 --- a/tests/e2e/configuration/library-mode/lightspeed-stack-auth-noop-token.yaml +++ b/tests/e2e/configuration/library-mode/lightspeed-stack-auth-noop-token.yaml @@ -23,4 +23,3 @@ conversation_cache: authentication: module: "noop-with-token" - diff --git a/tests/e2e/configuration/library-mode/lightspeed-stack-invalid-feedback-storage.yaml b/tests/e2e/configuration/library-mode/lightspeed-stack-invalid-feedback-storage.yaml index 1a39ad1e..f29418fb 100644 --- a/tests/e2e/configuration/library-mode/lightspeed-stack-invalid-feedback-storage.yaml +++ b/tests/e2e/configuration/library-mode/lightspeed-stack-invalid-feedback-storage.yaml @@ -17,4 +17,3 @@ user_data_collection: authentication: module: "noop-with-token" - diff --git a/tests/e2e/configuration/library-mode/lightspeed-stack-no-cache.yaml b/tests/e2e/configuration/library-mode/lightspeed-stack-no-cache.yaml index d8a0214d..d54a8ab5 100644 --- a/tests/e2e/configuration/library-mode/lightspeed-stack-no-cache.yaml +++ b/tests/e2e/configuration/library-mode/lightspeed-stack-no-cache.yaml @@ -19,4 +19,3 @@ user_data_collection: authentication: module: "noop-with-token" - diff --git a/tests/e2e/configuration/library-mode/lightspeed-stack.yaml b/tests/e2e/configuration/library-mode/lightspeed-stack.yaml index e6d02d3a..47257bfb 100644 --- a/tests/e2e/configuration/library-mode/lightspeed-stack.yaml +++ b/tests/e2e/configuration/library-mode/lightspeed-stack.yaml @@ -16,4 +16,4 @@ user_data_collection: transcripts_enabled: true transcripts_storage: "/tmp/data/transcripts" authentication: - module: "noop" \ No newline at end of file + module: "noop" diff --git a/tests/e2e/configuration/server-mode/lightspeed-stack-auth-noop-token.yaml b/tests/e2e/configuration/server-mode/lightspeed-stack-auth-noop-token.yaml index b415c588..960919ed 100644 --- a/tests/e2e/configuration/server-mode/lightspeed-stack-auth-noop-token.yaml +++ b/tests/e2e/configuration/server-mode/lightspeed-stack-auth-noop-token.yaml @@ -29,4 +29,3 @@ conversation_cache: authentication: module: "noop-with-token" - diff --git a/tests/e2e/configuration/server-mode/lightspeed-stack-no-cache.yaml b/tests/e2e/configuration/server-mode/lightspeed-stack-no-cache.yaml index 334884fa..03ae32ab 100644 --- a/tests/e2e/configuration/server-mode/lightspeed-stack-no-cache.yaml +++ b/tests/e2e/configuration/server-mode/lightspeed-stack-no-cache.yaml @@ -25,4 +25,3 @@ user_data_collection: authentication: module: "noop-with-token" - diff --git a/tests/e2e/configuration/server-mode/lightspeed-stack.yaml b/tests/e2e/configuration/server-mode/lightspeed-stack.yaml index adc5b482..cc699ba8 100644 --- a/tests/e2e/configuration/server-mode/lightspeed-stack.yaml +++ b/tests/e2e/configuration/server-mode/lightspeed-stack.yaml @@ -17,4 +17,4 @@ user_data_collection: transcripts_enabled: true transcripts_storage: "/tmp/data/transcripts" authentication: - module: "noop" \ No newline at end of file + module: "noop" diff --git a/tests/e2e/features/steps/info.py b/tests/e2e/features/steps/info.py index 59212668..b4ec37af 100644 --- a/tests/e2e/features/steps/info.py +++ b/tests/e2e/features/steps/info.py @@ -105,18 +105,13 @@ def check_shield_structure(context: Context) -> None: assert found_shield is not None, "No shield found in response" - expected_model = context.default_model - expected_provider = context.default_provider - # Validate structure and values assert found_shield["type"] == "shield", "type should be 'shield'" assert ( found_shield["provider_id"] == "llama-guard" ), "provider_id should be 'llama-guard'" - assert ( - found_shield["provider_resource_id"] == f"{expected_provider}/{expected_model}" - ), ( - f"provider_resource_id should be '{expected_provider}/{expected_model}', " + assert found_shield["provider_resource_id"] == "openai/gpt-4o-mini", ( + f"provider_resource_id should be 'openai/gpt-4o-mini', " f"but is '{found_shield['provider_resource_id']}'" ) assert ( diff --git a/tests/unit/authentication/test_api_key_token.py b/tests/unit/authentication/test_api_key_token.py index b16be195..212801b4 100644 --- a/tests/unit/authentication/test_api_key_token.py +++ b/tests/unit/authentication/test_api_key_token.py @@ -2,6 +2,7 @@ """Unit tests for functions defined in authentication/api_key_token.py""" + import pytest from fastapi import HTTPException, Request from pydantic import SecretStr diff --git a/tests/unit/authorization/test_azure_token_manager.py b/tests/unit/authorization/test_azure_token_manager.py new file mode 100644 index 00000000..f393cc62 --- /dev/null +++ b/tests/unit/authorization/test_azure_token_manager.py @@ -0,0 +1,155 @@ +"""Unit test for Authentication with Azure Entra ID Credentials.""" + +# pylint: disable=protected-access + +import time +from typing import Any, Generator + +import pytest +from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError +from pydantic import SecretStr +from pytest_mock import MockerFixture + +from authorization.azure_token_manager import ( + AzureEntraIDManager, + TOKEN_EXPIRATION_LEEWAY, +) +from configuration import AzureEntraIdConfiguration + + +@pytest.fixture(name="dummy_config") +def dummy_config_fixture() -> AzureEntraIdConfiguration: + """Return a dummy AzureEntraIdConfiguration for testing.""" + return AzureEntraIdConfiguration( + tenant_id=SecretStr("tenant"), + client_id=SecretStr("client"), + client_secret=SecretStr("secret"), + ) + + +@pytest.fixture(autouse=True) +def reset_singleton() -> Generator[None, None, None]: + """Reset the singleton instance before each test.""" + AzureEntraIDManager._instances = {} # type: ignore[attr-defined] + yield + + +@pytest.fixture(name="token_manager") +def token_manager_fixture() -> AzureEntraIDManager: + """Return a fresh AzureEntraIDTokenManager instance.""" + return AzureEntraIDManager() + + +class TestAzureEntraIDTokenManager: + """Unit tests for AzureEntraIDTokenManager.""" + + def test_singleton_behavior(self, token_manager: AzureEntraIDManager) -> None: + """Verify the singleton returns the same instance.""" + manager2 = AzureEntraIDManager() + assert token_manager is manager2 + + def test_initial_state( + self, token_manager: AzureEntraIDManager, mocker: MockerFixture + ) -> None: + """Check the initial token manager state.""" + mocker.patch.dict("os.environ", {"AZURE_API_KEY": ""}, clear=False) + assert token_manager.access_token.get_secret_value() == "" + assert token_manager.is_token_expired + assert not token_manager.is_entra_id_configured + + def test_set_config( + self, + token_manager: AzureEntraIDManager, + dummy_config: AzureEntraIdConfiguration, + ) -> None: + """Set the Azure configuration on the token manager.""" + token_manager.set_config(dummy_config) + assert token_manager.is_entra_id_configured + + def test_token_expiration_logic(self, token_manager: AzureEntraIDManager) -> None: + """Verify token expiration logic works correctly.""" + token_manager._expires_on = int(time.time()) + 100 + assert not token_manager.is_token_expired + + token_manager._expires_on = 0 + assert token_manager.is_token_expired + + def test_refresh_token_raises_without_config( + self, token_manager: AzureEntraIDManager + ) -> None: + """Raise ValueError when refresh_token is called without config.""" + with pytest.raises(ValueError, match="Azure Entra ID configuration not set"): + token_manager.refresh_token() + + def test_update_access_token_sets_token_and_expiration( + self, token_manager: AzureEntraIDManager + ) -> None: + """Update the token and its expiration in the token manager.""" + expires_on = int(time.time()) + 3600 + token_manager._update_access_token("test-token", expires_on) + assert token_manager.access_token.get_secret_value() == "test-token" + assert token_manager._expires_on == expires_on - TOKEN_EXPIRATION_LEEWAY + + def test_refresh_token_success( + self, + token_manager: AzureEntraIDManager, + dummy_config: AzureEntraIdConfiguration, + mocker: MockerFixture, + ) -> None: + """Refresh the token successfully using the Azure credential mock.""" + token_manager.set_config(dummy_config) + dummy_access_token = AccessToken("token_value", int(time.time()) + 3600) + + mock_credential_instance = mocker.Mock() + mock_credential_instance.get_token.return_value = dummy_access_token + + mocker.patch( + "authorization.azure_token_manager.ClientSecretCredential", + return_value=mock_credential_instance, + ) + + result = token_manager.refresh_token() + + assert result is True + assert token_manager.access_token.get_secret_value() == "token_value" + assert not token_manager.is_token_expired + mock_credential_instance.get_token.assert_called_once_with(dummy_config.scope) + + def test_refresh_token_failure_logs_error( + self, + token_manager: AzureEntraIDManager, + dummy_config: AzureEntraIdConfiguration, + mocker: MockerFixture, + caplog: Any, + ) -> None: + """Log error when token retrieval fails due to ClientAuthenticationError.""" + token_manager.set_config(dummy_config) + mock_credential_instance = mocker.Mock() + mock_credential_instance.get_token.side_effect = ClientAuthenticationError( + "fail" + ) + mocker.patch( + "authorization.azure_token_manager.ClientSecretCredential", + return_value=mock_credential_instance, + ) + + with caplog.at_level("WARNING"): + result = token_manager.refresh_token() + assert result is False + assert "Failed to retrieve Azure access token" in caplog.text + + def test_token_expired_property_dynamic( + self, token_manager: AzureEntraIDManager, mocker: MockerFixture + ) -> None: + """Simulate time passage to test token expiration property.""" + now = 1000000 + token_manager._expires_on = now + 10 + + mocker.patch("authorization.azure_token_manager.time.time", return_value=now) + assert not token_manager.is_token_expired + + mocker.patch( + "authorization.azure_token_manager.time.time", return_value=now + 20 + ) + assert token_manager.is_token_expired diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 286cae17..86edcd48 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -101,6 +101,7 @@ def test_dump_configuration(tmp_path: Path) -> None: assert "database" in content assert "byok_rag" in content assert "quota_handlers" in content + assert "azure_entra_id" in content # check the whole deserialized JSON file content assert content == { @@ -202,6 +203,7 @@ def test_dump_configuration(tmp_path: Path) -> None: "sqlite": None, "postgres": None, }, + "azure_entra_id": None, } @@ -419,6 +421,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: assert "database" in content assert "byok_rag" in content assert "quota_handlers" in content + assert "azure_entra_id" in content # check the whole deserialized JSON file content assert content == { @@ -535,6 +538,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: "sqlite": None, "postgres": None, }, + "azure_entra_id": None, } @@ -751,6 +755,7 @@ def test_dump_configuration_with_quota_limiters_different_values( "sqlite": None, "postgres": None, }, + "azure_entra_id": None, } @@ -941,6 +946,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "sqlite": None, "postgres": None, }, + "azure_entra_id": None, } @@ -1117,4 +1123,5 @@ def test_dump_configuration_pg_namespace(tmp_path: Path) -> None: "sqlite": None, "postgres": None, }, + "azure_entra_id": None, } diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5405092f..435d3b46 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,9 +1,19 @@ """Unit tests for functions defined in src/client.py.""" -import pytest +# pylint: disable=protected-access + +import json +import pytest from client import AsyncLlamaStackClientHolder from models.config import LlamaStackConfiguration +from utils.types import Singleton + + +@pytest.fixture(autouse=True) +def reset_singleton() -> None: + """Reset singleton state between tests.""" + Singleton._instances = {} def test_async_client_get_client_method() -> None: @@ -66,8 +76,77 @@ async def test_get_async_llama_stack_wrong_configuration() -> None: ) cfg.library_client_config_path = None with pytest.raises( - Exception, - match="Configuration problem: library_client_config_path option is not set", + ValueError, + match="Configuration problem: library_client_config_path is not set", ): client = AsyncLlamaStackClientHolder() await client.load(cfg) + + +@pytest.mark.asyncio +async def test_update_provider_data_service_client() -> None: + """Test that update_provider_data updates headers for service clients.""" + cfg = LlamaStackConfiguration( + url="http://localhost:8321", + api_key=None, + use_as_library_client=False, + library_client_config_path=None, + ) + holder = AsyncLlamaStackClientHolder() + await holder.load(cfg) + + original_client = holder.get_client() + assert not holder.is_library_client + + # Pre-populate with existing provider data via headers + original_client._custom_headers["X-LlamaStack-Provider-Data"] = json.dumps( + { + "existing_field": "keep_this", + "azure_api_key": "old_token", + } + ) + + updated_client = holder.update_provider_data( + { + "azure_api_key": "new_token", + "azure_api_base": "https://new.example.com", + } + ) + + # Returns new client and updates holder + assert updated_client is not original_client + assert holder.get_client() is updated_client + + # Verify headers on updated client + provider_data_json = updated_client.default_headers.get( + "X-LlamaStack-Provider-Data" + ) + assert provider_data_json is not None + provider_data = json.loads(provider_data_json) + + # Existing fields preserved, new fields updated + assert provider_data["existing_field"] == "keep_this" + assert provider_data["azure_api_key"] == "new_token" + assert provider_data["azure_api_base"] == "https://new.example.com" + + +@pytest.mark.asyncio +async def test_reload_library_client() -> None: + """Test that reload_library_client reloads and returns new client.""" + cfg = LlamaStackConfiguration( + url=None, + api_key=None, + use_as_library_client=True, + library_client_config_path="./tests/configuration/minimal-stack.yaml", + ) + holder = AsyncLlamaStackClientHolder() + await holder.load(cfg) + + original_client = holder.get_client() + assert holder.is_library_client + + reloaded_client = await holder.reload_library_client() + + # Returns new client and updates holder + assert reloaded_client is not original_client + assert holder.get_client() is reloaded_client diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index b2083a03..d231e853 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, Generator +from pydantic import ValidationError import pytest from configuration import AppConfig, LogicError @@ -806,3 +807,71 @@ def test_configuration_with_quota_handlers(tmpdir: Path) -> None: # check the scheduler configuration assert cfg.quota_handlers_configuration.scheduler.period == 1 + + +def test_load_configuration_with_azure_entra_id(tmpdir: Path) -> None: + """Return Azure Entra ID configuration when provided in configuration.""" + cfg_filename = tmpdir / "config.yaml" + with open(cfg_filename, "w", encoding="utf-8") as fout: + fout.write( + """ +name: test service +service: + host: localhost + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + api_key: test-key + url: http://localhost:8321 + use_as_library_client: false +user_data_collection: + feedback_enabled: false +azure_entra_id: + tenant_id: tenant + client_id: client + client_secret: secret + """ + ) + + cfg = AppConfig() + cfg.load_configuration(str(cfg_filename)) + + azure_conf = cfg.azure_entra_id + assert azure_conf is not None + assert azure_conf.tenant_id.get_secret_value() == "tenant" + assert azure_conf.client_id.get_secret_value() == "client" + assert azure_conf.client_secret.get_secret_value() == "secret" + + +def test_load_configuration_with_incomplete_azure_entra_id_raises(tmpdir: Path) -> None: + """Raise error if Azure Entra ID block is incomplete in configuration.""" + cfg_filename = tmpdir / "config.yaml" + with open(cfg_filename, "w", encoding="utf-8") as fout: + fout.write( + """ +name: test service +service: + host: localhost + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + api_key: test-key + url: http://localhost:8321 + use_as_library_client: false +user_data_collection: + feedback_enabled: false +azure_entra_id: + tenant_id: tenant + client_id: client + """ + ) + + cfg = AppConfig() + with pytest.raises(ValidationError): + cfg.load_configuration(str(cfg_filename)) diff --git a/tests/unit/test_llama_stack_configuration.py b/tests/unit/test_llama_stack_configuration.py index cd396077..840911fa 100644 --- a/tests/unit/test_llama_stack_configuration.py +++ b/tests/unit/test_llama_stack_configuration.py @@ -1,383 +1,183 @@ -"""Unit tests for functions defined in src/llama_stack_configuration.py.""" +"""Unit tests for src/llama_stack_configuration.py.""" from pathlib import Path - from typing import Any import pytest import yaml -from pydantic import SecretStr - +from llama_stack_configuration import ( + generate_configuration, + construct_vector_dbs_section, + construct_vector_io_providers_section, +) from models.config import ( - ByokRag, Configuration, ServiceConfiguration, LlamaStackConfiguration, UserDataCollection, + InferenceConfiguration, ) -from constants import ( - DEFAULT_EMBEDDING_MODEL, - DEFAULT_EMBEDDING_DIMENSION, -) -from llama_stack_configuration import ( - generate_configuration, - construct_vector_dbs_section, - construct_vector_io_providers_section, -) +# ============================================================================= +# Test construct_vector_dbs_section +# ============================================================================= -def test_construct_vector_dbs_section_init() -> None: - """Test the function construct_vector_dbs_section for no vector_dbs configured before.""" +def test_construct_vector_dbs_section_empty() -> None: + """Test with no BYOK RAG config.""" ls_config: dict[str, Any] = {} - byok_rag: list[ByokRag] = [] + byok_rag: list[dict[str, Any]] = [] output = construct_vector_dbs_section(ls_config, byok_rag) assert len(output) == 0 -def test_construct_vector_dbs_section_init_with_existing_data() -> None: - """Test the function construct_vector_dbs_section for vector_dbs configured before.""" +def test_construct_vector_dbs_section_preserves_existing() -> None: + """Test preserves existing vector_dbs entries.""" ls_config = { "vector_dbs": [ - { - "vector_db_id": "vector_db_id_1", - "provider_id": "provier_id_1", - "embedding_model": "embedding_model_1", - "embedding_dimension": 1, - }, - { - "vector_db_id": "vector_db_id_2", - "provider_id": "provier_id_2", - "embedding_model": "embedding_model_2", - "embedding_dimension": 2, - }, + {"vector_db_id": "existing", "provider_id": "existing_provider"}, ] } - byok_rag: list[ByokRag] = [] + byok_rag: list[dict[str, Any]] = [] output = construct_vector_dbs_section(ls_config, byok_rag) - assert len(output) == 2 - assert output[0] == { - "vector_db_id": "vector_db_id_1", - "provider_id": "provier_id_1", - "embedding_model": "embedding_model_1", - "embedding_dimension": 1, - } - assert output[1] == { - "vector_db_id": "vector_db_id_2", - "provider_id": "provier_id_2", - "embedding_model": "embedding_model_2", - "embedding_dimension": 2, - } + assert len(output) == 1 + assert output[0]["vector_db_id"] == "existing" -def test_construct_vector_dbs_section_append() -> None: - """Test the function construct_vector_dbs_section for no vector_dbs configured before.""" +def test_construct_vector_dbs_section_adds_new() -> None: + """Test adds new BYOK RAG entries.""" ls_config: dict[str, Any] = {} - byok_rag: list[ByokRag] = [ - ByokRag( - rag_id="rag_id_1", - vector_db_id="vector_db_id_1", - db_path=Path("tests/configuration/rag.txt"), - ), - ByokRag( - rag_id="rag_id_2", - vector_db_id="vector_db_id_2", - db_path=Path("tests/configuration/rag.txt"), - ), + byok_rag = [ + { + "rag_id": "rag1", + "vector_db_id": "db1", + "embedding_model": "test-model", + "embedding_dimension": 512, + }, ] output = construct_vector_dbs_section(ls_config, byok_rag) - assert len(output) == 2 - assert output[0] == { - "vector_db_id": "vector_db_id_1", - "provider_id": "byok_vector_db_id_1", - "embedding_model": DEFAULT_EMBEDDING_MODEL, - "embedding_dimension": DEFAULT_EMBEDDING_DIMENSION, - } - assert output[1] == { - "vector_db_id": "vector_db_id_2", - "provider_id": "byok_vector_db_id_2", - "embedding_model": DEFAULT_EMBEDDING_MODEL, - "embedding_dimension": DEFAULT_EMBEDDING_DIMENSION, - } + assert len(output) == 1 + assert output[0]["vector_db_id"] == "db1" + assert output[0]["provider_id"] == "byok_db1" + assert output[0]["embedding_model"] == "test-model" + assert output[0]["embedding_dimension"] == 512 -def test_construct_vector_dbs_section_full_merge() -> None: - """Test the function construct_vector_dbs_section for vector_dbs configured before.""" - ls_config = { - "vector_dbs": [ - { - "vector_db_id": "vector_db_id_1", - "provider_id": "provier_id_1", - "embedding_model": "embedding_model_1", - "embedding_dimension": 1, - }, - { - "vector_db_id": "vector_db_id_2", - "provider_id": "provier_id_2", - "embedding_model": "embedding_model_2", - "embedding_dimension": 2, - }, - ] - } - byok_rag = [ - ByokRag( - rag_id="rag_id_1", - vector_db_id="vector_db_id_1", - db_path=Path("tests/configuration/rag.txt"), - ), - ByokRag( - rag_id="rag_id_2", - vector_db_id="vector_db_id_2", - db_path=Path("tests/configuration/rag.txt"), - ), - ] +def test_construct_vector_dbs_section_merge() -> None: + """Test merges existing and new entries.""" + ls_config = {"vector_dbs": [{"vector_db_id": "existing"}]} + byok_rag = [{"vector_db_id": "new_db"}] output = construct_vector_dbs_section(ls_config, byok_rag) - assert len(output) == 4 - assert output[0] == { - "vector_db_id": "vector_db_id_1", - "provider_id": "provier_id_1", - "embedding_model": "embedding_model_1", - "embedding_dimension": 1, - } - assert output[1] == { - "vector_db_id": "vector_db_id_2", - "provider_id": "provier_id_2", - "embedding_model": "embedding_model_2", - "embedding_dimension": 2, - } - assert output[2] == { - "vector_db_id": "vector_db_id_1", - "provider_id": "byok_vector_db_id_1", - "embedding_model": DEFAULT_EMBEDDING_MODEL, - "embedding_dimension": DEFAULT_EMBEDDING_DIMENSION, - } - assert output[3] == { - "vector_db_id": "vector_db_id_2", - "provider_id": "byok_vector_db_id_2", - "embedding_model": DEFAULT_EMBEDDING_MODEL, - "embedding_dimension": DEFAULT_EMBEDDING_DIMENSION, - } + assert len(output) == 2 + +# ============================================================================= +# Test construct_vector_io_providers_section +# ============================================================================= -def test_construct_vector_io_providers_section_init() -> None: - """Test construct_vector_io_providers_section for no vector_io_providers configured before.""" + +def test_construct_vector_io_providers_section_empty() -> None: + """Test with no BYOK RAG config.""" ls_config: dict[str, Any] = {"providers": {}} - byok_rag: list[ByokRag] = [] + byok_rag: list[dict[str, Any]] = [] output = construct_vector_io_providers_section(ls_config, byok_rag) assert len(output) == 0 -def test_construct_vector_io_providers_section_init_with_existing_data() -> None: - """Test construct_vector_io_providers_section for vector_io_providers configured before.""" - ls_config = { - "providers": { - "vector_io": [ - { - "provider_id": "faiss_1", - "provider_type": "inline::faiss", - }, - { - "provider_id": "faiss_2", - "provider_type": "inline::faiss", - }, - ] - } - } - byok_rag: list[ByokRag] = [] +def test_construct_vector_io_providers_section_preserves_existing() -> None: + """Test preserves existing vector_io entries.""" + ls_config = {"providers": {"vector_io": [{"provider_id": "existing"}]}} + byok_rag: list[dict[str, Any]] = [] output = construct_vector_io_providers_section(ls_config, byok_rag) - assert len(output) == 2 - assert output[0] == { - "provider_id": "faiss_1", - "provider_type": "inline::faiss", - } - assert output[1] == { - "provider_id": "faiss_2", - "provider_type": "inline::faiss", - } + assert len(output) == 1 + assert output[0]["provider_id"] == "existing" -def test_construct_vector_io_providers_section_append() -> None: - """Test construct_vector_io_providers_section for no vector_io_providers configured before.""" +def test_construct_vector_io_providers_section_adds_new() -> None: + """Test adds new BYOK RAG entries.""" ls_config: dict[str, Any] = {"providers": {}} byok_rag = [ - ByokRag( - rag_id="rag_id_1", - vector_db_id="vector_db_id_1", - db_path=Path("tests/configuration/rag.txt"), - ), - ByokRag( - rag_id="rag_id_2", - vector_db_id="vector_db_id_2", - db_path=Path("tests/configuration/rag.txt"), - ), + { + "vector_db_id": "db1", + "rag_type": "inline::faiss", + }, ] output = construct_vector_io_providers_section(ls_config, byok_rag) - assert len(output) == 2 - assert output[0] == { - "provider_id": "byok_vector_db_id_1", - "provider_type": "inline::faiss", - "config": { - "kvstore": { - "db_path": ".llama/vector_db_id_1.db", - "namespace": None, - "type": "sqlite", - }, - }, - } - assert output[1] == { - "provider_id": "byok_vector_db_id_2", - "provider_type": "inline::faiss", - "config": { - "kvstore": { - "db_path": ".llama/vector_db_id_2.db", - "namespace": None, - "type": "sqlite", - }, - }, - } + assert len(output) == 1 + assert output[0]["provider_id"] == "byok_db1" + assert output[0]["provider_type"] == "inline::faiss" -def test_construct_vector_io_providers_section_full_merge() -> None: - """Test construct_vector_io_providers_section for vector_io_providers configured before.""" - ls_config = { - "providers": { - "vector_io": [ - { - "provider_id": "faiss_1", - "provider_type": "inline::faiss", - }, - { - "provider_id": "faiss_2", - "provider_type": "inline::faiss", - }, - ] - } - } - byok_rag = [ - ByokRag( - rag_id="rag_id_1", - vector_db_id="vector_db_id_1", - db_path=Path("tests/configuration/rag.txt"), - ), - ByokRag( - rag_id="rag_id_2", - vector_db_id="vector_db_id_2", - db_path=Path("tests/configuration/rag.txt"), - ), - ] - output = construct_vector_io_providers_section(ls_config, byok_rag) - assert len(output) == 4 - assert output[0] == { - "provider_id": "faiss_1", - "provider_type": "inline::faiss", - } - assert output[1] == { - "provider_id": "faiss_2", - "provider_type": "inline::faiss", - } - assert output[2] == { - "provider_id": "byok_vector_db_id_1", - "provider_type": "inline::faiss", - "config": { - "kvstore": { - "db_path": ".llama/vector_db_id_1.db", - "namespace": None, - "type": "sqlite", - }, - }, - } - assert output[3] == { - "provider_id": "byok_vector_db_id_2", - "provider_type": "inline::faiss", - "config": { - "kvstore": { - "db_path": ".llama/vector_db_id_2.db", - "namespace": None, - "type": "sqlite", - }, - }, - } +# ============================================================================= +# Test generate_configuration +# ============================================================================= -def test_generate_configuration_no_input_file(tmpdir: Path) -> None: - """Test the function to generate configuration when input file does not exist.""" - cfg = Configuration( - name="test_name", - service=ServiceConfiguration(), - llama_stack=LlamaStackConfiguration( +def test_generate_configuration_no_input_file(tmp_path: Path) -> None: + """Test generate_configuration when input file does not exist.""" + config: dict[str, Any] = {} + outfile = tmp_path / "output.yaml" + + with pytest.raises(FileNotFoundError): + generate_configuration("/nonexistent/file.yaml", str(outfile), config) + + +def test_generate_configuration_with_dict(tmp_path: Path) -> None: + """Test generate_configuration accepts dict.""" + config: dict[str, Any] = {"byok_rag": []} + outfile = tmp_path / "output.yaml" + + generate_configuration("tests/configuration/run.yaml", str(outfile), config) + + assert outfile.exists() + with open(outfile, encoding="utf-8") as f: + result = yaml.safe_load(f) + assert "providers" in result + + +def test_generate_configuration_with_pydantic_model(tmp_path: Path) -> None: + """Test generate_configuration accepts Pydantic model via model_dump().""" + cfg = Configuration( # type: ignore[call-arg] + name="test", + service=ServiceConfiguration(), # type: ignore[call-arg] + llama_stack=LlamaStackConfiguration( # type: ignore[call-arg] use_as_library_client=True, - library_client_config_path="tests/configuration/run.yaml", - api_key=SecretStr("whatever"), - ), - user_data_collection=UserDataCollection( - feedback_enabled=False, feedback_storage=None + library_client_config_path="run.yaml", ), + user_data_collection=UserDataCollection(), # type: ignore[call-arg] + inference=InferenceConfiguration(), # type: ignore[call-arg] ) - outfile = tmpdir / "run.xml" - # try to generate new configuration file - with pytest.raises(FileNotFoundError, match="No such file"): - generate_configuration("/does/not/exist", str(outfile), cfg) - - -def test_generate_configuration_proper_input_file_no_byok(tmpdir: Path) -> None: - """Test the function to generate configuration when input file exists.""" - cfg = Configuration( - name="test_name", - service=ServiceConfiguration(), - llama_stack=LlamaStackConfiguration( - use_as_library_client=True, - library_client_config_path="tests/configuration/run.yaml", - api_key=SecretStr("whatever"), - ), - user_data_collection=UserDataCollection( - feedback_enabled=False, feedback_storage=None - ), + outfile = tmp_path / "output.yaml" + + # generate_configuration expects dict, so convert Pydantic model + generate_configuration( + "tests/configuration/run.yaml", str(outfile), cfg.model_dump() ) - outfile = tmpdir / "run.xml" - # try to generate new configuration file - generate_configuration("tests/configuration/run.yaml", str(outfile), cfg) - - with open(outfile, "r", encoding="utf-8") as fin: - generated = yaml.safe_load(fin) - assert "vector_dbs" in generated - assert "providers" in generated - assert "vector_io" in generated["providers"] - - -def test_generate_configuration_proper_input_file_configured_byok(tmpdir: Path) -> None: - """Test the function to generate configuration when BYOK RAG should be added.""" - cfg = Configuration( - name="test_name", - service=ServiceConfiguration(), - llama_stack=LlamaStackConfiguration( - use_as_library_client=True, - library_client_config_path="tests/configuration/run.yaml", - api_key=SecretStr("whatever"), - ), - user_data_collection=UserDataCollection( - feedback_enabled=False, feedback_storage=None - ), - byok_rag=[ - ByokRag( - rag_id="rag_id_1", - vector_db_id="vector_db_id_1", - db_path=Path("tests/configuration/rag.txt"), - ), - ByokRag( - rag_id="rag_id_2", - vector_db_id="vector_db_id_2", - db_path=Path("tests/configuration/rag.txt"), - ), + + assert outfile.exists() + + +def test_generate_configuration_with_byok(tmp_path: Path) -> None: + """Test generate_configuration adds BYOK entries.""" + config = { + "byok_rag": [ + { + "rag_id": "rag1", + "vector_db_id": "db1", + "embedding_model": "test-model", + "embedding_dimension": 256, + "rag_type": "inline::faiss", + }, ], - ) - outfile = tmpdir / "run.xml" - # try to generate new configuration file - generate_configuration("tests/configuration/run.yaml", str(outfile), cfg) - - with open(outfile, "r", encoding="utf-8") as fin: - generated = yaml.safe_load(fin) - assert "vector_dbs" in generated - assert "providers" in generated - assert "vector_io" in generated["providers"] + } + outfile = tmp_path / "output.yaml" + + generate_configuration("tests/configuration/run.yaml", str(outfile), config) + + with open(outfile, encoding="utf-8") as f: + result = yaml.safe_load(f) + + db_ids = [db["vector_db_id"] for db in result["vector_dbs"]] + assert "db1" in db_ids diff --git a/uv.lock b/uv.lock index 9c4b39c5..776ab5e0 100644 --- a/uv.lock +++ b/uv.lock @@ -225,6 +225,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/a0/f59dd73e8582c59672cf1f4e5f3ec60d1ee312f8f2a56ae54af5293173c7/autoevals-0.0.130-py3-none-any.whl", hash = "sha256:ffb7b3a21070d2a4e593bb118180c04e43531e608bffd854624377bd857ceec0", size = 56034, upload-time = "2025-09-08T05:29:59.908Z" }, ] +[[package]] +name = "azure-core" +version = "1.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" }, +] + +[[package]] +name = "azure-identity" +version = "1.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-core" }, + { name = "cryptography" }, + { name = "msal" }, + { name = "msal-extensions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/8d/1a6c41c28a37eab26dc85ab6c86992c700cd3f4a597d9ed174b0e9c69489/azure_identity-1.25.1.tar.gz", hash = "sha256:87ca8328883de6036443e1c37b40e8dc8fb74898240f61071e09d2e369361456", size = 279826, upload-time = "2025-10-06T20:30:02.194Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/7b/5652771e24fff12da9dde4c20ecf4682e606b104f26419d139758cc935a6/azure_identity-1.25.1-py3-none-any.whl", hash = "sha256:e9edd720af03dff020223cd269fa3a61e8f345ea75443858273bcb44844ab651", size = 191317, upload-time = "2025-10-06T20:30:04.251Z" }, +] + [[package]] name = "bandit" version = "1.9.2" @@ -1260,6 +1289,8 @@ dependencies = [ { name = "aiosqlite" }, { name = "asyncpg" }, { name = "authlib" }, + { name = "azure-core" }, + { name = "azure-identity" }, { name = "cachetools" }, { name = "email-validator" }, { name = "fastapi" }, @@ -1343,6 +1374,8 @@ requires-dist = [ { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "asyncpg", specifier = ">=0.31.0" }, { name = "authlib", specifier = ">=1.6.0" }, + { name = "azure-core" }, + { name = "azure-identity" }, { name = "cachetools", specifier = ">=6.1.0" }, { name = "email-validator", specifier = ">=2.2.0" }, { name = "fastapi", specifier = ">=0.115.12" }, @@ -1662,6 +1695,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, ] +[[package]] +name = "msal" +version = "1.34.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/0e/c857c46d653e104019a84f22d4494f2119b4fe9f896c92b4b864b3b045cc/msal-1.34.0.tar.gz", hash = "sha256:76ba83b716ea5a6d75b0279c0ac353a0e05b820ca1f6682c0eb7f45190c43c2f", size = 153961, upload-time = "2025-09-22T23:05:48.989Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/dc/18d48843499e278538890dc709e9ee3dea8375f8be8e82682851df1b48b5/msal-1.34.0-py3-none-any.whl", hash = "sha256:f669b1644e4950115da7a176441b0e13ec2975c29528d8b9e81316023676d6e1", size = 116987, upload-time = "2025-09-22T23:05:47.294Z" }, +] + +[[package]] +name = "msal-extensions" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "msal" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/01/99/5d239b6156eddf761a636bded1118414d161bd6b7b37a9335549ed159396/msal_extensions-1.3.1.tar.gz", hash = "sha256:c5b0fd10f65ef62b5f1d62f4251d51cbcaf003fcedae8c91b040a488614be1a4", size = 23315, upload-time = "2025-03-14T23:51:03.902Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/75/bd9b7bb966668920f06b200e84454c8f3566b102183bc55c5473d96cb2b9/msal_extensions-1.3.1-py3-none-any.whl", hash = "sha256:96d3de4d034504e969ac5e85bae8106c8373b5c6568e4c8fa7af2eca9dbe6bca", size = 20583, upload-time = "2025-03-14T23:51:03.016Z" }, +] + [[package]] name = "multidict" version = "6.7.0"