From d29153eeb7c2985a75bf0779edbf4bd38572877f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 18:50:05 +0300 Subject: [PATCH 01/18] Add request logging options --- .../cvat/exchange-oracle/pyproject.toml | 1 + .../cvat/exchange-oracle/src/core/config.py | 6 + .../exchange-oracle/src/endpoints/__init__.py | 3 + .../src/endpoints/middleware.py | 166 ++++++++++++++++++ 4 files changed, 176 insertions(+) create mode 100644 packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index 3548499851..1238365404 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -24,6 +24,7 @@ xmltodict = "^0.13.0" datumaro = {git = "https://github.com/cvat-ai/datumaro.git", rev = "ff83c00c2c1bc4b8fdfcc55067fcab0a9b5b6b11"} boto3 = "^1.28.33" google-cloud-storage = "^2.14.0" +pyinstrument = "^4.6.2" [tool.poetry.group.dev.dependencies] diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index a6b4991b38..9d909b46d9 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -158,6 +158,12 @@ class FeaturesConfig: default_export_timeout = int(os.environ.get("DEFAULT_EXPORT_TIMEOUT", 60)) "Timeout, in seconds, for annotations or dataset export waiting" + request_logging_enabled = to_bool(os.getenv("REQUEST_LOGGING_ENABLED", False)) + "Allow to log request details for each request" + + profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", False)) + "Allow to profile specific requests by specifying profile=1" + class CoreConfig: default_assignment_time = int(os.environ.get("DEFAULT_ASSIGNMENT_TIME", 1800)) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py b/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py index 610ab10009..b4466bee8b 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/__init__.py @@ -4,6 +4,7 @@ from src.core.config import Config from src.endpoints.cvat import router as cvat_router from src.endpoints.exchange import router as service_router +from src.endpoints.middleware import setup_middleware from src.endpoints.webhook import router as webhook_router from src.schemas import MetaResponse, ResponseError, ValidationErrorResponse @@ -46,4 +47,6 @@ def init_api(app: FastAPI) -> FastAPI: app.include_router(webhook_router, responses=default_responses) app.include_router(service_router, responses=default_responses) + setup_middleware(app) + return app diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py new file mode 100644 index 0000000000..1c4948d903 --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -0,0 +1,166 @@ +import json +import time +from typing import Any, Callable + +from fastapi import FastAPI, Request, Response +from fastapi.responses import HTMLResponse, StreamingResponse +from pyinstrument import Profiler +from pyinstrument.renderers.html import HTMLRenderer +from pyinstrument.renderers.speedscope import SpeedscopeRenderer +from starlette.middleware.base import BaseHTTPMiddleware + +from src.core.config import Config +from src.log import get_root_logger + + +async def profile_request(request: Request, call_next: Callable): + """ + Profile the current request + + Adapted from + https://pyinstrument.readthedocs.io/en/latest/guide.html#profile-a-web-request-in-fastapi + + """ + profile_format = "html" + check_interval = 0.001 + + profile_type_to_renderer = { + "html": HTMLRenderer, + "speedscope": SpeedscopeRenderer, + } + + if request.query_params.get("profile", False): + # The default profile format is speedscope + profile_type = request.query_params.get("profile_format", profile_format) + + # we profile the request along with all additional middlewares, by interrupting + # the program every 1ms1 and records the entire stack at that point + with Profiler(interval=check_interval, async_mode="enabled") as profiler: + await call_next(request) + + renderer = profile_type_to_renderer[profile_type]() + response_data = profiler.output(renderer=renderer) + if profile_type == "html": + return HTMLResponse(response_data) + else: + return StreamingResponse(iter(response_data)) + + # Proceed without profiling + return await call_next(request) + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + """ + Middleware in charge of logging the HTTP request and response + + Adapted from + https://medium.com/@dhavalsavalia/fastapi-logging-middleware-logging-requests-and-responses-with-ease-and-style-201b9aa4001a + + """ + + def __init__(self, app: FastAPI) -> None: + super().__init__(app) + self.logger = get_root_logger() + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + logging_dict: dict[str, Any] = {} + + await request.body() + response, response_dict = await self._log_response(call_next, request) + request_dict = await self._log_request(request) + logging_dict["request"] = request_dict + logging_dict["response"] = response_dict + + request_id = request.headers.get("X-Request-ID") + if request_id: + logging_dict["correlation_id"] = request_id + + self.logger.info(json.dumps(logging_dict)) + return response + + async def _log_request(self, request: Request) -> dict[str, Any]: + """ + Logs request part + + Arguments: + - request: Request + """ + + path = request.url.path + if request.query_params: + path += f"?{request.query_params}" + + request_logging = { + "method": request.method, + "path": path, + "ip": request.client.host if request.client is not None else None, + } + + try: + body = await request.json() + except Exception: + body = None + else: + request_logging["body"] = body + + return request_logging + + async def _log_response( + self, call_next: Callable, request: Request + ) -> tuple[Response, dict[str, Any]]: + """ + Logs response part + + Arguments: + - call_next: Callable (To execute the actual path function and get response back) + - request: Request + - request_id: str (uuid) + + Returns: + - response: Response + - response_logging: str + """ + + start_time = time.perf_counter() + response = await self._execute_request(call_next, request) + finish_time = time.perf_counter() + execution_time = finish_time - start_time + + overall_status = "successful" if response.status_code < 400 else "failed" + + response_logging = { + "status": overall_status, + "status_code": response.status_code, + "time_taken": f"{execution_time:0.4f}s", + } + return response, response_logging + + async def _execute_request(self, call_next: Callable, request: Request) -> Response: + """ + Executes the actual path function using call_next. + + Arguments: + - call_next: Callable (To execute the actual path function + and get response back) + - request: Request + - request_id: str (uuid) + Returns: + - response: Response + """ + try: + response: Response = await call_next(request) + + except Exception as e: + self.logger.exception({"path": request.url.path, "method": request.method, "reason": e}) + raise e + + else: + return response + + +def setup_middleware(app: FastAPI): + if Config.features.request_logging_enabled: + app.add_middleware(RequestLoggingMiddleware) + + if Config.features.profiling_enabled: + app.add_middleware(BaseHTTPMiddleware, dispatch=profile_request) From 58d44fe788f50ca7bcf20df68e2fa77b1ff278b9 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 18:54:16 +0300 Subject: [PATCH 02/18] Update poetry lock file --- .../examples/cvat/exchange-oracle/poetry.lock | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/poetry.lock b/packages/examples/cvat/exchange-oracle/poetry.lock index aea0b61fb1..6b210e0fcd 100644 --- a/packages/examples/cvat/exchange-oracle/poetry.lock +++ b/packages/examples/cvat/exchange-oracle/poetry.lock @@ -3126,6 +3126,82 @@ typing-extensions = ">=4.2.0" dotenv = ["python-dotenv (>=0.10.4)"] email = ["email-validator (>=1.0.3)"] +[[package]] +name = "pyinstrument" +version = "4.6.2" +description = "Call stack profiler for Python. Shows you why your code is slow!" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyinstrument-4.6.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7a1b1cd768ea7ea9ab6f5490f7e74431321bcc463e9441dbc2f769617252d9e2"}, + {file = "pyinstrument-4.6.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8a386b9d09d167451fb2111eaf86aabf6e094fed42c15f62ec51d6980bce7d96"}, + {file = "pyinstrument-4.6.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c3e3ca8553b9aac09bd978c73d21b9032c707ac6d803bae6a20ecc048df4a8"}, + {file = "pyinstrument-4.6.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f329f5534ca069420246f5ce57270d975229bcb92a3a3fd6b2ca086527d9764"}, + {file = "pyinstrument-4.6.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4dcdcc7ba224a0c5edfbd00b0f530f5aed2b26da5aaa2f9af5519d4aa8c7e41"}, + {file = "pyinstrument-4.6.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73db0c2c99119c65b075feee76e903b4ed82e59440fe8b5724acf5c7cb24721f"}, + {file = "pyinstrument-4.6.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:da58f265326f3cf3975366ccb8b39014f1e69ff8327958a089858d71c633d654"}, + {file = "pyinstrument-4.6.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:feebcf860f955401df30d029ec8de7a0c5515d24ea809736430fd1219686fe14"}, + {file = "pyinstrument-4.6.2-cp310-cp310-win32.whl", hash = "sha256:b2b66ff0b16c8ecf1ec22de001cfff46872b2c163c62429055105564eef50b2e"}, + {file = "pyinstrument-4.6.2-cp310-cp310-win_amd64.whl", hash = "sha256:8d104b7a7899d5fa4c5bf1ceb0c1a070615a72c5dc17bc321b612467ad5c5d88"}, + {file = "pyinstrument-4.6.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:62f6014d2b928b181a52483e7c7b82f2c27e22c577417d1681153e5518f03317"}, + {file = "pyinstrument-4.6.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dcb5c8d763c5df55131670ba2a01a8aebd0d490a789904a55eb6a8b8d497f110"}, + {file = "pyinstrument-4.6.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed4e8c6c84e0e6429ba7008a66e435ede2d8cb027794c20923c55669d9c5633"}, + {file = "pyinstrument-4.6.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c0f0e1d8f8c70faa90ff57f78ac0dda774b52ea0bfb2d9f0f41ce6f3e7c869e"}, + {file = "pyinstrument-4.6.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3c44cb037ad0d6e9d9a48c14d856254ada641fbd0ae9de40da045fc2226a2a"}, + {file = "pyinstrument-4.6.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:be9901f17ac2f527c352f2fdca3d717c1d7f2ce8a70bad5a490fc8cc5d2a6007"}, + {file = "pyinstrument-4.6.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8a9791bf8916c1cf439c202fded32de93354b0f57328f303d71950b0027c7811"}, + {file = "pyinstrument-4.6.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d6162615e783c59e36f2d7caf903a7e3ecb6b32d4a4ae8907f2760b2ef395bf6"}, + {file = "pyinstrument-4.6.2-cp311-cp311-win32.whl", hash = "sha256:28af084aa84bbfd3620ebe71d5f9a0deca4451267f363738ca824f733de55056"}, + {file = "pyinstrument-4.6.2-cp311-cp311-win_amd64.whl", hash = "sha256:dd6007d3c2e318e09e582435dd8d111cccf30d342af66886b783208813caf3d7"}, + {file = "pyinstrument-4.6.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:e3813c8ecfab9d7d855c5f0f71f11793cf1507f40401aa33575c7fd613577c23"}, + {file = "pyinstrument-4.6.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6c761372945e60fc1396b7a49f30592e8474e70a558f1a87346d27c8c4ce50f7"}, + {file = "pyinstrument-4.6.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4fba3244e94c117bf4d9b30b8852bbdcd510e7329fdd5c7c8b3799e00a9215a8"}, + {file = "pyinstrument-4.6.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:803ac64e526473d64283f504df3b0d5c2c203ea9603cab428641538ffdc753a7"}, + {file = "pyinstrument-4.6.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2e554b1bb0df78f5ce8a92df75b664912ca93aa94208386102af454ec31b647"}, + {file = "pyinstrument-4.6.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7c671057fad22ee3ded897a6a361204ea2538e44c1233cad0e8e30f6d27f33db"}, + {file = "pyinstrument-4.6.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:d02f31fa13a9e8dc702a113878419deba859563a32474c9f68e04619d43d6f01"}, + {file = "pyinstrument-4.6.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b55983a884f083f93f0fc6d12ff8df0acd1e2fb0580d2f4c7bfe6def33a84b58"}, + {file = "pyinstrument-4.6.2-cp312-cp312-win32.whl", hash = "sha256:fdc0a53b27e5d8e47147489c7dab596ddd1756b1e053217ef5bc6718567099ff"}, + {file = "pyinstrument-4.6.2-cp312-cp312-win_amd64.whl", hash = "sha256:dd5c53a0159126b5ce7cbc4994433c9c671e057c85297ff32645166a06ad2c50"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b082df0bbf71251a7f4880a12ed28421dba84ea7110bb376e0533067a4eaff40"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90350533396071cb2543affe01e40bf534c35cb0d4b8fa9fdb0f052f9ca2cfe3"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67268bb0d579330cff40fd1c90b8510363ca1a0e7204225840614068658dab77"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20e15b4e1d29ba0b7fc81aac50351e0dc0d7e911e93771ebc3f408e864a2c93b"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2e625fc6ffcd4fd420493edd8276179c3f784df207bef4c2192725c1b310534c"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:113d2fc534c9ca7b6b5661d6ada05515bf318f6eb34e8d05860fe49eb7cfe17e"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3098cd72b71a322a72dafeb4ba5c566465e193d2030adad4c09566bd2f89bf4f"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-win32.whl", hash = "sha256:08fdc7f88c989316fa47805234c37a40fafe7b614afd8ae863f0afa9d1707b37"}, + {file = "pyinstrument-4.6.2-cp37-cp37m-win_amd64.whl", hash = "sha256:5ebeba952c0056dcc9b9355328c78c4b5c2a33b4b4276a9157a3ab589f3d1bac"}, + {file = "pyinstrument-4.6.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:34e59e91c88ec9ad5630c0964eca823949005e97736bfa838beb4789e94912a2"}, + {file = "pyinstrument-4.6.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cd0320c39e99e3c0a3129d1ed010ac41e5a7eb96fb79900d270080a97962e995"}, + {file = "pyinstrument-4.6.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46992e855d630575ec635eeca0068a8ddf423d4fd32ea0875a94e9f8688f0b95"}, + {file = "pyinstrument-4.6.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e474c56da636253dfdca7cd1998b240d6b39f7ed34777362db69224fcf053b1"}, + {file = "pyinstrument-4.6.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4b559322f30509ad8f082561792352d0805b3edfa508e492a36041fdc009259"}, + {file = "pyinstrument-4.6.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:06a8578b2943eb1dbbf281e1e59e44246acfefd79e1b06d4950f01b693de12af"}, + {file = "pyinstrument-4.6.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7bd3da31c46f1c1cb7ae89031725f6a1d1015c2041d9c753fe23980f5f9fd86c"}, + {file = "pyinstrument-4.6.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e63f4916001aa9c625976a50779282e0a5b5e9b17c52a50ef4c651e468ed5b88"}, + {file = "pyinstrument-4.6.2-cp38-cp38-win32.whl", hash = "sha256:32ec8db6896b94af790a530e1e0edad4d0f941a0ab8dd9073e5993e7ea46af7d"}, + {file = "pyinstrument-4.6.2-cp38-cp38-win_amd64.whl", hash = "sha256:a59fc4f7db738a094823afe6422509fa5816a7bf74e768ce5a7a2ddd91af40ac"}, + {file = "pyinstrument-4.6.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3a165e0d2deb212d4cf439383982a831682009e1b08733c568cac88c89784e62"}, + {file = "pyinstrument-4.6.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7ba858b3d6f6e5597c641edcc0e7e464f85aba86d71bc3b3592cb89897bf43f6"}, + {file = "pyinstrument-4.6.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fd8e547cf3df5f0ec6e4dffbe2e857f6b28eda51b71c3c0b5a2fc0646527835"}, + {file = "pyinstrument-4.6.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0de2c1714a37a820033b19cf134ead43299a02662f1379140974a9ab733c5f3a"}, + {file = "pyinstrument-4.6.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01fc45dedceec3df81668d702bca6d400d956c8b8494abc206638c167c78dfd9"}, + {file = "pyinstrument-4.6.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5b6e161ef268d43ee6bbfae7fd2cdd0a52c099ddd21001c126ca1805dc906539"}, + {file = "pyinstrument-4.6.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6ba8e368d0421f15ba6366dfd60ec131c1b46505d021477e0f865d26cf35a605"}, + {file = "pyinstrument-4.6.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edca46f04a573ac2fb11a84b937844e6a109f38f80f4b422222fb5be8ecad8cb"}, + {file = "pyinstrument-4.6.2-cp39-cp39-win32.whl", hash = "sha256:baf375953b02fe94d00e716f060e60211ede73f49512b96687335f7071adb153"}, + {file = "pyinstrument-4.6.2-cp39-cp39-win_amd64.whl", hash = "sha256:af1a953bce9fd530040895d01ff3de485e25e1576dccb014f76ba9131376fcad"}, + {file = "pyinstrument-4.6.2.tar.gz", hash = "sha256:0002ee517ed8502bbda6eb2bb1ba8f95a55492fcdf03811ba13d4806e50dd7f6"}, +] + +[package.extras] +bin = ["click", "nox"] +docs = ["furo (==2021.6.18b36)", "myst-parser (==0.15.1)", "sphinx (==4.2.0)", "sphinxcontrib-programoutput (==0.17)"] +examples = ["django", "numpy"] +test = ["flaky", "greenlet (>=3.0.0a1)", "ipython", "pytest", "pytest-asyncio (==0.12.0)", "sphinx-autobuild (==2021.3.14)", "trio"] +types = ["typing-extensions"] + [[package]] name = "pyparsing" version = "3.1.1" @@ -4293,4 +4369,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10,<3.13" -content-hash = "c562d2ef096f70f422239ef6d9ed9b8202d25e5cfe3942f096da3804b0026d3f" +content-hash = "4a68d513fb44080ec5079247580a934360c7ae68a35ae6203e94d62173c5e933" From 57d33bbf4c71cb4ff2c6b712670e6ead380ff7db Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 18:55:07 +0300 Subject: [PATCH 03/18] Update comment --- packages/examples/cvat/exchange-oracle/src/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 9d909b46d9..524279aace 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -162,7 +162,7 @@ class FeaturesConfig: "Allow to log request details for each request" profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", False)) - "Allow to profile specific requests by specifying profile=1" + "Allow to profile specific requests" class CoreConfig: From d6edd06a424260e5aa71ebdfe7d156020b2d6339 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 22:17:16 +0300 Subject: [PATCH 04/18] Fix invalid logging for requests with body --- .../examples/cvat/exchange-oracle/src/endpoints/middleware.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py index 1c4948d903..3fdbec2929 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -65,7 +65,6 @@ def __init__(self, app: FastAPI) -> None: async def dispatch(self, request: Request, call_next: Callable) -> Response: logging_dict: dict[str, Any] = {} - await request.body() response, response_dict = await self._log_response(call_next, request) request_dict = await self._log_request(request) logging_dict["request"] = request_dict From 41e400ebdf88b536c6024450195b8ceb12c09b1f Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 22:49:57 +0300 Subject: [PATCH 05/18] Remove disallowed parameter combination from get_available_projects - distinct cant be used with for update --- packages/examples/cvat/exchange-oracle/src/services/cvat.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 58fc471ef4..57a38f1abd 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -142,11 +142,9 @@ def get_projects_by_status( return projects -def get_available_projects( - session: Session, *, limit: int = 10, for_update: Union[bool, ForUpdateParams] = False -) -> List[Project]: +def get_available_projects(session: Session, *, limit: int = 10) -> List[Project]: return ( - _maybe_for_update(session.query(Project), enable=for_update) + session.query(Project) .where( (Project.status == ProjectStatuses.annotation.value) & Project.jobs.any( From 6d9e96eea54b6f0bff3cfa006273836d79c890fd Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 23:17:04 +0300 Subject: [PATCH 06/18] Improve performance for project and task status trackers --- .../cvat/exchange-oracle/src/core/config.py | 3 +- .../src/crons/state_trackers.py | 8 ++++- .../cvat/exchange-oracle/src/services/cvat.py | 29 +++++++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 524279aace..8e96cbb06a 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -68,8 +68,9 @@ class CronConfig: track_completed_projects_int = int(os.environ.get("TRACK_COMPLETED_PROJECTS_INT", 30)) track_completed_projects_chunk_size = os.environ.get("TRACK_COMPLETED_PROJECTS_CHUNK_SIZE", 5) track_completed_tasks_int = int(os.environ.get("TRACK_COMPLETED_TASKS_INT", 30)) - track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) + track_completed_tasks_chunk_size = os.environ.get("TRACK_COMPLETED_TASKS_CHUNK_SIZE", 20) track_creating_tasks_int = int(os.environ.get("TRACK_CREATING_TASKS_INT", 300)) + track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) track_assignments_int = int(os.environ.get("TRACK_ASSIGNMENTS_INT", 5)) track_assignments_chunk_size = os.environ.get("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py index fed32a2840..8996173e52 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py @@ -34,6 +34,7 @@ def track_completed_projects() -> None: projects = cvat_service.get_projects_by_status( session, ProjectStatuses.annotation, + task_status=TaskStatuses.completed, limit=CronConfig.track_completed_projects_chunk_size, for_update=ForUpdateParams(skip_locked=True), ) @@ -74,7 +75,12 @@ def track_completed_tasks() -> None: logger.debug("Starting cron job") with SessionLocal.begin() as session: tasks = cvat_service.get_tasks_by_status( - session, TaskStatuses.annotation, for_update=ForUpdateParams(skip_locked=True) + session, + TaskStatuses.annotation, + job_status=JobStatuses.completed, + project_status=ProjectStatuses.annotation, + limit=CronConfig.track_completed_tasks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), ) completed_task_ids = [] diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 57a38f1abd..d2b4ce3bce 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -127,6 +127,7 @@ def get_projects_by_status( status: ProjectStatuses, *, included_types: Optional[Sequence[TaskTypes]] = None, + task_status: Optional[TaskStatuses] = None, limit: int = 5, for_update: Union[bool, ForUpdateParams] = False, ) -> List[Project]: @@ -134,6 +135,9 @@ def get_projects_by_status( Project.status == status.value ) + if task_status: + projects = projects.where(Project.tasks.any(Task.status == task_status.value)) + if included_types is not None: projects = projects.where(Project.job_type.in_([t.value for t in included_types])) @@ -341,14 +345,29 @@ def get_tasks_by_cvat_id( def get_tasks_by_status( - session: Session, status: TaskStatuses, *, for_update: Union[bool, ForUpdateParams] = False + session: Session, + status: TaskStatuses, + *, + job_status: Optional[JobStatuses] = None, + project_status: Optional[ProjectStatuses] = None, + for_update: Union[bool, ForUpdateParams] = False, + limit: Optional[int] = 20, ) -> List[Task]: - return ( - _maybe_for_update(session.query(Task), enable=for_update) - .where(Task.status == status.value) - .all() + query = _maybe_for_update(session.query(Task), enable=for_update).where( + Task.status == status.value ) + if job_status: + query = query.where(Task.jobs.any(Job.status == job_status.value)) + + if project_status: + query = query.where(Task.project.has(Project.status == project_status.value)) + + if limit: + query = query.limit(limit) + + return query.all() + def update_task_status(session: Session, task_id: int, status: TaskStatuses) -> None: upd = update(Task).where(Task.id == task_id).values(status=status.value) From c1e89a45b8904dd01044b213d0097bf3793257ac Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 17 May 2024 23:17:24 +0300 Subject: [PATCH 07/18] Fix error handling in escrow status checks --- .../cvat/exchange-oracle/src/handlers/completed_escrows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py b/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py index be73b0b800..d78f88fae4 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py @@ -266,7 +266,7 @@ def _process_skeletons_from_boxes_escrows(self): except Exception as e: logger.error( "Failed to handle completed projects for escrow {}: {}".format( - escrow_address, e + completed_project.escrow_address, e ) ) continue From 347b7f677d9385394c9c190a59ec13f90c4cae43 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 01:26:01 +0300 Subject: [PATCH 08/18] Optimize db calls for getting new assignment --- .../cvat/exchange-oracle/src/services/cvat.py | 41 +++++++++++++++++++ .../exchange-oracle/src/services/exchange.py | 33 +++++---------- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index d2b4ce3bce..9c6ff82d76 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -421,6 +421,8 @@ def finish_data_uploads(session: Session, uploads: list[DataUpload]) -> None: # Job + + def create_job( session: Session, cvat_id: int, @@ -501,6 +503,27 @@ def count_jobs_by_escrow_address( ) +def get_free_job( + session: Session, + cvat_projects: List[int], + *, + for_update: Union[bool, ForUpdateParams] = False, +) -> Optional[Job]: + return ( + _maybe_for_update(session.query(Job), enable=for_update) + .where( + Job.cvat_project_id.in_(cvat_projects), + ~Job.assignments.any( + (Assignment.status == AssignmentStatuses.completed.value) + | (Assignment.status == AssignmentStatuses.created.value) + & (Assignment.completed_at == None) + & (utcnow() < Assignment.expires_at) + ), + ) + .first() + ) + + # Users @@ -661,6 +684,24 @@ def get_user_assignments_in_cvat_projects( ) +def count_active_user_assignments( + session: Session, + wallet_address: int, + cvat_projects: List[int], +) -> int: + return ( + session.query(Assignment) + .where( + Assignment.job.has(Job.cvat_project_id.in_(cvat_projects)), + Assignment.user_wallet_address == wallet_address, + Assignment.status == AssignmentStatuses.created.value, + Assignment.completed_at == None, + utcnow() < Assignment.expires_at, + ) + .count() + ) + + # Image def add_project_images(session: Session, cvat_project_id: int, filenames: List[str]) -> None: session.execute( diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index 9795cb5db9..bec07263a5 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -134,31 +134,20 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: manifest = parse_manifest(get_escrow_manifest(project.chain_id, project.escrow_address)) - unassigned_job: Optional[models.Job] = None - unfinished_assignments: list[models.Assignment] = [] - for job in project.jobs: - job_assignment = job.latest_assignment - if job_assignment and not job_assignment.is_finished: - unfinished_assignments.append(job_assignment) - - if ( - not unassigned_job - and job.status == JobStatuses.new - and (not job_assignment or job_assignment.is_finished) - ): - unassigned_job = job - - now = utcnow() - unfinished_user_assignments = [ - assignment - for assignment in unfinished_assignments - if assignment.user_wallet_address == wallet_address and now < assignment.expires_at - ] - if unfinished_user_assignments: + has_active_assignments = ( + cvat_service.count_active_user_assignments( + session, wallet_address=wallet_address, cvat_projects=[project.cvat_id] + ) + > 0 + ) + if has_active_assignments: raise UserHasUnfinishedAssignmentError( "The user already has an unfinished assignment in this project" ) + unassigned_job = cvat_service.get_free_job( + session, cvat_projects=[project.cvat_id], for_update=True + ) if not unassigned_job: return None @@ -166,7 +155,7 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: session, wallet_address=user.wallet_address, cvat_job_id=unassigned_job.cvat_id, - expires_at=now + expires_at=utcnow() + timedelta( seconds=manifest.annotation.max_time or get_default_assignment_timeout(manifest.annotation.type) From e7dab17a314cadf1052ef7c80fa140006e3e70fa Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 01:26:39 +0300 Subject: [PATCH 09/18] Optimize cvat calls for getting new assignment --- packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py | 4 ++-- .../examples/cvat/exchange-oracle/src/services/exchange.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index f022e72d5c..49cd3dac95 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -559,7 +559,7 @@ def update_job_assignee(id: str, assignee_id: Optional[int]): raise -def restart_job(id: str): +def restart_job(id: str, *, assignee_id: Optional[int] = None): logger = logging.getLogger("app") with get_api_client() as api_client: @@ -567,7 +567,7 @@ def restart_job(id: str): api_client.jobs_api.partial_update( id=id, patched_job_write_request=models.PatchedJobWriteRequest( - stage="annotation", state="new" + stage="annotation", state="new", assignee_id=assignee_id ), ) except exceptions.ApiException as e: diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index bec07263a5..ccaa4715eb 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -163,8 +163,7 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: ) cvat_api.clear_job_annotations(unassigned_job.cvat_id) - cvat_api.restart_job(unassigned_job.cvat_id) - cvat_api.update_job_assignee(unassigned_job.cvat_id, assignee_id=user.cvat_id) + cvat_api.restart_job(unassigned_job.cvat_id, assignee_id=user.cvat_id) # rollback is automatic within the transaction return assignment_id From 9385bec9eb2caa0876c85e3b4e412ee6f0c264b9 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 01:27:06 +0300 Subject: [PATCH 10/18] Add and update tests --- .../tests/api/test_exchange_api.py | 3 +- .../integration/services/test_exchange.py | 139 +++++++++++++++++- .../exchange-oracle/tests/utils/db_helper.py | 4 +- 3 files changed, 141 insertions(+), 5 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py index 809e523c9e..9eba05e767 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py @@ -309,8 +309,7 @@ def test_create_assignment_200(client: TestClient) -> None: json={"wallet_address": user_address}, ) cvat_api.clear_job_annotations.assert_called_once() - cvat_api.restart_job.assert_called_once() - cvat_api.update_job_assignee.assert_called_once() + cvat_api.restart_job.assert_called_once_with(cvat_job_1.cvat_id, assignee_id=user.cvat_id) assert response.status_code == 200 db_assignment = session.query(Assignment).filter_by(user_wallet_address=user_address).first() diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index b339ceef34..738af890a6 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -19,7 +19,12 @@ serialize_task, ) -from tests.utils.db_helper import create_project, create_project_task_and_job +from tests.utils.db_helper import ( + create_job, + create_project, + create_project_task_and_job, + create_task, +) class ServiceIntegrationTest(unittest.TestCase): @@ -227,6 +232,50 @@ def test_create_assignment(self): self.assertEqual(assignment.user_wallet_address, user_address) self.assertEqual(assignment.status, AssignmentStatuses.created) + def test_create_assignment_many_jobs(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + cvat_task_2 = create_task(self.session, 2, cvat_project.cvat_id) + cvat_job_2 = create_job(self.session, 2, cvat_task_2.cvat_id, cvat_project.cvat_id) + + user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + + now = datetime.now() + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address, + cvat_job_id=cvat_job_1.cvat_id, + created_at=now - timedelta(hours=1), + completed_at=now - timedelta(minutes=40), + expires_at=datetime.now() + timedelta(days=1), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assingment_id = create_assignment(cvat_project.id, user_address) + + assignment = self.session.query(Assignment).filter_by(id=assingment_id).first() + + self.assertEqual(assignment.cvat_job_id, cvat_job_2.cvat_id) + self.assertEqual(assignment.user_wallet_address, user_address) + self.assertEqual(assignment.status, AssignmentStatuses.created) + def test_create_assignment_invalid_user_address(self): cvat_project_1, _, _ = create_project_task_and_job( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 @@ -280,3 +329,91 @@ def test_create_assignment_unfinished_assignment(self): with self.assertRaises(HTTPException): create_assignment("1", user_address) + + def test_create_assignment_no_available_jobs_completed_assignment(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + + user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user1 = User( + wallet_address=user_address1, + cvat_email="test1@hmt.ai", + cvat_id=1, + ) + self.session.add(user1) + + user_address2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + user2 = User( + wallet_address=user_address2, + cvat_email="test2@hmt.ai", + cvat_id=2, + ) + self.session.add(user2) + + now = datetime.now() + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address1, + cvat_job_id=cvat_job_1.cvat_id, + created_at=now - timedelta(days=1), + completed_at=now - timedelta(hours=22), + expires_at=now + timedelta(hours=2), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project.id, user_address2) + + self.assertEqual(assignment_id, None) + + def test_create_assignment_no_available_jobs_active_foreign_assignment(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + + user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user1 = User( + wallet_address=user_address1, + cvat_email="test1@hmt.ai", + cvat_id=1, + ) + self.session.add(user1) + + user_address2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + user2 = User( + wallet_address=user_address2, + cvat_email="test2@hmt.ai", + cvat_id=2, + ) + self.session.add(user2) + + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address1, + cvat_job_id=cvat_job_1.cvat_id, + expires_at=datetime.now() + timedelta(days=1), + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project.id, user_address2) + + self.assertEqual(assignment_id, None) diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py b/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py index 8a5a81e9cc..2a332a1780 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py @@ -61,8 +61,8 @@ def create_job(session: Session, cvat_id: int, cvat_task_id: int, cvat_project_i cvat_job = Job( id=str(uuid.uuid4()), cvat_id=cvat_id, - cvat_project_id=cvat_id, - cvat_task_id=cvat_id, + cvat_project_id=cvat_project_id, + cvat_task_id=cvat_task_id, status=JobStatuses.new, ) session.add(cvat_job) From 984415615fde0d883bc2578d38317880ee2d38ea Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 02:14:05 +0300 Subject: [PATCH 11/18] Fix logging for requests with body --- .../src/endpoints/middleware.py | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py index 3fdbec2929..a346eddcc6 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -2,6 +2,8 @@ import time from typing import Any, Callable +import fastapi +import packaging.version as pv from fastapi import FastAPI, Request, Response from fastapi.responses import HTMLResponse, StreamingResponse from pyinstrument import Profiler @@ -58,13 +60,32 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): """ + @staticmethod + async def _set_body(request: Request, body: bytes): + # Before FastAPI 0.108.0 infinite hang is expected, + # if request body is awaited more than once. + # It's not needed when using FastAPI >= 0.108.0. + # https://github.com/tiangolo/fastapi/discussions/8187#discussioncomment-7962889 + if pv.parse(fastapi.__version__) >= pv.Version("0.108.0"): + return + + async def receive(): + return {"type": "http.request", "body": body} + + request._receive = receive + def __init__(self, app: FastAPI) -> None: super().__init__(app) self.logger = get_root_logger() + self.max_displayed_body_size = 200 + async def dispatch(self, request: Request, call_next: Callable) -> Response: logging_dict: dict[str, Any] = {} + body = await request.body() + await self._set_body(request, body) + response, response_dict = await self._log_response(call_next, request) request_dict = await self._log_request(request) logging_dict["request"] = request_dict @@ -96,10 +117,26 @@ async def _log_request(self, request: Request) -> dict[str, Any]: } try: - body = await request.json() + body = await request.body() + await self._set_body(request, body) except Exception: body = None else: + if body is not None: + raw_body = False + + if len(body) < self.max_displayed_body_size: + try: + body = json.loads(body) + except (json.JSONDecodeError, TypeError): + raw_body = True + else: + raw_body = True + + if raw_body: + body = body.decode(errors="ignore") + body = body[: self.max_displayed_body_size] + request_logging["body"] = body return request_logging From 27004ddea81a181137e6df700c84758f43f71552 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 03:43:31 +0300 Subject: [PATCH 12/18] Remove manifest dependency from assignment creation --- .../cvat/exchange-oracle/src/services/exchange.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index ccaa4715eb..9813022c25 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -2,10 +2,9 @@ from typing import Optional import src.cvat.api_calls as cvat_api -import src.models.cvat as models import src.services.cvat as cvat_service from src.chain.escrow import get_escrow_manifest -from src.core.types import AssignmentStatuses, JobStatuses, PlatformTypes, ProjectStatuses +from src.core.types import AssignmentStatuses, PlatformTypes, ProjectStatuses, TaskTypes from src.db import SessionLocal from src.schemas import exchange as service_api from src.utils.assignments import ( @@ -48,8 +47,7 @@ def serialize_task( title=f"Task {project.escrow_address[:10]}", description=manifest.annotation.description, job_bounty=manifest.job_bounty, - job_time_limit=manifest.annotation.max_time - or get_default_assignment_timeout(manifest.annotation.type), + job_time_limit=get_default_assignment_timeout(manifest.annotation.type), job_size=get_default_assignment_size(manifest), job_type=project.job_type, platform=PlatformTypes.CVAT, @@ -132,8 +130,6 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: ) return None - manifest = parse_manifest(get_escrow_manifest(project.chain_id, project.escrow_address)) - has_active_assignments = ( cvat_service.count_active_user_assignments( session, wallet_address=wallet_address, cvat_projects=[project.cvat_id] @@ -156,10 +152,7 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: wallet_address=user.wallet_address, cvat_job_id=unassigned_job.cvat_id, expires_at=utcnow() - + timedelta( - seconds=manifest.annotation.max_time - or get_default_assignment_timeout(manifest.annotation.type) - ), + + timedelta(seconds=get_default_assignment_timeout(TaskTypes(project.job_type))), ) cvat_api.clear_job_annotations(unassigned_job.cvat_id) From c638290e9a5824f796c785ba277422c61401d5f9 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 03:44:04 +0300 Subject: [PATCH 13/18] Fix job assignment in cvat --- packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index 49cd3dac95..2f13315071 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -567,7 +567,7 @@ def restart_job(id: str, *, assignee_id: Optional[int] = None): api_client.jobs_api.partial_update( id=id, patched_job_write_request=models.PatchedJobWriteRequest( - stage="annotation", state="new", assignee_id=assignee_id + stage="annotation", state="new", assignee=assignee_id ), ) except exceptions.ApiException as e: From efb33673787e143d804585a58d403900e9c043c0 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 03:46:46 +0300 Subject: [PATCH 14/18] Add a cvat api client context --- .../exchange-oracle/src/cvat/api_calls.py | 20 ++++++++++++++++++- .../exchange-oracle/src/services/exchange.py | 6 ++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index 2f13315071..3b5a7b4a8a 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -2,12 +2,14 @@ import json import logging import zipfile +from contextlib import contextmanager +from contextvars import ContextVar from datetime import timedelta from enum import Enum from http import HTTPStatus from io import BytesIO from time import sleep -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models from cvat_sdk.api_client.api_client import Endpoint @@ -90,7 +92,23 @@ def _get_annotations( return file_buffer +_api_client_context: ContextVar[ApiClient] = ContextVar("api_client", default=None) + + +@contextmanager +def api_client_context(api_client: ApiClient) -> Generator[ApiClient, None, None]: + old = _api_client_context.set(api_client) + try: + yield api_client + finally: + _api_client_context.reset(old) + + def get_api_client() -> ApiClient: + current_api_client = _api_client_context.get() + if current_api_client: + return current_api_client + configuration = Configuration( host=Config.cvat_config.cvat_url, username=Config.cvat_config.cvat_admin, diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index 9813022c25..b651180654 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -155,8 +155,10 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: + timedelta(seconds=get_default_assignment_timeout(TaskTypes(project.job_type))), ) - cvat_api.clear_job_annotations(unassigned_job.cvat_id) - cvat_api.restart_job(unassigned_job.cvat_id, assignee_id=user.cvat_id) + with cvat_api.api_client_context(cvat_api.get_api_client()): + cvat_api.clear_job_annotations(unassigned_job.cvat_id) + cvat_api.restart_job(unassigned_job.cvat_id, assignee_id=user.cvat_id) + # rollback is automatic within the transaction return assignment_id From 49b5878ad98bb37eda5e9fd98139d4e58253ee87 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 03:47:09 +0300 Subject: [PATCH 15/18] Fix exception message handling --- .../cvat/exchange-oracle/src/handlers/error_handlers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py b/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py index b8d1c424b1..86d410aabb 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py @@ -39,7 +39,9 @@ async def http_exception_handler(_, exc): @app.exception_handler(Exception) async def generic_exception_handler(_, exc: Exception): message = ( - "Something went wrong" if Config.environment != "development" else ".".join(exc.args) + "Something went wrong" + if Config.environment != "development" + else ".".join(map(str, exc.args)) ) return JSONResponse(content={"message": message}, status_code=500) From 4734178ce092bca65d63877d09fdcdac1c5753d7 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Sat, 18 May 2024 03:51:55 +0300 Subject: [PATCH 16/18] Update .env template --- packages/examples/cvat/exchange-oracle/src/.env.template | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index 510d41284a..d654e5624e 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -35,14 +35,17 @@ PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE= TRACK_COMPLETED_PROJECTS_INT= TRACK_COMPLETED_PROJECTS_CHUNK_SIZE= TRACK_COMPLETED_TASKS_INT= +TRACK_COMPLETED_TASKS_CHUNK_SIZE= TRACK_COMPLETED_ESCROWS_INT= TRACK_COMPLETED_ESCROWS_CHUNK_SIZE= -PROCESS_JOB_LAUNCHER_WEBHOOKS_INT= TRACK_CREATING_TASKS_INT= +TRACK_CREATING_TASKS_CHUNK_SIZE= +TRACK_ASSIGNMENTS_INT= +TRACK_ASSIGNMENTS_CHUNK_SIZE= REJECTED_PROJECTS_CHUNK_SIZE= ACCEPTED_PROJECTS_CHUNK_SIZE= -TRACK_ESCROW_CREATION_CHUNK_SIZE= TRACK_ESCROW_CREATION_INT= +TRACK_ESCROW_CREATION_CHUNK_SIZE= TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES= TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE= From 96c6f6d0fc123762d94969954bf09d62282bdb53 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 20 May 2024 01:27:30 +0300 Subject: [PATCH 17/18] Fix assignment in validated rejected jobs --- .../cvat/exchange-oracle/src/services/cvat.py | 4 +- .../integration/services/test_exchange.py | 58 +++++++++++++++++-- .../exchange-oracle/tests/utils/db_helper.py | 14 +++-- 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 9c6ff82d76..f4d3b986fc 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -513,9 +513,9 @@ def get_free_job( _maybe_for_update(session.query(Job), enable=for_update) .where( Job.cvat_project_id.in_(cvat_projects), + Job.status == JobStatuses.new, ~Job.assignments.any( - (Assignment.status == AssignmentStatuses.completed.value) - | (Assignment.status == AssignmentStatuses.created.value) + (Assignment.status == AssignmentStatuses.created.value) & (Assignment.completed_at == None) & (utcnow() < Assignment.expires_at) ), diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index 738af890a6..7741790418 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -8,7 +8,7 @@ from pydantic import ValidationError import src.services.cvat as cvat_service -from src.core.types import AssignmentStatuses, PlatformTypes, ProjectStatuses +from src.core.types import AssignmentStatuses, JobStatuses, PlatformTypes, ProjectStatuses from src.db import SessionLocal from src.models.cvat import Assignment, User from src.schemas import exchange as service_api @@ -224,18 +224,20 @@ def test_create_assignment(self): ): manifest = json.load(data) mock_get_manifest.return_value = manifest - assingment_id = create_assignment(cvat_project_1.id, user_address) + assignment_id = create_assignment(cvat_project_1.id, user_address) - assignment = self.session.query(Assignment).filter_by(id=assingment_id).first() + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() self.assertEqual(assignment.cvat_job_id, cvat_job_1.cvat_id) self.assertEqual(assignment.user_wallet_address, user_address) self.assertEqual(assignment.status, AssignmentStatuses.created) - def test_create_assignment_many_jobs(self): + def test_create_assignment_many_jobs_1_completed(self): cvat_project, _, cvat_job_1 = create_project_task_and_job( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) + cvat_job_1.status = JobStatuses.completed.value + cvat_task_2 = create_task(self.session, 2, cvat_project.cvat_id) cvat_job_2 = create_job(self.session, 2, cvat_task_2.cvat_id, cvat_project.cvat_id) @@ -268,9 +270,9 @@ def test_create_assignment_many_jobs(self): ): manifest = json.load(data) mock_get_manifest.return_value = manifest - assingment_id = create_assignment(cvat_project.id, user_address) + assignment_id = create_assignment(cvat_project.id, user_address) - assignment = self.session.query(Assignment).filter_by(id=assingment_id).first() + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() self.assertEqual(assignment.cvat_job_id, cvat_job_2.cvat_id) self.assertEqual(assignment.user_wallet_address, user_address) @@ -334,6 +336,7 @@ def test_create_assignment_no_available_jobs_completed_assignment(self): cvat_project, _, cvat_job_1 = create_project_task_and_job( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 ) + cvat_job_1.status = JobStatuses.completed.value user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" user1 = User( @@ -417,3 +420,46 @@ def test_create_assignment_no_available_jobs_active_foreign_assignment(self): assignment_id = create_assignment(cvat_project.id, user_address2) self.assertEqual(assignment_id, None) + + def test_create_assignment_in_validated_and_rejected_job(self): + cvat_project_1, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + cvat_job_1.status = JobStatuses.new.value # validated and rejected return to 'new' + + user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + + now = datetime.now() + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address, + cvat_job_id=cvat_job_1.cvat_id, + created_at=now - timedelta(hours=1), + completed_at=now - timedelta(minutes=40), + expires_at=datetime.now() + timedelta(days=1), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project_1.id, user_address) + + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() + + self.assertEqual(assignment.cvat_job_id, cvat_job_1.cvat_id) + self.assertEqual(assignment.user_wallet_address, user_address) + self.assertEqual(assignment.status, AssignmentStatuses.created) diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py b/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py index 2a332a1780..79103a0064 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py @@ -12,7 +12,7 @@ def create_project( cvat_id: int, *, status: ProjectStatuses = ProjectStatuses.annotation, -) -> tuple: +) -> Project: cvat_project = Project( id=str(uuid.uuid4()), cvat_id=cvat_id, @@ -28,13 +28,15 @@ def create_project( return cvat_project -def create_project_and_task(session: Session, escrow_address: str, cvat_id: int) -> tuple: +def create_project_and_task( + session: Session, escrow_address: str, cvat_id: int +) -> tuple[Project, Task]: cvat_project = create_project(session, escrow_address, cvat_id) cvat_task = create_task(session, cvat_project_id=cvat_project.cvat_id, cvat_id=cvat_id) return cvat_project, cvat_task -def create_task(session: Session, cvat_id: int, cvat_project_id: str) -> tuple: +def create_task(session: Session, cvat_id: int, cvat_project_id: str) -> Task: cvat_task = Task( id=str(uuid.uuid4()), cvat_id=cvat_id, @@ -46,7 +48,9 @@ def create_task(session: Session, cvat_id: int, cvat_project_id: str) -> tuple: return cvat_task -def create_project_task_and_job(session: Session, escrow_address: str, cvat_id: int) -> tuple: +def create_project_task_and_job( + session: Session, escrow_address: str, cvat_id: int +) -> tuple[Project, Task, Job]: cvat_project, cvat_task = create_project_and_task(session, escrow_address, cvat_id) cvat_job = create_job( session, @@ -57,7 +61,7 @@ def create_project_task_and_job(session: Session, escrow_address: str, cvat_id: return cvat_project, cvat_task, cvat_job -def create_job(session: Session, cvat_id: int, cvat_task_id: int, cvat_project_id: int) -> tuple: +def create_job(session: Session, cvat_id: int, cvat_task_id: int, cvat_project_id: int) -> Job: cvat_job = Job( id=str(uuid.uuid4()), cvat_id=cvat_id, From 3df142a6c66d2e11fbb7d64df877ce733e363842 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 20 May 2024 19:32:52 +0300 Subject: [PATCH 18/18] Add threading for exchange oracle blocking requests --- .../cvat/exchange-oracle/src/.env.template | 6 +++ .../cvat/exchange-oracle/src/__init__.py | 15 +++++++ .../cvat/exchange-oracle/src/core/config.py | 15 +++++++ .../cvat/exchange-oracle/src/db/__init__.py | 2 + .../exchange-oracle/src/endpoints/exchange.py | 42 ++++++++++++++----- .../exchange-oracle/src/services/exchange.py | 20 +++++++-- .../exchange-oracle/src/utils/concurrency.py | 36 ++++++++++++++++ 7 files changed, 121 insertions(+), 15 deletions(-) create mode 100644 packages/examples/cvat/exchange-oracle/src/utils/concurrency.py diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index d654e5624e..82b7f302b1 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -5,6 +5,12 @@ ENVIRONMENT= WORKERS_AMOUNT= WEBHOOK_MAX_RETRIES= WEBHOOK_DELAY_IF_FAILED= +MAX_WORKER_THREADS= + +# DB + +MAX_DB_CONNECTIONS= +DB_CONNECTION_RECYCLE_TIMEOUT= # Postgres_config diff --git a/packages/examples/cvat/exchange-oracle/src/__init__.py b/packages/examples/cvat/exchange-oracle/src/__init__.py index db507cee5e..0be62821e6 100644 --- a/packages/examples/cvat/exchange-oracle/src/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/__init__.py @@ -7,6 +7,7 @@ from src.endpoints import init_api from src.handlers.error_handlers import setup_error_handlers from src.log import setup_logging +from src.utils.concurrency import fastapi_set_max_threads setup_logging() @@ -31,6 +32,20 @@ async def startup_event(): logger = logging.getLogger("app") logger.info("Exchange Oracle is up and running!") + if Config.features.db_connection_limit < Config.features.thread_limit: + logger.warn( + "The DB connection limit {} is less than maximum number of working threads {}. " + "This configuration can cause runtime errors on long blocking DB calls. " + "Consider changing values of the {} and {} environment variables.".format( + Config.features.db_connection_limit, + Config.features.thread_limit, + Config.features.DB_CONNECTION_LIMIT_ENV_VAR, + Config.features.THREAD_LIMIT_ENV_VAR, + ) + ) + + await fastapi_set_max_threads(Config.features.thread_limit) + is_test = Config.environment == "test" if not is_test: diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 8e96cbb06a..6976dcb7e5 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -153,6 +153,9 @@ def bucket_url(cls): class FeaturesConfig: + THREAD_LIMIT_ENV_VAR = "MAX_WORKER_THREADS" + DB_CONNECTION_LIMIT_ENV_VAR = "MAX_DB_CONNECTIONS" + enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no")) "Allows using a custom host in manifest bucket urls" @@ -165,6 +168,18 @@ class FeaturesConfig: profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", False)) "Allow to profile specific requests" + thread_limit = int(os.getenv(THREAD_LIMIT_ENV_VAR, 5)) + "Maximum number of threads for blocking requests" + + db_connection_limit = int(os.getenv(DB_CONNECTION_LIMIT_ENV_VAR, 15)) + """ + Maximum number of active parallel DB connections. + The recommended value is >= thread_limit + cron jobs count + """ + + db_connection_recycle_timeout = int(os.getenv("DB_CONNECTION_RECYCLE_TIMEOUT", 600)) + "DB connection lifetime after the last action on the connection, in seconds" + class CoreConfig: default_assignment_time = int(os.environ.get("DEFAULT_ASSIGNMENT_TIME", 1800)) diff --git a/packages/examples/cvat/exchange-oracle/src/db/__init__.py b/packages/examples/cvat/exchange-oracle/src/db/__init__.py index 6e9c85cded..e79f7d0d9e 100644 --- a/packages/examples/cvat/exchange-oracle/src/db/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/db/__init__.py @@ -9,6 +9,8 @@ DATABASE_URL, echo="debug" if Config.loglevel <= src.utils.logging.TRACE else False, connect_args={"options": "-c lock_timeout={:d}".format(Config.postgres_config.lock_timeout)}, + pool_size=Config.features.db_connection_limit, + pool_recycle=Config.features.db_connection_recycle_timeout, ) SessionLocal = sessionmaker(autocommit=False, bind=engine) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py index fabeb3543f..c8402d6f6d 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py @@ -8,18 +8,22 @@ import src.services.cvat as cvat_service import src.services.exchange as oracle_service from src.db import SessionLocal +from src.db import errors as db_errors from src.schemas.exchange import AssignmentRequest, TaskResponse, UserRequest, UserResponse +from src.utils.concurrency import run_as_sync from src.validators.signature import validate_human_app_signature router = APIRouter() @router.get("/tasks", description="Lists available tasks") -async def list_tasks( +def list_tasks( wallet_address: Optional[str] = Query(default=None), signature: str = Header(description="Calling service signature"), ) -> list[TaskResponse]: - await validate_human_app_signature(signature) + # Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow) + + run_as_sync(validate_human_app_signature, signature) if not wallet_address: return oracle_service.get_available_tasks() @@ -28,11 +32,13 @@ async def list_tasks( @router.put("/register", description="Binds a CVAT user a to HUMAN App user") -async def register( +def register( user: UserRequest, signature: str = Header(description="Calling service signature"), ) -> UserResponse: - await validate_human_app_signature(signature) + # Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow, CVAT) + + run_as_sync(validate_human_app_signature, signature) with SessionLocal.begin() as session: email_db_user = cvat_service.get_user_by_email(session, user.cvat_email, for_update=True) @@ -97,19 +103,33 @@ async def register( "/tasks/{id}/assignment", description="Start an assignment within the task for the annotator", ) -async def create_assignment( +def create_assignment( data: AssignmentRequest, project_id: str = Path(alias="id"), signature: str = Header(description="Calling service signature"), ) -> TaskResponse: - await validate_human_app_signature(signature) + # Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow, CVAT) + + run_as_sync(validate_human_app_signature, signature) - try: - assignment_id = oracle_service.create_assignment( - project_id=project_id, wallet_address=data.wallet_address + attempt = 0 + max_attempts = 10 + while attempt < max_attempts: + try: + assignment_id = oracle_service.create_assignment( + project_id=project_id, wallet_address=data.wallet_address + ) + break + except oracle_service.UserHasUnfinishedAssignmentError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e + except db_errors.LockNotAvailable: + attempt += 1 + + if attempt >= max_attempts: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail="Too many requests at the moment, please try again later", ) - except oracle_service.UserHasUnfinishedAssignmentError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e if not assignment_id: raise HTTPException( diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index b651180654..d342f3e555 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -155,10 +155,22 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: + timedelta(seconds=get_default_assignment_timeout(TaskTypes(project.job_type))), ) - with cvat_api.api_client_context(cvat_api.get_api_client()): - cvat_api.clear_job_annotations(unassigned_job.cvat_id) - cvat_api.restart_job(unassigned_job.cvat_id, assignee_id=user.cvat_id) + # Need to save the values to use outside the transaction + unassigned_job_cvat_id = unassigned_job.cvat_id + user_cvat_id = user.cvat_id + + # Finish the transaction ASAP to release the locks acquired and unblock other clients. - # rollback is automatic within the transaction + # It's possible that the following part is never completed. In this case the assignment + # will expire as usual after the assignment lifetime, even if not canceled here. + try: + with cvat_api.api_client_context(cvat_api.get_api_client()): + cvat_api.clear_job_annotations(unassigned_job_cvat_id) + cvat_api.restart_job(unassigned_job_cvat_id, assignee_id=user_cvat_id) + except Exception: + with SessionLocal.begin() as session: + cvat_service.update_assignment( + session, assignment_id, status=AssignmentStatuses.canceled + ) return assignment_id diff --git a/packages/examples/cvat/exchange-oracle/src/utils/concurrency.py b/packages/examples/cvat/exchange-oracle/src/utils/concurrency.py new file mode 100644 index 0000000000..4d2676c7d7 --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/utils/concurrency.py @@ -0,0 +1,36 @@ +from functools import partial + +from anyio import from_thread, to_thread + + +def _check_backend(): + import fastapi.concurrency + + assert hasattr(fastapi.concurrency, "anyio") + + +async def fastapi_set_max_threads(max_threads: int): + """ + Sets the maximum number of active threads in the sync worker pool of FastAPI. + This affects the maximum number of active blocking requests + (the endpoints defined as non-async def ...) in each process. + + """ + _check_backend() + + # https://anyio.readthedocs.io/en/stable/threads.html#adjusting-the-default-maximum-worker-thread-count + to_thread.current_default_thread_limiter().total_tokens = max_threads + + +def run_as_sync(async_fn, *args, **kwargs): + """ + Runs an async function synchronously. + Supposed to be called in blocking endpoints (defined as def ...) + """ + _check_backend() + + if args or kwargs: + async_fn = partial(async_fn, *args, **kwargs) + + with from_thread.start_blocking_portal() as portal: + return portal.call(async_fn)