Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
grpcio>=1.76.0
grpcio-tools>=1.76.0
betterproto2==0.9.1
#mcp-contextforge-gateway==0.8.0
git+https://github.com/IBM/mcp-context-forge@v0.9.0
mcp-contextforge-gateway==0.9.0
nemoguardrails==0.19.0
3 changes: 2 additions & 1 deletion resources/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
plugins:
# Self-contained Search Replace Plugin - depends on plugin availability
- name: "ReplaceBadWordsPlugin"
kind: "plugins.regex_filter.search_replace.SearchReplacePlugin"
# From https://github.com/contextforge-org/contextforge-plugins-python
kind: "contextforge-plugins-python.regex_filter.search_replace.SearchReplacePlugin"
description: "A plugin for finding and replacing words."
version: "0.1.0"
author: "Teryl Taylor"
Expand Down
191 changes: 125 additions & 66 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
import os
import grpc

# Third-Party
from envoy.service.ext_proc.v3 import external_processor_pb2 as ep
from envoy.service.ext_proc.v3 import external_processor_pb2_grpc as ep_grpc
from envoy.config.core.v3 import base_pb2 as core
from envoy.type.v3 import http_status_pb2 as http_status_pb2

# plugin manager
# First-Party
# from apex.mcp.entities.models import HookType, Message, PromptResult, Role, TextContent, PromptPosthookPayload, PromptPrehookPayload
# import apex.mcp.entities.models as apex
# import mcpgateway.plugins.tools.models as apex
from mcpgateway.plugins.framework import (
ToolHookType,
PromptPrehookPayload,
Expand All @@ -24,66 +21,39 @@
)
from mcpgateway.plugins.framework import PluginManager
from mcpgateway.plugins.framework.models import GlobalContext
# from apex.framework.manager import PluginManager
# from apex.framework.models import GlobalContext
# from plugins.regex_filter.search_replace import SearchReplaceConfig

# ============================================================================
# LOGGING CONFIGURATION
# ============================================================================

log_level = os.environ.get("LOGLEVEL", "INFO").upper()

logging.basicConfig(level=log_level)
logger = logging.getLogger("ext-proc-PM")
logger.setLevel(log_level)

# handler = logging.StreamHandler()
# handler.setLevel(log_level)

# # Add the handler to the logger
# logger.addHandler(handler)


async def getToolPostInvokeResponse(body):
# FIXME: size of content array is expected to be 1
# for content in body["result"]["content"]:

logger.debug("**** Tool Post Invoke ****")
payload = ToolPostInvokePayload(name="replaceme", result=body["result"])
# TODO: hard-coded ids
logger.debug("**** Tool Post Invoke payload ****")
logger.debug(payload)
global_context = GlobalContext(request_id="1", server_id="2")
result, _ = await manager.invoke_hook(
ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context
)
logger.info(result)
if not result.continue_processing:
body_resp = ep.ProcessingResponse(
immediate_response=ep.ImmediateResponse(
# TODO: hard-coded error reason
status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden),
details="No go",
)
)
else:
result_payload = result.modified_payload
if result_payload is not None:
body["result"] = result_payload.result
else:
body = None
body_resp = ep.ProcessingResponse(
request_body=ep.BodyResponse(
response=ep.CommonResponse(
body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8"))
)
)
)
return body_resp
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================


def set_result_in_body(body, result_args):
"""Set the result arguments in the request body."""
body["params"]["arguments"] = result_args


# ============================================================================
# MCP HOOK HANDLERS
# ============================================================================


async def getToolPreInvokeResponse(body):
"""
Handle tool pre-invoke hook processing.

Invokes plugins before a tool is called, allowing for argument validation,
modification, or blocking of the tool invocation.
"""
logger.debug(body)
payload_args = {
"tool_name": body["params"]["name"],
Expand All @@ -93,13 +63,11 @@ async def getToolPreInvokeResponse(body):
payload = ToolPreInvokePayload(name=body["params"]["name"], args=payload_args)
# TODO: hard-coded ids
global_context = GlobalContext(request_id="1", server_id="2")
logger.debug("**** Invoking Tool Pre Invoke with payload ****")
logger.debug(payload)
logger.debug(f"**** Invoking Tool Pre Invoke with payload: {payload} ****")
result, _ = await manager.invoke_hook(
ToolHookType.TOOL_PRE_INVOKE, payload, global_context=global_context
)
logger.debug("**** Tool Pre Invoke Result ****")
logger.debug(result)
logger.debug(f"**** Tool Pre Invoke Result: {result} ****")
if not result.continue_processing:
error_body = {
"jsonrpc": body["jsonrpc"],
Expand Down Expand Up @@ -143,18 +111,66 @@ async def getToolPreInvokeResponse(body):
)
)
)
logger.info("****Tool Pre Invoke Return body****")
logger.info(body_resp)
logger.info(f"****Tool Pre Invoke Return body: {body_resp}****")
return body_resp


async def getToolPostInvokeResponse(body):
"""
Handle tool post-invoke hook processing.

Invokes plugins after a tool has been called, allowing for result validation,
modification, or filtering of the tool output.
"""
# FIXME: size of content array is expected to be 1
# for content in body["result"]["content"]:

