diff --git a/requirements.txt b/requirements.txt index 7821bba..faa6730 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ grpcio>=1.76.0 grpcio-tools>=1.76.0 betterproto2==0.9.1 -#mcp-contextforge-gateway==0.8.0 -git+https://github.com/IBM/mcp-context-forge@v0.9.0 +mcp-contextforge-gateway==0.9.0 nemoguardrails==0.19.0 diff --git a/resources/config/config.yaml b/resources/config/config.yaml index 03ff86d..ef3be7d 100644 --- a/resources/config/config.yaml +++ b/resources/config/config.yaml @@ -2,7 +2,8 @@ plugins: # Self-contained Search Replace Plugin - depends on plugin availability - name: "ReplaceBadWordsPlugin" - kind: "plugins.regex_filter.search_replace.SearchReplacePlugin" + # From https://github.com/contextforge-org/contextforge-plugins-python + kind: "contextforge-plugins-python.regex_filter.search_replace.SearchReplacePlugin" description: "A plugin for finding and replacing words." version: "0.1.0" author: "Teryl Taylor" diff --git a/src/server.py b/src/server.py index 576d225..04d2967 100644 --- a/src/server.py +++ b/src/server.py @@ -6,16 +6,13 @@ import os import grpc +# Third-Party from envoy.service.ext_proc.v3 import external_processor_pb2 as ep from envoy.service.ext_proc.v3 import external_processor_pb2_grpc as ep_grpc from envoy.config.core.v3 import base_pb2 as core from envoy.type.v3 import http_status_pb2 as http_status_pb2 -# plugin manager # First-Party -# from apex.mcp.entities.models import HookType, Message, PromptResult, Role, TextContent, PromptPosthookPayload, PromptPrehookPayload -# import apex.mcp.entities.models as apex -# import mcpgateway.plugins.tools.models as apex from mcpgateway.plugins.framework import ( ToolHookType, PromptPrehookPayload, @@ -24,9 +21,10 @@ ) from mcpgateway.plugins.framework import PluginManager from mcpgateway.plugins.framework.models import GlobalContext -# from apex.framework.manager import PluginManager -# from apex.framework.models import GlobalContext -# from plugins.regex_filter.search_replace import SearchReplaceConfig + +# ============================================================================ +# LOGGING CONFIGURATION +# ============================================================================ log_level = os.environ.get("LOGLEVEL", "INFO").upper() @@ -34,56 +32,28 @@ logger = logging.getLogger("ext-proc-PM") logger.setLevel(log_level) -# handler = logging.StreamHandler() -# handler.setLevel(log_level) - -# # Add the handler to the logger -# logger.addHandler(handler) - - -async def getToolPostInvokeResponse(body): - # FIXME: size of content array is expected to be 1 - # for content in body["result"]["content"]: - - logger.debug("**** Tool Post Invoke ****") - payload = ToolPostInvokePayload(name="replaceme", result=body["result"]) - # TODO: hard-coded ids - logger.debug("**** Tool Post Invoke payload ****") - logger.debug(payload) - global_context = GlobalContext(request_id="1", server_id="2") - result, _ = await manager.invoke_hook( - ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context - ) - logger.info(result) - if not result.continue_processing: - body_resp = ep.ProcessingResponse( - immediate_response=ep.ImmediateResponse( - # TODO: hard-coded error reason - status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden), - details="No go", - ) - ) - else: - result_payload = result.modified_payload - if result_payload is not None: - body["result"] = result_payload.result - else: - body = None - body_resp = ep.ProcessingResponse( - request_body=ep.BodyResponse( - response=ep.CommonResponse( - body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8")) - ) - ) - ) - return body_resp +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ def set_result_in_body(body, result_args): + """Set the result arguments in the request body.""" body["params"]["arguments"] = result_args +# ============================================================================ +# MCP HOOK HANDLERS +# ============================================================================ + + async def getToolPreInvokeResponse(body): + """ + Handle tool pre-invoke hook processing. + + Invokes plugins before a tool is called, allowing for argument validation, + modification, or blocking of the tool invocation. + """ logger.debug(body) payload_args = { "tool_name": body["params"]["name"], @@ -93,13 +63,11 @@ async def getToolPreInvokeResponse(body): payload = ToolPreInvokePayload(name=body["params"]["name"], args=payload_args) # TODO: hard-coded ids global_context = GlobalContext(request_id="1", server_id="2") - logger.debug("**** Invoking Tool Pre Invoke with payload ****") - logger.debug(payload) + logger.debug(f"**** Invoking Tool Pre Invoke with payload: {payload} ****") result, _ = await manager.invoke_hook( ToolHookType.TOOL_PRE_INVOKE, payload, global_context=global_context ) - logger.debug("**** Tool Pre Invoke Result ****") - logger.debug(result) + logger.debug(f"**** Tool Pre Invoke Result: {result} ****") if not result.continue_processing: error_body = { "jsonrpc": body["jsonrpc"], @@ -143,18 +111,66 @@ async def getToolPreInvokeResponse(body): ) ) ) - logger.info("****Tool Pre Invoke Return body****") - logger.info(body_resp) + logger.info(f"****Tool Pre Invoke Return body: {body_resp}****") + return body_resp + + +async def getToolPostInvokeResponse(body): + """ + Handle tool post-invoke hook processing. + + Invokes plugins after a tool has been called, allowing for result validation, + modification, or filtering of the tool output. + """ + # FIXME: size of content array is expected to be 1 + # for content in body["result"]["content"]: + + logger.debug("**** Tool Post Invoke ****") + payload = ToolPostInvokePayload(name="replaceme", result=body["result"]) + # TODO: hard-coded ids + logger.debug(f"**** Tool Post Invoke payload: {payload} ****") + global_context = GlobalContext(request_id="1", server_id="2") + result, _ = await manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context + ) + logger.info(result) + if not result.continue_processing: + body_resp = ep.ProcessingResponse( + immediate_response=ep.ImmediateResponse( + # TODO: hard-coded error reason + status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden), + details="No go", + ) + ) + else: + result_payload = result.modified_payload + if result_payload is not None: + body["result"] = result_payload.result + else: + body = None + body_resp = ep.ProcessingResponse( + request_body=ep.BodyResponse( + response=ep.CommonResponse( + body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8")) + ) + ) + ) return body_resp async def getPromptPreFetchResponse(body): + """ + Handle prompt pre-fetch hook processing. + + Invokes plugins before a prompt is fetched, allowing for argument validation, + modification, or blocking of the prompt request. + """ prompt = PromptPrehookPayload( name=body["params"]["name"], args=body["params"]["arguments"] ) # TODO: hard-coded ids global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook( + result, _ = await manager.invoke_hook( ToolHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context ) logger.info(result) @@ -174,22 +190,43 @@ async def getPromptPreFetchResponse(body): ) ) ) - logger.info("****body ") - logger.info(body_resp) + logger.info(f"****Prompt Pre-fetch Return body: {body_resp}") return body_resp +# ============================================================================ +# ENVOY EXTERNAL PROCESSOR SERVICER +# ============================================================================ + + class ExtProcServicer(ep_grpc.ExternalProcessorServicer): + """ + Envoy External Processor implementation for MCP Gateway. + + Processes HTTP requests and responses, intercepting MCP protocol messages + to apply plugin hooks at various stages of the request/response lifecycle. + """ + async def Process( self, request_iterator: AsyncIterator[ep.ProcessingRequest], context ) -> AsyncIterator[ep.ProcessingResponse]: + """ + Main processing loop for handling Envoy external processor requests. + + Processes different types of requests: + - Request headers: Add custom headers to incoming requests + - Response headers: Add custom headers to outgoing responses + - Request body: Process MCP tool/prompt invocations + - Response body: Process MCP tool results + """ req_body_buf = bytearray() resp_body_buf = bytearray() async for request in request_iterator: - # logger.info(request) + # ---------------------------------------------------------------- + # Request Headers Processing + # ---------------------------------------------------------------- if request.HasField("request_headers"): - # Modify request headers _headers = request.request_headers.headers yield ep.ProcessingResponse( request_headers=ep.HeadersResponse( @@ -210,8 +247,10 @@ async def Process( ) ) ) + # ---------------------------------------------------------------- + # Response Headers Processing + # ---------------------------------------------------------------- elif request.HasField("response_headers"): - # Modify response headers _headers = request.response_headers.headers yield ep.ProcessingResponse( response_headers=ep.HeadersResponse( @@ -233,6 +272,9 @@ async def Process( ) ) + # ---------------------------------------------------------------- + # Request Body Processing (MCP Tool/Prompt Invocations) + # ---------------------------------------------------------------- elif request.HasField("request_body") and request.request_body.body: chunk = request.request_body.body req_body_buf.extend(chunk) @@ -247,7 +289,6 @@ async def Process( body = json.loads(text) if "method" in body and body["method"] == "tools/call": body_resp = await getToolPreInvokeResponse(body) - elif "method" in body and body["method"] == "prompts/get": body_resp = await getPromptPreFetchResponse(body) else: @@ -260,7 +301,9 @@ async def Process( req_body_buf.clear() - # ---- Response body chunks ---- + # ---------------------------------------------------------------- + # Response Body Processing (MCP Tool Results) + # ---------------------------------------------------------------- elif request.HasField("response_body") and request.response_body.body: chunk = request.response_body.body resp_body_buf.extend(chunk) @@ -289,14 +332,26 @@ async def Process( yield body_resp resp_body_buf.clear() - # Handle other message types (request_body, response_body, etc.) as needed else: + # Unhandled request types logger.warn("Not processed") +# ============================================================================ +# SERVER INITIALIZATION +# ============================================================================ + + async def serve(host: str = "0.0.0.0", port: int = 50052): + """ + Initialize and start the gRPC external processor server. + + Args: + host: Host address to bind to (default: 0.0.0.0) + port: Port number to listen on (default: 50052) + """ await manager.initialize() - logger.info(manager.config) + logger.info(f"Manager config: {manager.config}") logger.debug(f"Loaded {manager.plugin_count} plugins") server = grpc.aio.server() @@ -310,6 +365,10 @@ async def serve(host: str = "0.0.0.0", port: int = 50052): await server.wait_for_termination() +# ============================================================================ +# MAIN ENTRY POINT +# ============================================================================ + if __name__ == "__main__": try: logging.getLogger("mcpgateway.config").setLevel(logging.DEBUG)