diff --git a/.gitignore b/.gitignore index c2b00cf..85dd180 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,4 @@ -# reccmp-user.yml -# reccmp-build.yml - - -# OpenSHC project specific: +### OpenSHC build-*/ build/ _build/ @@ -14,6 +10,10 @@ dist/ # Softlink to game installation _original +# Allow dependency libs +!dependencies/**/*.lib + +### C++ (https://github.com/github/gitignore/blob/main/C%2B%2B.gitignore) # Prerequisites *.d @@ -50,5 +50,167 @@ _original # Visual Studio Configuration .vs -# Allow dependency libs -!dependencies/**/*.lib +### Python (https://github.com/github/gitignore/blob/main/Python.gitignore) + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/tools/mcp/README.md b/tools/mcp/README.md new file mode 100644 index 0000000..63ff060 --- /dev/null +++ b/tools/mcp/README.md @@ -0,0 +1,11 @@ +# MCP server + +AI can be of help with increasing accuracy. + +## Setup + +In Ghidra, add the `ghidra_scripts` directory as a directory in the Script Manager window. Then launch `_OpenSHC/TOOLS/ghidra_server.py` from the Script Manager to start the server to expose Ghidra functionality to an MCP client such as Claude. + +## Claude specific setup + +Copy the file `claude_desktop_configuration.json` to your Claude directory (`%APPDATA%\Claude`) and adapt it such that it points to this mcp directory containing `decomphelper.py` on your file system. Restart Claude completely. diff --git a/tools/mcp/claude/claude_desktop_config.json b/tools/mcp/claude/claude_desktop_config.json new file mode 100644 index 0000000..de19b3d --- /dev/null +++ b/tools/mcp/claude/claude_desktop_config.json @@ -0,0 +1,17 @@ +{ + "preferences": { + "legacyQuickEntryEnabled": false, + "menuBarEnabled": false + }, + "mcpServers": { + "decomp-helper": { + "command": "uv", + "args": [ + "--directory", + "", + "run", + "./tools/mcp/decomphelper.py" + ] + } + } +} \ No newline at end of file diff --git a/tools/mcp/decomphelper.py b/tools/mcp/decomphelper.py new file mode 100644 index 0000000..9fa8c05 --- /dev/null +++ b/tools/mcp/decomphelper.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +""" +MCP Server for C++ decompilation and assembly comparison tasks. + +This server provides tools for: +- Compiling C++ files with MSVC +- Comparing assembly output with target assembly +- Generating assembly diffs + +Usage: + Claude: + add a mcpServers entry to the Claude Desktop App config file + Testing: + python decomp_mcp_server.py +""" + +import asyncio +import json +import os +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Sequence +import difflib +import re +import sys +import requests +#sys.stdout.reconfigure(line_buffering=True) +#sys.stderr.reconfigure(line_buffering=True) +import logging +logging.basicConfig(stream=sys.stderr) + +from mcp.server.fastmcp import FastMCP +from mcp.types import ( + Resource, + Tool, + TextContent, + ImageContent, + EmbeddedResource, + LoggingLevel +) +import mcp.server.stdio + +PATH_CMAKE_OPENSHC_SOURCES = Path("cmake/openshc-sources.txt") +if not PATH_CMAKE_OPENSHC_SOURCES.exists(): + raise Exception(f"could not find cmake core sources txt file: {str(PATH_CMAKE_OPENSHC_SOURCES)}") + +# Initialize MCP server +mcp = FastMCP("decomp-helper") + +def compile_project() -> tuple[bool, str, str]: + """ + Compile the C++ project using MSVC. Must be executed after writing new cpp file contents. + + Returns: + Tuple of (success, stdout, stderr) + """ + # Build compiler command + cmd = ["build.bat", "RelWithDebInfo", "OpenSHC.dll"] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=".", + stdin=subprocess.DEVNULL, + ) + return result.returncode == 0, result.stdout, result.stderr + except Exception as e: + return False, "", str(e) + +@mcp.tool() +def extract_function_assembly_diff(function_name: str) -> tuple[bool, Any, str, str]: + """ + Extract assembly diff for a specific function, comparing the original binary to the reimplementation source code. Should be called after writing and compiling a cpp file, see 'compile_cpp_code_for_function'. + + Args: + function_name: Name of the function to extract, fully namespaced using '::' + + Returns: + Tuple of (success, diff, stdout, stderr) + """ + cmd = [str(Path("reccmp") / "run.bat"), "reccmp-reccmp", "--target", "STRONGHOLDCRUSADER", "--json", "diff.json"] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=".", + stdin=subprocess.DEVNULL, + ) + if result.returncode != 0: + raise Exception(f"could not create diff: {result.stderr}, command: {' '.join(cmd)}") + except Exception as e: + return False, "", "", f"could not execute reccmp/run: {str(e)}" + try: + diff = json.loads(Path("reccmp/diff.json").read_text()) + except Exception as e: + return False, "", "", f"could not load reccmp/diff.json: {str(e)}" + all_data = diff['data'] + data = [entry for entry in all_data if entry['name'] == function_name] + if len(data) == 0: + return False, "", "", f"no function with name '{function_name}' in diff.json" + data = data[0] + return True, data, "", "" + +def function_name_to_cpp_path(function_name: str, base_path = Path("src")) -> tuple[bool, str, str]: + parts = function_name.split("::") + path = base_path + for part in parts: + if "." in part: + return False, "", "illegal character in cpp file path: ." + path = path / part + return True, f"{str(path)}.cpp", "" + +@mcp.tool() +def compile_cpp_code_for_function(function_name: str, contents: str) -> tuple[bool, str, str]: + """ + Write and compile cpp code for function identified by fully namespaced function name. + + Args: + function_name: Name of the function to extract, fully namespaced using '::' + contents: New contents of the file + + Returns: + Tuple of (success, stdout, stderr) + """ + # Translate the function name into a path + rstate, rresult, rerr = function_name_to_cpp_path(function_name=function_name) + if not rstate: + return rstate, "", f"could not resolve function name to file path: {rerr}" + path = Path(rresult) + if not path.exists(): + return False, "", f"cpp file path does not exist: {str(path)}" + path.write_text(contents) + + # Ensure the cpp file is included in the build + csentry = str(path).replace("\\", "/") + if not csentry.startswith("src/"): + return False, "", f"invalid cmake/openshc-sources.txt entry: {csentry}" + + lines = PATH_CMAKE_OPENSHC_SOURCES.read_text().splitlines(False) + if not csentry in lines: + lines.append(csentry) + PATH_CMAKE_OPENSHC_SOURCES.write_text('\n'.join(lines) + '\n', newline='\n') + + # Compile the project and return the resulting state + return compile_project() + +@mcp.tool() +def read_cpp_code_for_function(function_name: str) -> tuple[bool, str, str]: + """ + Read the cpp code for a function + + Args: + function_name: Name of the function to extract, fully namespaced using '::' + + Returns: + Tuple of (success, contents, stderr) + """ + rstate, rresult, rerr = function_name_to_cpp_path(function_name=function_name) + if not rstate: + return rstate, "", f"could not resolve function name to file path: {rerr}" + path = Path(rresult) + if not path.exists(): + return False, "", f"cpp file path does not exist: {str(path)}" + try: + return True, path.read_text(), "" + except Exception as e: + return False, "", f"{e}" + +@mcp.tool() +def read_source_file(relative_path: str) -> tuple[bool, str, str]: + """ + Read the C++ file contents of a file + + Args: + relative_path: Name of the file, usually starts with 'OpenSHC/' + + Returns: + Tuple of (success, contents, stderr) + """ + src = Path("src") + path = (src / Path(relative_path)).resolve() + if not str(path).startswith(str(src.resolve())): + return False, "", "Cannot escape src/ directory" + if not path.exists(): + return False, "", f" cpp file path does not exist: {str(path)}" + try: + return True, path.read_text(), "" + except Exception as e: + return False, "", f"{e}" + +@mcp.tool() +def fetch_ghidra_function_decompilation(function_name: str) -> tuple[bool, str, str]: + """ + Fetches decompilation of a function (json with additional information). Contains ghidra special functions. + + Args: + function_name: Name of the function to extract, fully namespaced using '::' + + Returns: + Tuple of (success, contents, stderr) + """ + try: + resp = requests.get("http://127.0.0.1:11337/functions/decompile", params={ + "name": function_name, + }, timeout=5) + resp.raise_for_status() + contents = resp.json() + return True, json.dumps(contents), "" + except Exception as e: + return False, "", f"{e}" + + +if __name__ == "__main__": + mcp.run(transport="stdio") diff --git a/tools/mcp/ghidra_scripts/decompilation/__init__.py b/tools/mcp/ghidra_scripts/decompilation/__init__.py new file mode 100644 index 0000000..a44ab60 --- /dev/null +++ b/tools/mcp/ghidra_scripts/decompilation/__init__.py @@ -0,0 +1,135 @@ +import typing +if typing.TYPE_CHECKING: + from ghidra.ghidra_builtins import * # type: ignore + +from pyghidra.script import get_current_interpreter +getCurrentProgram = get_current_interpreter().getCurrentProgram # type: ignore + +import traceback +from http.server import BaseHTTPRequestHandler, HTTPServer +import json +import threading +import urllib +from urllib.parse import urlparse, parse_qs +from ghidra.app.decompiler import DecompInterface, DecompileOptions # type: ignore +from ghidra.util.task import ConsoleTaskMonitor # type: ignore +from ghidra.program.model.listing import Function # type: ignore +from ghidra.program.model.pcode import HighSymbol # type: ignore +import re + +def decompile_function(func: Function, NAMESPACE = "EXE"): + """ + Decompile a function using Ghidra's decompiler. + + Args: + func: Ghidra Function object + + Returns: + Dictionary with decompilation results + """ + decompiler = None + try: + # Initialize decompiler + decompiler = DecompInterface() + decompiler.openProgram(getCurrentProgram()) + + # Set decompiler options + options = DecompileOptions() + decompiler.setOptions(options) + + # Set a reasonable timeout + decompiler.setSimplificationStyle("decompile") + + # Decompile + monitor = ConsoleTaskMonitor() + results = decompiler.decompileFunction(func, 30, monitor) # 30 second timeout + + if results and results.decompileCompleted(): + decompiled_code = results.getDecompiledFunction().getC() + + highF = results.getHighFunction() + + global_symbols = list[HighSymbol](highF.getGlobalSymbolMap().getSymbols()) + + clean_instructions: list[tuple[str, str]] = [] + includes: list[str] = [] + serialized_global_symbols = [] + for sym in global_symbols: + dt = sym.getDataType() + dtp = dt.getDataTypePath().getPath() + dtn = dt.getName() + if dtp[0] == '/': + dtp = dtp[1:] + if dtp.startswith("_HoldStrong/"): + dtp = NAMESPACE + "/" + dtp[1+len("HoldStrong/"):] + includes.append(dtp.replace("/", "::") + ".hpp") + is_this = dtn in func.getName(True) # type: ignore + if hasattr(dt, "getNumElements"): + clean_instructions.append(("(?")) + clean_instructions.append(("(?")) + clean_instructions.append(("&" + sym.getName(), "this")) + else: + clean_instructions.append(("(?")) + clean_instructions.append(("(? + + + Ghidra Decompilation Server + + + +

