Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions infra/auth_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@
import random
import subprocess
import uuid
from dataclasses import dataclass

from azure.identity.aio import AzureDeveloperCliCredential
from dotenv_azd import load_azd_env
from kiota_abstractions.api_error import APIError
from kiota_abstractions.base_request_configuration import RequestConfiguration
from msgraph import GraphServiceClient
from msgraph.generated.applications.item.add_password.add_password_post_request_body import (
AddPasswordPostRequestBody,
)
from msgraph.generated.models.api_application import ApiApplication
from msgraph.generated.models.application import Application
from msgraph.generated.models.o_auth2_permission_grant import OAuth2PermissionGrant
from msgraph.generated.models.password_credential import PasswordCredential
from msgraph.generated.models.permission_scope import PermissionScope
from msgraph.generated.models.service_principal import ServicePrincipal
from msgraph.generated.models.web_application import WebApplication
from msgraph.graph_service_client import GraphServiceClient
from msgraph.generated.oauth2_permission_grants.oauth2_permission_grants_request_builder import (
Oauth2PermissionGrantsRequestBuilder,
)


async def get_application(graph_client: GraphServiceClient, app_id: str) -> str | None:
Expand Down Expand Up @@ -136,7 +143,7 @@ def update_app_with_identifier_uri(client_id: str) -> Application:
)


async def create_or_update_fastmcp_app(graph_client: GraphServiceClient) -> None:
async def create_or_update_fastmcp_app(graph_client: GraphServiceClient):
"""Create or update a FastMCP app registration."""
app_id_env_var = "ENTRA_PROXY_AZURE_CLIENT_ID"
app_secret_env_var = "ENTRA_PROXY_AZURE_CLIENT_SECRET"
Expand Down Expand Up @@ -171,6 +178,73 @@ async def create_or_update_fastmcp_app(graph_client: GraphServiceClient) -> None
update_azd_env(app_secret_env_var, client_secret)
print("Client secret created and saved to environment.")

return app_id


@dataclass
class GrantDefinition:
principal_id: str
resource_app_id: str
scopes: list[str]
target_label: str

def scope_string(self) -> str:
return " ".join(self.scopes)


async def grant_application_admin_consent(graph_client: GraphServiceClient, server_app_id: str):
server_principal = await graph_client.service_principals_with_app_id(server_app_id).get()
if server_principal is None or server_principal.id is None:
raise ValueError("Unable to locate service principal for server application")

grant_definitions = [
GrantDefinition(
principal_id=server_principal.id,
resource_app_id="00000003-0000-0000-c000-000000000000",
scopes=["User.Read", "email", "offline_access", "openid", "profile"],
target_label="server application",
)
]

for grant in grant_definitions:
resource_principal = await graph_client.service_principals_with_app_id(grant.resource_app_id).get()
if resource_principal is None or resource_principal.id is None:
raise ValueError(f"Unable to locate service principal for resource {grant.resource_app_id}")

desired_scope = grant.scope_string()
filter_query = f"clientId eq '{grant.principal_id}' and resourceId eq '{resource_principal.id}'"
query_params = Oauth2PermissionGrantsRequestBuilder.Oauth2PermissionGrantsRequestBuilderGetQueryParameters(
filter=filter_query
)
request_config = RequestConfiguration[
Oauth2PermissionGrantsRequestBuilder.Oauth2PermissionGrantsRequestBuilderGetQueryParameters
](query_parameters=query_params)
existing_grants = await graph_client.oauth2_permission_grants.get(request_configuration=request_config)

current_grant = existing_grants.value[0] if existing_grants and existing_grants.value else None

if current_grant:
print(f"Admin consent already granted for {desired_scope} on the {grant.target_label}")
continue

try:
await graph_client.oauth2_permission_grants.post(
OAuth2PermissionGrant(
client_id=grant.principal_id,
consent_type="AllPrincipals",
resource_id=resource_principal.id,
scope=desired_scope,
)
)
print(f"Granted admin consent for {desired_scope} on the {grant.target_label}")
except APIError as error:
status_code = error.response_status_code
if status_code in {401, 403}:
print(f"Failed to grant admin consent: {error.message}")
return
else:
raise