logger.debug("**** Tool Post Invoke ****")
payload = ToolPostInvokePayload(name="replaceme", result=body["result"])
# TODO: hard-coded ids
logger.debug(f"**** Tool Post Invoke payload: {payload} ****")
global_context = GlobalContext(request_id="1", server_id="2")
result, _ = await manager.invoke_hook(
ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context
)
logger.info(result)
if not result.continue_processing:
body_resp = ep.ProcessingResponse(
immediate_response=ep.ImmediateResponse(
# TODO: hard-coded error reason
status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden),
details="No go",
)
)
else:
result_payload = result.modified_payload
if result_payload is not None:
body["result"] = result_payload.result
else:
body = None
body_resp = ep.ProcessingResponse(
request_body=ep.BodyResponse(
response=ep.CommonResponse(
body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8"))
)
)
)
return body_resp


async def getPromptPreFetchResponse(body):
"""
Handle prompt pre-fetch hook processing.

Invokes plugins before a prompt is fetched, allowing for argument validation,
modification, or blocking of the prompt request.
"""
prompt = PromptPrehookPayload(
name=body["params"]["name"], args=body["params"]["arguments"]
)
# TODO: hard-coded ids
global_context = GlobalContext(request_id="1", server_id="2")
result, contexts = await manager.invoke_hook(
result, _ = await manager.invoke_hook(
ToolHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context
)
logger.info(result)
Expand All @@ -174,22 +190,43 @@ async def getPromptPreFetchResponse(body):
)
)
)
logger.info("****body ")
logger.info(body_resp)
logger.info(f"****Prompt Pre-fetch Return body: {body_resp}")
return body_resp


# ============================================================================
# ENVOY EXTERNAL PROCESSOR SERVICER
# ============================================================================


class ExtProcServicer(ep_grpc.ExternalProcessorServicer):
"""
Envoy External Processor implementation for MCP Gateway.

Processes HTTP requests and responses, intercepting MCP protocol messages
to apply plugin hooks at various stages of the request/response lifecycle.
"""

async def Process(
self, request_iterator: AsyncIterator[ep.ProcessingRequest], context
) -> AsyncIterator[ep.ProcessingResponse]:
"""
Main processing loop for handling Envoy external processor requests.

Processes different types of requests:
- Request headers: Add custom headers to incoming requests
- Response headers: Add custom headers to outgoing responses
- Request body: Process MCP tool/prompt invocations
- Response body: Process MCP tool results
"""
req_body_buf = bytearray()
resp_body_buf = bytearray()

async for request in request_iterator:
# logger.info(request)
# ----------------------------------------------------------------
# Request Headers Processing
# ----------------------------------------------------------------
if request.HasField("request_headers"):
# Modify request headers
_headers = request.request_headers.headers
yield ep.ProcessingResponse(
request_headers=ep.HeadersResponse(
Expand All @@ -210,8 +247,10 @@ async def Process(
)
)
)
# ----------------------------------------------------------------
# Response Headers Processing
# ----------------------------------------------------------------
elif request.HasField("response_headers"):
# Modify response headers
_headers = request.response_headers.headers
yield ep.ProcessingResponse(
response_headers=ep.HeadersResponse(
Expand All @@ -233,6 +272,9 @@ async def Process(
)
)

# ----------------------------------------------------------------
# Request Body Processing (MCP Tool/Prompt Invocations)
# ----------------------------------------------------------------
elif request.HasField("request_body") and request.request_body.body:
chunk = request.request_body.body
req_body_buf.extend(chunk)
Expand All @@ -247,7 +289,6 @@ async def Process(
body = json.loads(text)
if "method" in body and body["method"] == "tools/call":
body_resp = await getToolPreInvokeResponse(body)

elif "method" in body and body["method"] == "prompts/get":
body_resp = await getPromptPreFetchResponse(body)
else:
Expand All @@ -260,7 +301,9 @@ async def Process(

req_body_buf.clear()

# ---- Response body chunks ----
# ----------------------------------------------------------------
# Response Body Processing (MCP Tool Results)
# ----------------------------------------------------------------
elif request.HasField("response_body") and request.response_body.body:
chunk = request.response_body.body
resp_body_buf.extend(chunk)
Expand Down Expand Up @@ -289,14 +332,26 @@ async def Process(
yield body_resp
resp_body_buf.clear()

# Handle other message types (request_body, response_body, etc.) as needed
else:
# Unhandled request types
logger.warn("Not processed")


# ============================================================================
# SERVER INITIALIZATION
# ============================================================================


async def serve(host: str = "0.0.0.0", port: int = 50052):
"""
Initialize and start the gRPC external processor server.

Args:
host: Host address to bind to (default: 0.0.0.0)
port: Port number to listen on (default: 50052)
"""
await manager.initialize()
logger.info(manager.config)
logger.info(f"Manager config: {manager.config}")
logger.debug(f"Loaded {manager.plugin_count} plugins")

server = grpc.aio.server()
Expand All @@ -310,6 +365,10 @@ async def serve(host: str = "0.0.0.0", port: int = 50052):
await server.wait_for_termination()


# ============================================================================
# MAIN ENTRY POINT
# ============================================================================

if __name__ == "__main__":
try:
logging.getLogger("mcpgateway.config").setLevel(logging.DEBUG)
Expand Down