diff --git a/DSL/Resql/rag-search/POST/get-testing-connection.sql b/DSL/Resql/rag-search/POST/get-testing-connection.sql new file mode 100644 index 0000000..93e9149 --- /dev/null +++ b/DSL/Resql/rag-search/POST/get-testing-connection.sql @@ -0,0 +1,25 @@ +SELECT + id, + connection_name, + used_budget, + monthly_budget, + warn_budget_threshold, + stop_budget_threshold, + environment, + connection_status, + created_at, + llm_platform, + llm_model, + embedding_platform, + embedding_model, + CASE + WHEN used_budget IS NULL OR used_budget = 0 OR (used_budget::DECIMAL / monthly_budget::DECIMAL) < (warn_budget_threshold::DECIMAL / 100.0) THEN 'within_budget' + WHEN stop_budget_threshold != 0 AND (used_budget::DECIMAL / monthly_budget::DECIMAL) >= (stop_budget_threshold::DECIMAL / 100.0) THEN 'over_budget' + WHEN stop_budget_threshold = 0 AND (used_budget::DECIMAL / monthly_budget::DECIMAL) >= 1 THEN 'over_budget' + WHEN (used_budget::DECIMAL / monthly_budget::DECIMAL) >= (warn_budget_threshold::DECIMAL / 100.0) THEN 'close_to_exceed' + ELSE 'within_budget' + END AS budget_status +FROM llm_connections +WHERE environment = 'testing' +ORDER BY created_at DESC +LIMIT 1; \ No newline at end of file diff --git a/DSL/Resql/rag-search/POST/store-production-inference-result.sql b/DSL/Resql/rag-search/POST/store-inference-result.sql similarity index 90% rename from DSL/Resql/rag-search/POST/store-production-inference-result.sql rename to DSL/Resql/rag-search/POST/store-inference-result.sql index bb5b553..089e92d 100644 --- a/DSL/Resql/rag-search/POST/store-production-inference-result.sql +++ b/DSL/Resql/rag-search/POST/store-inference-result.sql @@ -7,6 +7,7 @@ INSERT INTO inference_results ( embedding_scores, final_answer, environment, + llm_connection_id, created_at ) VALUES ( :chat_id, @@ -17,6 +18,7 @@ INSERT INTO inference_results ( :embedding_scores::JSONB, :final_answer, :environment, + :llm_connection_id, :created_at::timestamp with time zone ) RETURNING id, @@ -28,4 +30,5 @@ INSERT INTO inference_results ( embedding_scores, final_answer, environment, + llm_connection_id, created_at; diff --git a/DSL/Ruuter.private/rag-search/POST/inference/results/test/store.yml b/DSL/Ruuter.private/rag-search/POST/inference/results/test/store.yml deleted file mode 100644 index c83203e..0000000 --- a/DSL/Ruuter.private/rag-search/POST/inference/results/test/store.yml +++ /dev/null @@ -1,94 +0,0 @@ -declaration: - call: declare - version: 0.1 - description: "Store inference result" - method: post - accepts: json - returns: json - namespace: rag-search - allowlist: - body: - - field: llm_connection_id - type: number - description: "LLM connection ID" - - field: user_question - type: string - description: "User's question/input" - - field: final_answer - type: string - description: "LLM's final generated answer" - -extract_request_data: - assign: - llm_connection_id: ${Number(incoming.body.llm_connection_id)} - user_question: ${incoming.body.user_question} - final_answer: ${incoming.body.final_answer} - created_at: ${new Date().toISOString()} - next: check_llm_connection_exists - -check_llm_connection_exists: - call: http.post - args: - url: "[#RAG_SEARCH_RESQL]/get-llm-connection" - body: - connection_id: ${llm_connection_id} - result: connection_result - next: validate_connection_exists - -validate_connection_exists: - switch: - - condition: "${connection_result.response.body.length > 0}" - next: store_inference_result - next: return_connection_not_found - -store_inference_result: - call: http.post - args: - url: "[#RAG_SEARCH_RESQL]/store-testing-inference-result" - body: - llm_connection_id: ${llm_connection_id} - user_question: ${user_question} - final_answer: ${final_answer} - environment: "testing" - created_at: ${created_at} - result: store_result - next: check_status - -check_status: - switch: - - condition: ${200 <= store_result.response.statusCodeValue && store_result.response.statusCodeValue < 300} - next: format_success_response - next: format_failed_response - -format_success_response: - assign: - data_success: { - data: '${store_result.response.body[0]}', - operationSuccess: true, - statusCode: 200 - } - next: return_success - -format_failed_response: - assign: - data_failed: { - data: '[]', - operationSuccess: false, - statusCode: 400 - } - next: return_bad_request - -return_success: - return: ${data_success} - status: 200 - next: end - -return_bad_request: - return: ${data_failed} - status: 400 - next: end - -return_connection_not_found: - status: 404 - return: "error: LLM connection not found" - next: end diff --git a/DSL/Ruuter.private/rag-search/POST/inference/results/production/store.yml b/DSL/Ruuter.public/rag-search/POST/inference/results/store.yml similarity index 85% rename from DSL/Ruuter.private/rag-search/POST/inference/results/production/store.yml rename to DSL/Ruuter.public/rag-search/POST/inference/results/store.yml index 32c5093..19d8adf 100644 --- a/DSL/Ruuter.private/rag-search/POST/inference/results/production/store.yml +++ b/DSL/Ruuter.public/rag-search/POST/inference/results/store.yml @@ -29,6 +29,12 @@ declaration: - field: final_answer type: string description: "LLM's final generated answer" + - field: environment + type: string + description: "Environment identifier (e.g., production, testing)" + - field: llm_connection_id + type: string + description: "Connection identifier" extract_request_data: assign: @@ -39,6 +45,8 @@ extract_request_data: ranked_chunks: ${JSON.stringify(incoming.body.ranked_chunks) || null} embedding_scores: ${JSON.stringify(incoming.body.embedding_scores) || null} final_answer: ${incoming.body.final_answer} + environment: ${incoming.body.environment} + llm_connection_id: ${incoming.body.llm_connection_id} created_at: ${new Date().toISOString()} next: validate_required_fields @@ -51,7 +59,7 @@ validate_required_fields: store_production_inference_result: call: http.post args: - url: "[#RAG_SEARCH_RESQL]/store-production-inference-result" + url: "[#RAG_SEARCH_RESQL]/store-inference-result" body: chat_id: ${chat_id} user_question: ${user_question} @@ -60,7 +68,8 @@ store_production_inference_result: ranked_chunks: ${ranked_chunks} embedding_scores: ${embedding_scores} final_answer: ${final_answer} - environment: "production" + environment: ${environment} + llm_connection_id: ${llm_connection_id} created_at: ${created_at} result: store_result next: check_status diff --git a/Dockerfile.llm_orchestration_service b/Dockerfile.llm_orchestration_service index 989177e..0a4f979 100644 --- a/Dockerfile.llm_orchestration_service +++ b/Dockerfile.llm_orchestration_service @@ -21,4 +21,4 @@ RUN uv sync --locked EXPOSE 8100 # Run the FastAPI app via uvicorn -CMD ["uv","run","uvicorn", "src.llm_orchestration_service_api:app", "--host", "0.0.0.0", "--port", "8100"] +CMD ["uv","run","uvicorn", "src.llm_orchestration_service_api:app", "--host", "0.0.0.0", "--port", "8100"] \ No newline at end of file diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 2de809a..59417d5 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -1,14 +1,15 @@ """LLM Orchestration Service - Business logic for LLM orchestration.""" from typing import Optional, List, Dict, Union, Any, AsyncIterator -import json import os import time +import asyncio from loguru import logger from langfuse import Langfuse, observe import dspy from datetime import datetime import json as json_module +import threading from llm_orchestrator_config.llm_manager import LLMManager from models.request_models import ( @@ -19,11 +20,11 @@ ContextGenerationRequest, TestOrchestrationResponse, ChunkInfo, + DocumentReference, ) from prompt_refine_manager.prompt_refiner import PromptRefinerAgent from src.response_generator.response_generate import ResponseGeneratorAgent from src.response_generator.response_generate import stream_response_native -from src.vector_indexer.constants import ResponseGenerationConstants from src.llm_orchestrator_config.llm_ochestrator_constants import ( OUT_OF_SCOPE_MESSAGE, TECHNICAL_ISSUE_MESSAGE, @@ -32,12 +33,16 @@ GUARDRAILS_BLOCKED_PHRASES, TEST_DEPLOYMENT_ENVIRONMENT, STREAM_TOKEN_LIMIT_MESSAGE, + PRODUCTION_DEPLOYMENT_ENVIRONMENT, ) from src.llm_orchestrator_config.stream_config import StreamConfig +from src.vector_indexer.constants import ResponseGenerationConstants from src.utils.error_utils import generate_error_id, log_error_with_context from src.utils.stream_manager import stream_manager from src.utils.cost_utils import calculate_total_costs, get_lm_usage_since from src.utils.time_tracker import log_step_timings +from src.utils.budget_tracker import get_budget_tracker +from src.utils.production_store import get_production_store from src.guardrails import NeMoRailsAdapter, GuardrailCheckResult from src.contextual_retrieval import ContextualRetriever from src.llm_orchestrator_config.exceptions import ( @@ -133,6 +138,12 @@ def process_orchestration_request( # Log final costs and return response self._log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) + + # Update budget for the LLM connection + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client total_costs = calculate_total_costs(costs_dict) @@ -184,6 +195,12 @@ def process_orchestration_request( langfuse.flush() self._log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) + + # Update budget even on error + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + return self._create_error_response(request) @observe(name="streaming_generation", as_type="generation", capture_output=False) @@ -393,7 +410,9 @@ async def bot_response_generator() -> AsyncIterator[str]: # Wrap entire streaming logic in try/except for proper error handling try: - # Track tokens in stream context + # Track tokens and accumulated response in stream context + accumulated_response = [] # Track the full response for production storage + if components["guardrails_adapter"]: # Use NeMo's stream_with_guardrails helper method # This properly integrates the external generator with NeMo's validation @@ -412,6 +431,9 @@ async def bot_response_generator() -> AsyncIterator[str]: chunk_tokens = len(validated_chunk) // 4 stream_ctx.token_count += chunk_tokens + # Accumulate response for production storage + accumulated_response.append(validated_chunk) + # Check token limit if ( stream_ctx.token_count @@ -482,7 +504,10 @@ async def bot_response_generator() -> AsyncIterator[str]: return # Cleanup happens in finally # Log first few chunks for debugging - if chunk_count <= 10: + if ( + chunk_count + <= ResponseGenerationConstants.DEFAULT_MAX_BLOCKS + ): logger.debug( f"[{request.chatId}] [{stream_ctx.stream_id}] Validated chunk {chunk_count}: {repr(validated_chunk)}" ) @@ -501,6 +526,31 @@ async def bot_response_generator() -> AsyncIterator[str]: f"[{request.chatId}] [{stream_ctx.stream_id}] Stream completed successfully " f"({chunk_count} chunks streamed)" ) + + # Send document references before END token + doc_references = self._extract_document_references( + relevant_chunks + ) + if doc_references: + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Sending {len(doc_references)} document references before END" + ) + references_data = [ + ref.model_dump() for ref in doc_references + ] + references_message = { + "chatId": request.chatId, + "payload": { + "type": "references", + "references": references_data, + }, + "timestamp": str( + int(datetime.now().timestamp() * 1000) + ), + "sentTo": [], + } + yield f"data: {json_module.dumps(references_message)}\n\n" + yield self._format_sse(request.chatId, "END") else: @@ -516,6 +566,9 @@ async def bot_response_generator() -> AsyncIterator[str]: token_estimate = len(token) // 4 stream_ctx.token_count += token_estimate + # Accumulate response for production storage + accumulated_response.append(token) + if ( stream_ctx.token_count > StreamConfig.MAX_TOKENS_PER_STREAM @@ -533,6 +586,30 @@ async def bot_response_generator() -> AsyncIterator[str]: yield self._format_sse(request.chatId, token) + # Send document references before END token + doc_references = self._extract_document_references( + relevant_chunks + ) + if doc_references: + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Sending {len(doc_references)} document references before END" + ) + references_data = [ + ref.model_dump() for ref in doc_references + ] + references_message = { + "chatId": request.chatId, + "payload": { + "type": "references", + "references": references_data, + }, + "timestamp": str( + int(datetime.now().timestamp() * 1000) + ), + "sentTo": [], + } + yield f"data: {json_module.dumps(references_message)}\n\n" + yield self._format_sse(request.chatId, "END") # Extract usage information after streaming completes @@ -558,6 +635,11 @@ async def bot_response_generator() -> AsyncIterator[str]: self._log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) + # Update budget for the LLM connection + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client total_costs = calculate_total_costs(costs_dict) @@ -586,6 +668,24 @@ async def bot_response_generator() -> AsyncIterator[str]: ) langfuse.flush() + # Store inference data (for production and testing environments) + if request.environment in [ + PRODUCTION_DEPLOYMENT_ENVIRONMENT, + TEST_DEPLOYMENT_ENVIRONMENT, + ]: + try: + await self._store_production_inference_data_async( + request=request, + refined_output=refined_output, + relevant_chunks=relevant_chunks, + accumulated_response="".join(accumulated_response), + ) + except Exception as storage_error: + # Log storage error but don't fail the request + logger.error( + f"Storage failed for chat_id: {request.chatId}, environment: {request.environment} - {str(storage_error)}" + ) + # Mark stream as completed successfully stream_ctx.mark_completed() @@ -599,6 +699,11 @@ async def bot_response_generator() -> AsyncIterator[str]: costs_dict["streaming_generation"] = usage_info self._log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) + + # Update budget even on client disconnect + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) raise except Exception as stream_error: error_id = generate_error_id() @@ -618,6 +723,11 @@ async def bot_response_generator() -> AsyncIterator[str]: self._log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) + # Update budget even on streaming error + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + except Exception as e: error_id = generate_error_id() stream_ctx.mark_error(error_id) @@ -631,6 +741,11 @@ async def bot_response_generator() -> AsyncIterator[str]: self._log_costs(costs_dict) log_step_timings(timing_dict, request.chatId) + # Update budget even on outer exception + self._update_connection_budget( + request.connection_id, costs_dict, request.environment + ) + if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client langfuse.update_current_generation( @@ -851,9 +966,27 @@ def _execute_orchestration_pipeline( components["guardrails_adapter"], generated_response, request, costs_dict ) timing_dict["output_guardrails_check"] = time.time() - start_time + + # Step 6: Store inference data (for production and testing environments) + if request.environment in [ + PRODUCTION_DEPLOYMENT_ENVIRONMENT, + TEST_DEPLOYMENT_ENVIRONMENT, + ]: + try: + self._store_production_inference_data( + request=request, + refined_output=refined_output, + relevant_chunks=relevant_chunks, + final_response=output_guardrails_response, + ) + except Exception as storage_error: + # Log storage error but don't fail the request + logger.error( + f"Storage failed for chat_id: {request.chatId}, environment: {request.environment} - {str(storage_error)}" + ) + return output_guardrails_response - @observe(name="safe_initialize_guardrails", as_type="span") def _safe_initialize_guardrails( self, environment: str, connection_id: Optional[str] ) -> Optional[NeMoRailsAdapter]: @@ -945,7 +1078,6 @@ def _safe_retrieve_contextual_chunks_sync( request: OrchestrationRequest, ) -> List[Dict[str, Union[str, float, Dict[str, Any]]]]: """Synchronous wrapper for _safe_retrieve_contextual_chunks for non-streaming pipeline.""" - import asyncio try: # Safely execute the async method in the sync context @@ -1079,6 +1211,175 @@ def _create_out_of_scope_response( content=OUT_OF_SCOPE_MESSAGE, ) + def _store_production_inference_data( + self, + request: OrchestrationRequest, + refined_output: PromptRefinerOutput, + relevant_chunks: List[Dict[str, Union[str, float, Dict[str, Any]]]], + final_response: OrchestrationResponse, + ) -> None: + """ + Store production inference data to Resql endpoint for analytics. + + This method stores comprehensive inference data including: + - User question and refined questions + - Conversation history + - Retrieved chunks with rankings + - Embedding scores + - Final generated answer + + Args: + request: Original orchestration request + refined_output: Prompt refiner output with original and refined questions + relevant_chunks: Retrieved and ranked chunks + final_response: Final orchestration response with generated answer + """ + try: + # Only store if the service was active and response was generated successfully + if not final_response.llmServiceActive: + logger.debug( + f"Skipping production data storage for chat_id: {request.chatId} " + f"- LLM service was not active" + ) + return + + # Extract embedding scores from chunks + embedding_scores = [] + for chunk in relevant_chunks: + score_value = chunk.get("fused_score", chunk.get("score", 0.0)) + try: + if isinstance(score_value, (int, float)): + embedding_scores.append(float(score_value)) + else: + embedding_scores.append(0.0) + except (ValueError, TypeError): + embedding_scores.append(0.0) + + # Convert conversation history to list of dicts + conversation_history_list = [ + {"role": item.authorRole, "content": item.message} + for item in (request.conversationHistory or []) + ] + + # Get the production store instance + production_store = get_production_store() + + # Store the inference result asynchronously without blocking + + def store_async(): + """Run async storage in a new event loop in a separate thread.""" + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result = loop.run_until_complete( + production_store.store_inference_result_async( + chat_id=request.chatId, + user_question=request.message, + refined_questions=refined_output.refined_questions, + conversation_history=conversation_history_list, + ranked_chunks=relevant_chunks, + embedding_scores=embedding_scores, + final_answer=final_response.content, + environment=request.environment, + ) + ) + loop.close() + + if result["success"]: + logger.info( + f"Successfully stored inference data for chat_id: {request.chatId}, environment: {request.environment}" + ) + else: + logger.warning( + f"Failed to store inference data for chat_id: {request.chatId}, environment: {request.environment} - " + f"Error: {result['error']}" + ) + except Exception as e: + logger.error(f"Error in async storage thread: {str(e)}") + + # Start storage in background thread (non-blocking) + storage_thread = threading.Thread(target=store_async, daemon=True) + storage_thread.start() + + except Exception as e: + # Log the error but don't fail the request + logger.error( + f"Error storing inference data for chat_id: {request.chatId}, environment: {request.environment} - {str(e)}" + ) + + async def _store_production_inference_data_async( + self, + request: OrchestrationRequest, + refined_output: PromptRefinerOutput, + relevant_chunks: List[Dict[str, Union[str, float, Dict[str, Any]]]], + accumulated_response: str, + ) -> None: + """ + Async version: Store production inference data to Resql endpoint for analytics. + + This method stores comprehensive inference data including: + - User question and refined questions + - Conversation history + - Retrieved chunks with rankings + - Embedding scores + - Final generated answer (from streaming) + + Args: + request: Original orchestration request + refined_output: Prompt refiner output with original and refined questions + relevant_chunks: Retrieved and ranked chunks + accumulated_response: Complete streamed response + """ + try: + # Extract embedding scores from chunks + embedding_scores = [] + for chunk in relevant_chunks: + score_value = chunk.get("fused_score", chunk.get("score", 0.0)) + try: + if isinstance(score_value, (int, float)): + embedding_scores.append(float(score_value)) + else: + embedding_scores.append(0.0) + except (ValueError, TypeError): + embedding_scores.append(0.0) + + # Convert conversation history to list of dicts + conversation_history_list = [ + {"role": item.authorRole, "content": item.message} + for item in (request.conversationHistory or []) + ] + + # Get the production store instance + production_store = get_production_store() + + # Store the inference result (async) + result = await production_store.store_inference_result_async( + chat_id=request.chatId, + user_question=request.message, + refined_questions=refined_output.refined_questions, + conversation_history=conversation_history_list, + ranked_chunks=relevant_chunks, + embedding_scores=embedding_scores, + final_answer=accumulated_response, + environment=request.environment, + ) + + if result["success"]: + logger.info( + f"Successfully stored inference data (async) for chat_id: {request.chatId}, environment: {request.environment}" + ) + else: + logger.warning( + f"Failed to store inference data (async) for chat_id: {request.chatId}, environment: {request.environment} - " + f"Error: {result['error']}" + ) + + except Exception as e: + # Log the error but don't fail the request + logger.error( + f"Error storing inference data (async) for chat_id: {request.chatId}, environment: {request.environment} - {str(e)}" + ) + @observe(name="initialize_guardrails", as_type="span") def _initialize_guardrails( self, environment: str, connection_id: Optional[str] @@ -1395,6 +1696,70 @@ def _log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: except Exception as e: logger.warning(f"Failed to log costs: {str(e)}") + def _update_connection_budget( + self, + connection_id: Optional[str], + costs_dict: Dict[str, Dict[str, Any]], + environment: str = "development", + ) -> None: + """ + Update the budget for an LLM connection based on usage costs. + For production environment, fetches the connection ID asynchronously if not provided. + + Args: + connection_id: The LLM connection ID (optional) + costs_dict: Dictionary of costs per component + environment: The deployment environment (production/testing/development) + """ + try: + budget_tracker = get_budget_tracker() + + # For production environment, fetch connection ID if not provided + if environment == "production" and not connection_id: + logger.debug( + "Production environment detected, fetching connection ID..." + ) + try: + # Use synchronous fetch to avoid event loop issues + production_id = ( + budget_tracker.connection_fetcher.fetch_connection_id_sync( + "production" + ) + ) + if production_id: + connection_id = str(production_id) + logger.info(f"Using production connection_id: {connection_id}") + else: + logger.warning("Could not fetch production connection ID") + except Exception as fetch_error: + logger.error( + f"Error fetching production connection ID: {str(fetch_error)}" + ) + + result = budget_tracker.update_budget_from_costs(connection_id, costs_dict) + + if result.get("success"): + if result.get("budget_exceeded"): + logger.warning( + f"Budget threshold exceeded for connection_id={connection_id}. " + "Connection may have been deactivated." + ) + else: + logger.debug( + f"Budget updated successfully for connection_id={connection_id}" + ) + else: + reason = result.get("reason", "unknown") + if reason not in ["no_connection_id", "zero_or_negative_cost"]: + logger.warning( + f"Failed to update budget for connection_id={connection_id}. " + f"Reason: {reason}" + ) + + except Exception as e: + # Don't fail the orchestration if budget update fails + logger.error(f"Error updating budget: {str(e)}") + @observe(name="initialize_llm_manager", as_type="span") def _initialize_llm_manager( self, environment: str, connection_id: Optional[str] @@ -1620,17 +1985,13 @@ def _format_chunks_for_test_response( relevant_chunks: List of retrieved chunks with metadata Returns: - List of ChunkInfo objects with rank and content (limited to top 5), or None if no chunks + List of ChunkInfo objects with rank and content, or None if no chunks """ if not relevant_chunks: return None - # Limit to top-k chunks that are actually used in response generation - max_blocks = ResponseGenerationConstants.DEFAULT_MAX_BLOCKS - limited_chunks = relevant_chunks[:max_blocks] - formatted_chunks = [] - for rank, chunk in enumerate(limited_chunks, start=1): + for rank, chunk in enumerate(relevant_chunks, start=1): # Extract text content - prefer "text" key, fallback to "content" chunk_text = chunk.get("text", chunk.get("content", "")) if isinstance(chunk_text, str) and chunk_text.strip(): @@ -1638,6 +1999,63 @@ def _format_chunks_for_test_response( return formatted_chunks if formatted_chunks else None + @staticmethod + def _extract_document_references( + relevant_chunks: Optional[List[Dict[str, Union[str, float, Dict[str, Any]]]]], + ) -> Optional[List[DocumentReference]]: + """ + Extract unique document references from retrieved chunks. + + Args: + relevant_chunks: List of retrieved chunks with metadata + + Returns: + List of DocumentReference objects, or None if no chunks + """ + if not relevant_chunks: + return None + + seen_urls: set[str] = set() + references: List[DocumentReference] = [] + + for rank, chunk in enumerate(relevant_chunks, start=1): + # Extract document_url - try multiple keys for robustness + doc_url = chunk.get("document_url") + if not doc_url: + # Fallback to metadata + meta = chunk.get("meta", {}) + if isinstance(meta, dict): + doc_url = ( + meta.get("document_url") + or meta.get("source_file") + or meta.get("source") + ) + + if doc_url and isinstance(doc_url, str) and doc_url.strip(): + # Only include unique URLs (deduplicate) + if doc_url not in seen_urls: + seen_urls.add(doc_url) + + # Extract score - try multiple keys, ensure it's a float + score_value = chunk.get("fused_score") or chunk.get("score", 0.0) + try: + if isinstance(score_value, (int, float)): + score = float(score_value) + else: + score = 0.0 + except (ValueError, TypeError): + score = 0.0 + + references.append( + DocumentReference( + document_url=doc_url, + chunk_rank=rank, + relevance_score=round(score, 4), + ) + ) + + return references if references else None + @observe(name="generate_rag_response", as_type="generation") def _generate_rag_response( self, @@ -1730,6 +2148,19 @@ def _generate_rag_response( ) if question_out_of_scope: logger.info("Question determined out-of-scope – sending fixed message.") + + # Extract document references even for out-of-scope + doc_references = self._extract_document_references(relevant_chunks) + + # Append references to content + content_with_refs = OUT_OF_SCOPE_MESSAGE + if doc_references: + refs_text = "\n\n**References:**\n" + "\n".join( + f"{i + 1}. {ref.document_url}" + for i, ref in enumerate(doc_references) + ) + content_with_refs += refs_text + if request.environment == TEST_DEPLOYMENT_ENVIRONMENT: logger.info( "Test environment detected – returning out-of-scope message." @@ -1738,7 +2169,7 @@ def _generate_rag_response( llmServiceActive=True, # service OK; insufficient context questionOutOfLLMScope=True, inputGuardFailed=False, - content=OUT_OF_SCOPE_MESSAGE, + content=content_with_refs, chunks=self._format_chunks_for_test_response(relevant_chunks), ) else: @@ -1747,18 +2178,29 @@ def _generate_rag_response( llmServiceActive=True, # service OK; insufficient context questionOutOfLLMScope=True, inputGuardFailed=False, - content=OUT_OF_SCOPE_MESSAGE, + content=content_with_refs, ) # In-scope: return the answer as-is (NO citations) logger.info("Returning in-scope answer without citations.") + + # Extract document references and append to content + doc_references = self._extract_document_references(relevant_chunks) + content_with_refs = answer + if doc_references: + refs_text = "\n\n**References:**\n" + "\n".join( + f"{i + 1}. {ref.document_url}" + for i, ref in enumerate(doc_references) + ) + content_with_refs += refs_text + if request.environment == TEST_DEPLOYMENT_ENVIRONMENT: logger.info("Test environment detected – returning generated answer.") return TestOrchestrationResponse( llmServiceActive=True, questionOutOfLLMScope=False, inputGuardFailed=False, - content=answer, + content=content_with_refs, chunks=self._format_chunks_for_test_response(relevant_chunks), ) else: @@ -1767,7 +2209,7 @@ def _generate_rag_response( llmServiceActive=True, questionOutOfLLMScope=False, inputGuardFailed=False, - content=answer, + content=content_with_refs, ) except Exception as e: @@ -1880,7 +2322,7 @@ def generate_context_for_chunks( raise def get_available_embedding_models_for_indexer( - self, environment: str = "production" + self, environment: str = PRODUCTION_DEPLOYMENT_ENVIRONMENT ) -> Dict[str, Any]: """Get available embedding models for vector indexer. diff --git a/src/llm_orchestrator_config/llm_ochestrator_constants.py b/src/llm_orchestrator_config/llm_ochestrator_constants.py index b534229..b53b3d7 100644 --- a/src/llm_orchestrator_config/llm_ochestrator_constants.py +++ b/src/llm_orchestrator_config/llm_ochestrator_constants.py @@ -4,7 +4,7 @@ ) TECHNICAL_ISSUE_MESSAGE = ( - "2. Technical issue with response generation\n" + "Technical issue with response generation\n" "I apologize, but I’m currently unable to generate a response due to a temporary technical issue. " "Please try again in a moment." ) @@ -25,6 +25,7 @@ # Streaming configuration STREAMING_ALLOWED_ENVS = {"production"} TEST_DEPLOYMENT_ENVIRONMENT = "testing" +PRODUCTION_DEPLOYMENT_ENVIRONMENT = "production" # Stream limit error messages STREAM_TIMEOUT_MESSAGE = ( @@ -86,3 +87,8 @@ VALIDATION_REQUIRED_FIELDS_MISSING = "Required information is missing from your request. Please ensure all required fields are provided." VALIDATION_GENERIC_ERROR = "I apologize, but I couldn't process your request. Please check your input and try again." + +# Service endpoints +RAG_SEARCH_RESQL = "http://resql:8082/rag-search" +RAG_SEARCH_RUUTER_PUBLIC = "http://ruuter-public:8086/rag-search" +RAG_SEARCH_RUUTER_PRIVATE = "http://ruuter-private:8088/rag-search" diff --git a/src/models/request_models.py b/src/models/request_models.py index 2239425..f4a073c 100644 --- a/src/models/request_models.py +++ b/src/models/request_models.py @@ -129,6 +129,14 @@ def validate_payload_size(self) -> "OrchestrationRequest": return self +class DocumentReference(BaseModel): + """Model for document reference with URL.""" + + document_url: str = Field(..., description="Source document URL") + chunk_rank: int = Field(..., description="Rank of chunk in retrieval (1-based)") + relevance_score: float = Field(..., description="Relevance score (0-1)") + + class OrchestrationResponse(BaseModel): """Model for LLM orchestration response.""" diff --git a/src/utils/budget_tracker.py b/src/utils/budget_tracker.py new file mode 100644 index 0000000..134b034 --- /dev/null +++ b/src/utils/budget_tracker.py @@ -0,0 +1,223 @@ +"""Budget tracking utility for LLM connection usage.""" + +from typing import Optional, Dict, Any, cast, List +from loguru import logger +import requests + +from ..llm_orchestrator_config.llm_ochestrator_constants import RAG_SEARCH_RESQL +from .connection_id_fetcher import get_connection_id_fetcher + + +class BudgetTracker: + """Handles budget updates for LLM connections.""" + + def __init__(self): + """Initialize the budget tracker with Resql and Ruuter endpoints.""" + # Use Resql directly for budget updates + self.resql_base = RAG_SEARCH_RESQL + self.update_endpoint = f"{self.resql_base}/update-llm-connection-used-budget" + + self.timeout = 5 # seconds + + # Use centralized connection ID fetcher + self.connection_fetcher = get_connection_id_fetcher() + + def _validate_connection_id(self, connection_id: Optional[str]) -> Optional[int]: + """ + Validate and convert connection_id to integer. + + Args: + connection_id: The connection ID to validate + + Returns: + Integer connection ID, or None if invalid + """ + if not connection_id: + logger.debug("No connection_id provided, skipping budget update") + return None + + try: + return int(connection_id) + except (ValueError, TypeError): + logger.warning( + f"Connection ID '{connection_id}' is not numeric. " + f"Budget tracking requires numeric database IDs. " + f"Skipping budget update for this request." + ) + return None + + def _make_budget_update_request( + self, connection_id_int: int, usage_cost: float + ) -> Dict[str, Any]: + """ + Make the actual budget update API request. + + Args: + connection_id_int: The integer connection ID + usage_cost: The cost to add + + Returns: + Dictionary containing the response or error + """ + payload = {"connection_id": connection_id_int, "usage": usage_cost} + logger.info( + f"Updating budget for connection_id={connection_id_int}, usage={usage_cost}" + ) + + response = requests.post( + self.update_endpoint, json=payload, timeout=self.timeout + ) + + if response.status_code == 200: + response_data: Any = response.json() + + # Resql returns a list, so get the first item + data: Any + if isinstance(response_data, list): + typed_list = cast(List[Any], response_data) + if len(typed_list) > 0: + data = typed_list[0] + else: + data = {} # Empty dict if list is empty + else: + data = response_data + + logger.info( + f"Budget updated successfully for connection_id={connection_id_int}" + ) + + # Check if budget was exceeded + budget_exceeded: bool = False + if isinstance(data, dict): + budget_exceeded_value = cast(Dict[str, Any], data).get( + "budgetExceeded", False + ) + budget_exceeded = bool(budget_exceeded_value) + + if budget_exceeded: + logger.warning( + f"Budget threshold exceeded for connection_id={connection_id_int}. " + f"Connection may have been deactivated." + ) + + return { + "success": True, + "data": data, + "budget_exceeded": budget_exceeded, + } + else: + logger.error( + f"Failed to update budget for connection_id={connection_id_int}. " + f"Status: {response.status_code}, Response: {response.text}" + ) + return { + "success": False, + "reason": "api_error", + "status_code": response.status_code, + "error_message": response.text, + } + + def update_budget( + self, connection_id: Optional[str], usage_cost: float + ) -> Dict[str, Any]: + """ + Update the used budget for an LLM connection. + + Args: + connection_id: The LLM connection ID (can be numeric ID or string identifier) + usage_cost: The cost to add to the used budget + + Returns: + Dictionary containing the response from the update endpoint + or an error indicator if the update failed + """ + # If no connection ID provided, try to fetch production connection ID + if not connection_id: + logger.debug( + "No connection_id provided, attempting to fetch production connection ID" + ) + try: + fetched_id = self.connection_fetcher.fetch_connection_id_sync( + "production" + ) + if fetched_id is not None: + connection_id = str(fetched_id) + logger.debug( + f"Using fetched production connection_id: {connection_id}" + ) + except Exception as e: + logger.warning(f"Failed to fetch production connection ID: {str(e)}") + + # Validate connection_id + connection_id_int = self._validate_connection_id(connection_id) + if connection_id_int is None: + return { + "success": False, + "reason": "no_connection_id" + if not connection_id + else "non_numeric_connection_id", + "connection_id": connection_id, + } + + # Skip if usage cost is 0 or negative + if usage_cost <= 0: + logger.debug(f"Usage cost is {usage_cost}, skipping budget update") + return {"success": False, "reason": "zero_or_negative_cost"} + + try: + return self._make_budget_update_request(connection_id_int, usage_cost) + + except requests.exceptions.Timeout: + logger.error( + f"Timeout while updating budget for connection_id={connection_id}" + ) + return {"success": False, "reason": "timeout"} + + except requests.exceptions.RequestException as e: + logger.error( + f"Request error while updating budget for connection_id={connection_id}: {str(e)}" + ) + return {"success": False, "reason": "request_error", "error": str(e)} + + except Exception as e: + logger.error( + f"Unexpected error while updating budget for connection_id={connection_id}: {str(e)}" + ) + return {"success": False, "reason": "unexpected_error", "error": str(e)} + + def update_budget_from_costs( + self, connection_id: Optional[str], costs_dict: Dict[str, Dict[str, Any]] + ) -> Dict[str, Any]: + """ + Update budget from a costs dictionary containing component costs. + + Args: + connection_id: The LLM connection ID (optional) + costs_dict: Dictionary of component costs with total_cost values + + Returns: + Dictionary containing the response from the update endpoint + """ + # Calculate total cost from all components + total_cost = 0.0 + for component_costs in costs_dict.values(): + total_cost += component_costs.get("total_cost", 0.0) + + logger.debug( + f"Total cost calculated from components: ${total_cost:.6f} " + f"(components: {list(costs_dict.keys())})" + ) + + return self.update_budget(connection_id, total_cost) + + +# Singleton instance +_budget_tracker_instance: Optional[BudgetTracker] = None + + +def get_budget_tracker() -> BudgetTracker: + """Get or create the singleton budget tracker instance.""" + global _budget_tracker_instance + if _budget_tracker_instance is None: + _budget_tracker_instance = BudgetTracker() + return _budget_tracker_instance diff --git a/src/utils/connection_id_fetcher.py b/src/utils/connection_id_fetcher.py new file mode 100644 index 0000000..903ad0b --- /dev/null +++ b/src/utils/connection_id_fetcher.py @@ -0,0 +1,235 @@ +""" +Connection ID utility for fetching LLM connection IDs by environment. + +This module provides functionality to fetch LLM connection IDs for different +environments (production, testing) that can be reused across services. +""" + +import asyncio +import threading +from typing import Optional, Dict, Any +from loguru import logger +import requests +import aiohttp + +from src.llm_orchestrator_config.llm_ochestrator_constants import RAG_SEARCH_RESQL + + +class ConnectionIdFetcher: + """ + Service for fetching LLM connection IDs by environment. + + This is a reusable utility that can be used by both budget tracker + and production store services. + """ + + def __init__(self): + """Initialize the connection ID fetcher with endpoints.""" + # Use Resql directly for consistent performance + self.resql_base = RAG_SEARCH_RESQL + self.timeout = 5 # seconds + + # Cache connection IDs to avoid repeated requests + self._connection_cache: Dict[str, int] = {} + # Thread-safe lock for cache access + self._cache_lock = threading.Lock() + + def _extract_connection_id_from_response(self, data: Any) -> Optional[int]: + """ + Extract connection ID from API response data. + + Args: + data: The JSON response data + + Returns: + The connection ID as integer, or None if not found + """ + # Handle different response formats + if isinstance(data, dict): + # Check if it's wrapped in response key + response_data: Any = data.get("response", data) + else: + response_data = data + + connection_id: Any = None + if isinstance(response_data, list): + # Array format: [{"id": 1, ...}] + if len(response_data) > 0 and isinstance(response_data[0], dict): + connection_id = response_data[0].get("id") + elif isinstance(response_data, dict): + # Object format: {"id": 1, ...} + connection_id = response_data.get("id") + + if connection_id is not None: + try: + return int(connection_id) + except (ValueError, TypeError): + logger.warning(f"Invalid connection ID format: {connection_id}") + return None + + return None + + def fetch_connection_id_sync(self, environment: str) -> Optional[int]: + """ + Synchronously fetch the LLM connection ID for specified environment. + + Args: + environment: The deployment environment ("production" or "testing") + + Returns: + The connection ID (integer) or None if unavailable + """ + # Return cached value if available + cache_key = f"{environment}_connection_id" + + # Thread-safe cache check + with self._cache_lock: + if cache_key in self._connection_cache: + cached_value = self._connection_cache[cache_key] + logger.debug( + f"Using cached connection_id for {environment}: {cached_value}" + ) + return cached_value + + try: + logger.debug(f"Fetching {environment} connection ID from Resql (sync)...") + + # Use Resql endpoint for getting connection by environment + endpoint = f"{self.resql_base}/get-{environment}-connection" + + response = requests.post(endpoint, json={}, timeout=self.timeout) + + if response.status_code == 200: + data = response.json() + connection_id = self._extract_connection_id_from_response(data) + + if connection_id is not None: + # Cache the connection ID (thread-safe) + with self._cache_lock: + self._connection_cache[cache_key] = connection_id + logger.info( + f"{environment.capitalize()} connection_id fetched: {connection_id}" + ) + return connection_id + else: + logger.warning(f"No {environment} connection ID found in response") + return None + else: + logger.error( + f"Failed to fetch {environment} connection. " + f"Status: {response.status_code}, Response: {response.text}" + ) + return None + + except requests.exceptions.Timeout: + logger.error(f"Timeout while fetching {environment} connection ID") + return None + + except Exception as e: + logger.error(f"Error fetching {environment} connection ID: {str(e)}") + return None + + async def fetch_connection_id_async(self, environment: str) -> Optional[int]: + """ + Asynchronously fetch the LLM connection ID for specified environment. + + Args: + environment: The deployment environment ("production" or "testing") + + Returns: + The connection ID (integer) or None if unavailable + """ + # Return cached value if available + cache_key = f"{environment}_connection_id" + + # Thread-safe cache check + with self._cache_lock: + if cache_key in self._connection_cache: + cached_value = self._connection_cache[cache_key] + logger.debug( + f"Using cached connection_id for {environment}: {cached_value}" + ) + return cached_value + + try: + logger.debug(f"Fetching {environment} connection ID from Resql (async)...") + + # Use Resql endpoint for getting connection by environment + endpoint = f"{self.resql_base}/get-{environment}-connection" + + async with aiohttp.ClientSession() as session: + async with session.post( + endpoint, + json={}, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as response: + if response.status == 200: + data = await response.json() + connection_id = self._extract_connection_id_from_response(data) + + if connection_id is not None: + # Cache the connection ID (thread-safe) + with self._cache_lock: + self._connection_cache[cache_key] = connection_id + logger.info( + f"{environment.capitalize()} connection_id fetched: {connection_id}" + ) + return connection_id + else: + logger.warning( + f"No {environment} connection ID found in response" + ) + return None + else: + error_text = await response.text() + logger.error( + f"Failed to fetch {environment} connection. " + f"Status: {response.status}, Response: {error_text}" + ) + return None + + except asyncio.TimeoutError: + logger.error(f"Timeout while fetching {environment} connection ID") + return None + except aiohttp.ClientError as e: + logger.error( + f"Client error while fetching {environment} connection ID: {str(e)}" + ) + return None + except Exception as e: + logger.error(f"Error fetching {environment} connection ID: {str(e)}") + return None + + def clear_cache(self, environment: Optional[str] = None): + """ + Clear the connection ID cache. + + Args: + environment: Specific environment to clear, or None to clear all + """ + with self._cache_lock: + if environment: + cache_key = f"{environment}_connection_id" + if cache_key in self._connection_cache: + del self._connection_cache[cache_key] + logger.debug(f"Cleared cache for {environment} connection_id") + else: + self._connection_cache.clear() + logger.debug("Cleared all connection_id cache") + + +# Singleton instance for reuse across modules +_connection_id_fetcher: Optional[ConnectionIdFetcher] = None + + +def get_connection_id_fetcher() -> ConnectionIdFetcher: + """ + Get the singleton connection ID fetcher instance. + + Returns: + ConnectionIdFetcher instance + """ + global _connection_id_fetcher + if _connection_id_fetcher is None: + _connection_id_fetcher = ConnectionIdFetcher() + return _connection_id_fetcher diff --git a/src/utils/production_store.py b/src/utils/production_store.py new file mode 100644 index 0000000..4d15f21 --- /dev/null +++ b/src/utils/production_store.py @@ -0,0 +1,322 @@ +""" +Production Inference Data Storage Utility + +This module provides functionality to store production inference results +to the Ruuter endpoint for analytics and monitoring purposes. +""" + +from typing import Dict, List, Any, Optional +from datetime import datetime +import json +from loguru import logger +import requests +import aiohttp +from src.utils.connection_id_fetcher import get_connection_id_fetcher +from ..llm_orchestrator_config.llm_ochestrator_constants import RAG_SEARCH_RUUTER_PUBLIC + + +class ProductionInferenceStore: + """ + Service for storing production inference results via Ruuter endpoint. + """ + + def __init__(self): + """Initialize the production inference store with Ruuter configuration.""" + self.store_endpoint = f"{RAG_SEARCH_RUUTER_PUBLIC}/inference/results/store" + self.timeout = 10 # seconds + self.connection_fetcher = get_connection_id_fetcher() + + def _create_payload( + self, + chat_id: str, + user_question: str, + refined_questions: List[str], + conversation_history: List[Dict[str, str]], + ranked_chunks: List[Dict[str, Any]], + embedding_scores: List[float], + final_answer: str, + environment: str, + connection_id: Optional[int], + ) -> Dict[str, Any]: + """Create the payload for storing inference results.""" + return { + "chat_id": chat_id, + "user_question": user_question, + "refined_questions": json.dumps(refined_questions), + "conversation_history": json.dumps(conversation_history), + "ranked_chunks": json.dumps(ranked_chunks), + "embedding_scores": json.dumps(embedding_scores), + "final_answer": final_answer, + "environment": environment, + "llm_connection_id": connection_id, + "created_at": datetime.now().isoformat(), + } + + def _handle_response_data( + self, response_data: Any, chat_id: str, environment: str + ) -> Dict[str, Any]: + """Handle and validate response data from the API.""" + # Handle nested response structure from Ruuter: {"response": {"data": {...}}} + if isinstance(response_data, dict) and "response" in response_data: + nested_data = response_data.get("response", {}) + if isinstance(nested_data, dict) and "data" in nested_data: + actual_data = nested_data.get("data") + if actual_data: + logger.info( + f"Successfully stored inference result for chat_id: {chat_id}, environment: {environment}" + ) + return { + "success": True, + "data": actual_data, + "error": None, + } + + # Fallback: handle simple list format for backward compatibility + if isinstance(response_data, list) and len(response_data) > 0: + logger.info( + f"Successfully stored inference result for chat_id: {chat_id}, environment: {environment}" + ) + return { + "success": True, + "data": response_data[0], # Return first item + "error": None, + } + + # Neither format matched - log warning + logger.warning( + f"Failed to store inference result for chat_id: {chat_id}, environment: {environment} - " + f"Empty or invalid response: {response_data}" + ) + return { + "success": False, + "data": None, + "error": "Empty or invalid response from server", + } + + def store_inference_result( + self, + chat_id: str, + user_question: str, + refined_questions: List[str], + conversation_history: List[Dict[str, str]], + ranked_chunks: List[Dict[str, Any]], + embedding_scores: List[float], + final_answer: str, + environment: str, + connection_id: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Store production inference result with comprehensive data. + + Args: + chat_id: Chat ID for this conversation + user_question: User's raw question/input + refined_questions: List of refined questions (LLM-generated) + conversation_history: Prior messages array of {role, content} + ranked_chunks: Retrieved chunks ranked with metadata + embedding_scores: Distance scores for each chunk + final_answer: LLM's final generated answer + environment: Deployment environment (production/testing) + connection_id: LLM connection ID (optional, will be fetched if not provided) + + Returns: + Dict containing: + - success (bool): Whether storage was successful + - data (Optional[Dict]): Response data from server + - error (Optional[str]): Error message if failed + """ + try: + # Fetch connection ID if not provided + if connection_id is None: + logger.debug(f"Fetching {environment} connection ID...") + connection_id = self.connection_fetcher.fetch_connection_id_sync( + environment + ) + if connection_id is None: + logger.warning( + f"Could not fetch {environment} connection ID, storing without it" + ) + + # Prepare the request payload + payload = self._create_payload( + chat_id, + user_question, + refined_questions, + conversation_history, + ranked_chunks, + embedding_scores, + final_answer, + environment, + connection_id, + ) + + logger.debug( + f"Storing inference result for chat_id: {chat_id}, environment: {environment}" + ) + + # Make the HTTP POST request to Ruuter endpoint + response = requests.post( + self.store_endpoint, + json=payload, + timeout=self.timeout, + ) + + # Check if the request was successful + if response.status_code == 200: + response_data = response.json() + return self._handle_response_data(response_data, chat_id, environment) + else: + error_msg = ( + f"Failed to store production inference result. " + f"Status: {response.status_code}, Response: {response.text}" + ) + logger.error(error_msg) + return { + "success": False, + "data": None, + "error": error_msg, + } + + except requests.exceptions.Timeout: + error_msg = f"Timeout while storing production inference result for chat_id: {chat_id}" + logger.error(error_msg) + return { + "success": False, + "data": None, + "error": error_msg, + } + except requests.exceptions.RequestException as e: + error_msg = ( + f"Request error while storing production inference result: {str(e)}" + ) + logger.error(error_msg) + return { + "success": False, + "data": None, + "error": error_msg, + } + except Exception as e: + error_msg = ( + f"Unexpected error while storing production inference result: {str(e)}" + ) + logger.error(error_msg) + return { + "success": False, + "data": None, + "error": error_msg, + } + + async def store_inference_result_async( + self, + chat_id: str, + user_question: str, + refined_questions: List[str], + conversation_history: List[Dict[str, str]], + ranked_chunks: List[Dict[str, Any]], + embedding_scores: List[float], + final_answer: str, + environment: str = "production", + connection_id: Optional[int] = None, + ) -> Dict[str, Any]: + """ + Async version of store_inference_result for streaming pipelines. + + Args: + chat_id: Chat ID for this conversation + user_question: User's raw question/input + refined_questions: List of refined questions (LLM-generated) + conversation_history: Prior messages array of {role, content} + ranked_chunks: Retrieved chunks ranked with metadata + embedding_scores: Distance scores for each chunk + final_answer: LLM's final generated answer + environment: Deployment environment (production/testing) + connection_id: LLM connection ID (optional, will be fetched if not provided) + + Returns: + Dict containing: + - success (bool): Whether storage was successful + - data (Optional[Dict]): Response data from server + - error (Optional[str]): Error message if failed + """ + try: + # Fetch connection ID if not provided + if connection_id is None: + logger.debug(f"Fetching {environment} connection ID (async)...") + connection_id = await self.connection_fetcher.fetch_connection_id_async( + environment + ) + if connection_id is None: + logger.warning( + f"Could not fetch {environment} connection ID, storing without it" + ) + + # Prepare the request payload + payload = self._create_payload( + chat_id, + user_question, + refined_questions, + conversation_history, + ranked_chunks, + embedding_scores, + final_answer, + environment, + connection_id, + ) + + logger.debug( + f"Storing inference result (async) for chat_id: {chat_id}, environment: {environment}" + ) + + # Make the async HTTP POST request to Ruuter endpoint + async with aiohttp.ClientSession() as session: + async with session.post( + self.store_endpoint, + json=payload, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as response: + # Check if the request was successful + if response.status == 200: + response_data = await response.json() + return self._handle_response_data( + response_data, chat_id, environment + ) + else: + response_text = await response.text() + error_msg = ( + f"Failed to store production inference result (async). " + f"Status: {response.status}, Response: {response_text}" + ) + logger.error(error_msg) + return { + "success": False, + "data": None, + "error": error_msg, + } + + except Exception as e: + error_msg = ( + f"Error while storing production inference result (async): {str(e)}" + ) + logger.error(error_msg) + return { + "success": False, + "data": None, + "error": error_msg, + } + + +# Singleton instance for reuse across the application +_production_store_instance: Optional[ProductionInferenceStore] = None + + +def get_production_store() -> ProductionInferenceStore: + """ + Get or create the singleton ProductionInferenceStore instance. + + Returns: + ProductionInferenceStore: The singleton instance + """ + global _production_store_instance + if _production_store_instance is None: + _production_store_instance = ProductionInferenceStore() + return _production_store_instance