async def main():
# Configuration - customize these as needed
Expand All @@ -182,8 +256,12 @@ async def main():
scopes = ["https://graph.microsoft.com/.default"]
graph_client = GraphServiceClient(credentials=credential, scopes=scopes)

await create_or_update_fastmcp_app(graph_client)
print("Setup complete!")
server_app_id = await create_or_update_fastmcp_app(graph_client)

print("Attempting to grant admin consent for the client and server applications...")
await grant_application_admin_consent(graph_client, server_app_id)

print("✅ Entra app registration setup is complete.")


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ param keycloakMcpServerAudience string = 'mcp-server'
@description('Flag to restrict ACR public network access (requires VPN for local image push when true)')
param usePrivateAcr bool = false

@description('Entra ID group ID for admin access to expense statistics (only used when mcpAuthProvider is entra_proxy)')
param entraAdminGroupId string = ''

@description('Flag to restrict Log Analytics public query access for increased security')
param usePrivateLogAnalytics bool = false

Expand Down Expand Up @@ -790,6 +793,7 @@ module server 'server.bicep' = {
entraProxyClientSecret: useEntraProxy ? entraProxyClientSecret : ''
entraProxyBaseUrl: useEntraProxy ? entraProxyMcpServerBaseUrl : ''
tenantId: useEntraProxy ? tenant().tenantId : ''
entraAdminGroupId: useEntraProxy ? entraAdminGroupId : ''
mcpAuthProvider: mcpAuthProvider
logfireToken: logfireToken
}
Expand Down
3 changes: 3 additions & 0 deletions infra/main.parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
"entraProxyClientSecret": {
"value": "${ENTRA_PROXY_AZURE_CLIENT_SECRET}"
},
"entraAdminGroupId": {
"value": "${ENTRA_ADMIN_GROUP_ID}"
},
"logfireToken": {
"value": "${LOGFIRE_TOKEN}"
}
Expand Down
5 changes: 5 additions & 0 deletions infra/server.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ param entraProxyClientId string = ''
param entraProxyClientSecret string = ''
param entraProxyBaseUrl string = ''
param tenantId string = ''
param entraAdminGroupId string = ''
@secure()
param logfireToken string = ''
@allowed([
Expand Down Expand Up @@ -139,6 +140,10 @@ var entraProxyEnv = !empty(entraProxyClientId) ? [
name: 'AZURE_TENANT_ID'
value: tenantId
}
{
name: 'ENTRA_ADMIN_GROUP_ID'
value: entraAdminGroupId
}
] : []

// Secrets for sensitive values
Expand Down
1 change: 1 addition & 0 deletions infra/write_env.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ if ($ENTRA_PROXY_AZURE_CLIENT_ID -and $ENTRA_PROXY_AZURE_CLIENT_ID -ne "") {
Add-Content -Path $ENV_FILE_PATH -Value "ENTRA_PROXY_AZURE_CLIENT_ID=$ENTRA_PROXY_AZURE_CLIENT_ID"
Write-Env ENTRA_PROXY_AZURE_CLIENT_SECRET
Write-Env ENTRA_PROXY_MCP_SERVER_BASE_URL
Write-EnvIfSet ENTRA_ADMIN_GROUP_ID
}
Add-Content -Path $ENV_FILE_PATH -Value "MCP_ENTRY=$(azd env get-value MCP_ENTRY)"
Add-Content -Path $ENV_FILE_PATH -Value "MCP_SERVER_URL=$(azd env get-value MCP_SERVER_URL)"
Expand Down
1 change: 1 addition & 0 deletions infra/write_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ if [ -n "$ENTRA_PROXY_AZURE_CLIENT_ID" ]; then
echo "ENTRA_PROXY_AZURE_CLIENT_ID=${ENTRA_PROXY_AZURE_CLIENT_ID}" >> "$ENV_FILE_PATH"
write_env ENTRA_PROXY_AZURE_CLIENT_SECRET
write_env ENTRA_PROXY_MCP_SERVER_BASE_URL
write_env_if_set ENTRA_ADMIN_GROUP_ID
fi
echo "MCP_ENTRY=$(azd env get-value MCP_ENTRY)" >> "$ENV_FILE_PATH"
echo "MCP_SERVER_URL=$(azd env get-value MCP_SERVER_URL)" >> "$ENV_FILE_PATH"
Expand Down
78 changes: 78 additions & 0 deletions servers/auth_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import Annotated

import httpx
import logfire
from azure.core.settings import settings
from azure.cosmos.aio import CosmosClient
Expand All @@ -20,6 +21,7 @@
from fastmcp.server.middleware import Middleware, MiddlewareContext
from key_value.aio.stores.memory import MemoryStore
from keycloak_provider import KeycloakAuthProvider
from msal import ConfidentialClientApplication, TokenCache
from opentelemetry.instrumentation.starlette import StarletteInstrumentor
from rich.console import Console
from rich.logging import RichHandler
Expand Down Expand Up @@ -124,6 +126,32 @@
logger.error("No authentication configured for MCP server, exiting.")
raise SystemExit(1)

confidential_client = ConfidentialClientApplication(
client_id=os.environ["ENTRA_PROXY_AZURE_CLIENT_ID"],
client_credential=os.environ["ENTRA_PROXY_AZURE_CLIENT_SECRET"],
authority=f"https://login.microsoftonline.com/{os.environ['AZURE_TENANT_ID']}",
token_cache=TokenCache(),
)


async def check_user_in_group(graph_token: str, group_id: str) -> bool:
"""Check if the authenticated user is a member of the specified group (including transitive membership)."""
async with httpx.AsyncClient() as client:
url = (
"https://graph.microsoft.com/v1.0/me/transitiveMemberOf/microsoft.graph.group"
f"?$filter=id eq '{group_id}'&$count=true"
)
response = await client.get(
url,
headers={
"Authorization": f"Bearer {graph_token}",
"ConsistencyLevel": "eventual",
},
)
response.raise_for_status()
data = response.json()
return data.get("@odata.count", 0) > 0


# Middleware to populate user_id in per-request context state
class UserAuthMiddleware(Middleware):
Expand Down Expand Up @@ -242,6 +270,56 @@ async def get_user_expenses(ctx: Context):
return f"Error: Unable to retrieve expense data - {str(e)}"


@mcp.tool
async def get_expense_stats(ctx: Context):
"""Get a statistical summary of expenses (count per category) for all users.
Only accessible to users in the authorized admin group.
"""
access_token = get_access_token()
if not access_token:
return "Error: Authentication required"

auth_token = access_token.token
try:
graph_resource_access_token = confidential_client.acquire_token_on_behalf_of(
user_assertion=auth_token, scopes=["https://graph.microsoft.com/.default"]
)
if "error" in graph_resource_access_token:
return "Error: Unable to obtain Graph API access token for authorization check"

graph_auth_token = graph_resource_access_token["access_token"]

# Check for the specific admin group ID using transitive membership
admin_group_id = os.environ.get("ENTRA_ADMIN_GROUP_ID", "")
if not admin_group_id:
return "Error: Admin group ID not configured. Set ENTRA_ADMIN_GROUP_ID environment variable."
is_admin = await check_user_in_group(graph_auth_token, admin_group_id)

if not is_admin:
return "Error: Unauthorized. You do not have permission to access expense statistics."

# Query Cosmos DB for stats across all users
# We fetch categories and aggregate in Python to avoid cross-partition GROUP BY limitations
query = "SELECT c.category FROM c"
stats = {}
async for item in cosmos_container.query_items(query=query):
category = item.get("category", "Unknown")
stats[category] = stats.get(category, 0) + 1

if not stats:
return "No expense data found to summarize."

summary = "Expense Statistics (Count per Category):\n"
for category, count in stats.items():
summary += f"- Category {category}: {count} expenses\n"

return summary

except Exception as e:
logger.error(f"Error retrieving expense stats: {str(e)}")
return f"Error: Unable to retrieve expense statistics - {str(e)}"


@mcp.custom_route("/health", methods=["GET"])
async def health_check(_request):
"""Health check endpoint for service availability."""
Expand Down