Ghidra Decompilation Server

+

Server is running on port 11337

+ +

Available Endpoints:

+ +
+

POST /functions/decompile

+

Decompile a function by address or name

+

Request body:

+
{
+    "address": "0x004F8160",  // hex address
+    // OR
+    "name": "functionName"
+}
+
+ +
+

GET /health

+

Check server health status

+
+ + + """ + self.wfile.write(html.encode("utf-8")) + + elif path == '/health': + self._send_json({ + "status": "healthy", + "program": str(currentProgram.getName()) if currentProgram else None + }) + elif path == '/exit': + self._send_json({"status": "shutdown", }) + KEEP_RUNNING = False + monitor.cancel() + elif path == '/functions/decompile': + print(url) + if not url.query: + return self._send_error_json("no parameters", 400) + qc = parse_qs(url.query) + print(qc) + if not "name" in qc: + return self._send_error_json("parameter 'name' missing", 400) + funcName = qc["name"][0] + print(funcName) + if funcName not in fdb: + return self._send_error_json("not found: " + funcName, 404) + # Decompile the function + decompiled = self._decompile_function(fdb[funcName]) + + if decompiled["success"]: + self._send_json(decompiled) + else: + self._send_error_json(decompiled.get("error", "Decompilation failed")) + else: + self._send_error_json("Not found", 404) + + def do_POST(self): + """Handle POST requests.""" + if self.path == '/functions/decompile': + try: + # Read request body + content_length = int(self.headers.get('Content-Length', 0)) + post_data = self.rfile.read(content_length) + request_data = json.loads(post_data) + + # Get function address or name + address_str = request_data.get('address') + function_name = request_data.get('name') + + if not address_str and not function_name: + self._send_error_json("Either 'address' or 'name' must be provided") + return + + # Find function + func = None + if address_str: + # Parse address + if address_str.startswith('0x') or address_str.startswith('0X'): + address_str = address_str[2:] + + try: + addr = currentProgram.getAddressFactory().getAddress(address_str) + func = currentProgram.getFunctionManager().getFunctionAt(addr) + except: + self._send_error_json("Invalid address format: " + address_str) + return + + elif function_name: + # Search by name + function_manager = currentProgram.getFunctionManager() + functions = function_manager.getFunctions(True) # Get all functions + for f in functions: + if f.getName() == function_name: + func = f + break + + if not func: + self._send_error_json("Function not found") + return + + # Decompile the function + decompiled = self._decompile_function(func) + + if decompiled["success"]: + self._send_json(decompiled) + else: + self._send_error_json(decompiled.get("error", "Decompilation failed")) + + except json.JSONDecodeError: + self._send_error_json("Invalid JSON in request body") + except Exception as e: + self._send_error_json("Internal server error: " + str(e), 500) + + else: + self._send_error_json("Not found", 404) + + def _decompile_function(self, func: Function): + return decompile_function(func, NAMESPACE=NAMESPACE) + + +def serve(threaded = False): + while not monitor.cancelled: + if not httpd: + break + httpd.handle_request() + if not threaded: + if not KEEP_RUNNING: + break + stop_server() + +def start_server(threaded=False): + """Start the HTTP server in a background thread.""" + global server_thread, httpd + + if server_thread and server_thread.is_alive(): + println("[Server] Server is already running on port 11337") + return + + try: + server_address = ('127.0.0.1', 11337) + httpd = HTTPServer(server_address, DecompilationHandler) + httpd.timeout = 1 + + println("[Server] Starting Ghidra Decompilation Server on http://127.0.0.1:11337") + println("[Server] Endpoint: POST http://127.0.0.1:11337/functions/decompile") + println("[Server] Press Ctrl+C in console or run 'Stop Decompilation Server' to stop") + println("[Server] Visit http://127.0.0.1:11337 in your browser for API documentation") + + if threaded: + # Run server in background thread + server_thread = threading.Thread(target=serve) + server_thread.daemon = True + server_thread.start() + println("[Server] Server started successfully!") + else: + println("[Server] Serving...") + serve() + + except Exception as e: + println("[Server] Failed to start server: " + str(e)) + if "Address already in use" in str(e): + println("[Server] Port 11337 is already in use. Stop the existing server first.") + + +def stop_server(threaded = False): + """Stop the HTTP server.""" + global httpd, server_thread + + if httpd: + println("[Server] Stopping server...") + if threaded: + httpd.shutdown() + httpd.server_close() + httpd = None + server_thread = None + println("[Server] Server stopped.") + else: + println("[Server] No server is currently running.") + + +# Main execution +if __name__ == '__main__': + if not currentProgram: + println("[Server] Error: No program is currently open in Ghidra") + println("[Server] Please open a program before starting the server") + else: + start_server() diff --git a/tools/mcp/ghidra_scripts/requirements.txt b/tools/mcp/ghidra_scripts/requirements.txt new file mode 100644 index 0000000..98f7a62 --- /dev/null +++ b/tools/mcp/ghidra_scripts/requirements.txt @@ -0,0 +1 @@ +ghidra-stubs \ No newline at end of file diff --git a/tools/mcp/requirements.txt b/tools/mcp/requirements.txt new file mode 100644 index 0000000..fd4697a --- /dev/null +++ b/tools/mcp/requirements.txt @@ -0,0 +1 @@ +mcp \ No newline at end of file