diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index 510d41284a..82b7f302b1 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -5,6 +5,12 @@ ENVIRONMENT= WORKERS_AMOUNT= WEBHOOK_MAX_RETRIES= WEBHOOK_DELAY_IF_FAILED= +MAX_WORKER_THREADS= + +# DB + +MAX_DB_CONNECTIONS= +DB_CONNECTION_RECYCLE_TIMEOUT= # Postgres_config @@ -35,14 +41,17 @@ PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE= TRACK_COMPLETED_PROJECTS_INT= TRACK_COMPLETED_PROJECTS_CHUNK_SIZE= TRACK_COMPLETED_TASKS_INT= +TRACK_COMPLETED_TASKS_CHUNK_SIZE= TRACK_COMPLETED_ESCROWS_INT= TRACK_COMPLETED_ESCROWS_CHUNK_SIZE= -PROCESS_JOB_LAUNCHER_WEBHOOKS_INT= TRACK_CREATING_TASKS_INT= +TRACK_CREATING_TASKS_CHUNK_SIZE= +TRACK_ASSIGNMENTS_INT= +TRACK_ASSIGNMENTS_CHUNK_SIZE= REJECTED_PROJECTS_CHUNK_SIZE= ACCEPTED_PROJECTS_CHUNK_SIZE= -TRACK_ESCROW_CREATION_CHUNK_SIZE= TRACK_ESCROW_CREATION_INT= +TRACK_ESCROW_CREATION_CHUNK_SIZE= TRACK_COMPLETED_ESCROWS_MAX_DOWNLOADING_RETRIES= TRACK_COMPLETED_ESCROWS_JOBS_DOWNLOADING_BATCH_SIZE= diff --git a/packages/examples/cvat/exchange-oracle/src/__init__.py b/packages/examples/cvat/exchange-oracle/src/__init__.py index db507cee5e..0be62821e6 100644 --- a/packages/examples/cvat/exchange-oracle/src/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/__init__.py @@ -7,6 +7,7 @@ from src.endpoints import init_api from src.handlers.error_handlers import setup_error_handlers from src.log import setup_logging +from src.utils.concurrency import fastapi_set_max_threads setup_logging() @@ -31,6 +32,20 @@ async def startup_event(): logger = logging.getLogger("app") logger.info("Exchange Oracle is up and running!") + if Config.features.db_connection_limit < Config.features.thread_limit: + logger.warn( + "The DB connection limit {} is less than maximum number of working threads {}. " + "This configuration can cause runtime errors on long blocking DB calls. " + "Consider changing values of the {} and {} environment variables.".format( + Config.features.db_connection_limit, + Config.features.thread_limit, + Config.features.DB_CONNECTION_LIMIT_ENV_VAR, + Config.features.THREAD_LIMIT_ENV_VAR, + ) + ) + + await fastapi_set_max_threads(Config.features.thread_limit) + is_test = Config.environment == "test" if not is_test: diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 524279aace..6976dcb7e5 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -68,8 +68,9 @@ class CronConfig: track_completed_projects_int = int(os.environ.get("TRACK_COMPLETED_PROJECTS_INT", 30)) track_completed_projects_chunk_size = os.environ.get("TRACK_COMPLETED_PROJECTS_CHUNK_SIZE", 5) track_completed_tasks_int = int(os.environ.get("TRACK_COMPLETED_TASKS_INT", 30)) - track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) + track_completed_tasks_chunk_size = os.environ.get("TRACK_COMPLETED_TASKS_CHUNK_SIZE", 20) track_creating_tasks_int = int(os.environ.get("TRACK_CREATING_TASKS_INT", 300)) + track_creating_tasks_chunk_size = os.environ.get("TRACK_CREATING_TASKS_CHUNK_SIZE", 5) track_assignments_int = int(os.environ.get("TRACK_ASSIGNMENTS_INT", 5)) track_assignments_chunk_size = os.environ.get("TRACK_ASSIGNMENTS_CHUNK_SIZE", 10) @@ -152,6 +153,9 @@ def bucket_url(cls): class FeaturesConfig: + THREAD_LIMIT_ENV_VAR = "MAX_WORKER_THREADS" + DB_CONNECTION_LIMIT_ENV_VAR = "MAX_DB_CONNECTIONS" + enable_custom_cloud_host = to_bool(os.environ.get("ENABLE_CUSTOM_CLOUD_HOST", "no")) "Allows using a custom host in manifest bucket urls" @@ -164,6 +168,18 @@ class FeaturesConfig: profiling_enabled = to_bool(os.getenv("PROFILING_ENABLED", False)) "Allow to profile specific requests" + thread_limit = int(os.getenv(THREAD_LIMIT_ENV_VAR, 5)) + "Maximum number of threads for blocking requests" + + db_connection_limit = int(os.getenv(DB_CONNECTION_LIMIT_ENV_VAR, 15)) + """ + Maximum number of active parallel DB connections. + The recommended value is >= thread_limit + cron jobs count + """ + + db_connection_recycle_timeout = int(os.getenv("DB_CONNECTION_RECYCLE_TIMEOUT", 600)) + "DB connection lifetime after the last action on the connection, in seconds" + class CoreConfig: default_assignment_time = int(os.environ.get("DEFAULT_ASSIGNMENT_TIME", 1800)) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py index fed32a2840..8996173e52 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py @@ -34,6 +34,7 @@ def track_completed_projects() -> None: projects = cvat_service.get_projects_by_status( session, ProjectStatuses.annotation, + task_status=TaskStatuses.completed, limit=CronConfig.track_completed_projects_chunk_size, for_update=ForUpdateParams(skip_locked=True), ) @@ -74,7 +75,12 @@ def track_completed_tasks() -> None: logger.debug("Starting cron job") with SessionLocal.begin() as session: tasks = cvat_service.get_tasks_by_status( - session, TaskStatuses.annotation, for_update=ForUpdateParams(skip_locked=True) + session, + TaskStatuses.annotation, + job_status=JobStatuses.completed, + project_status=ProjectStatuses.annotation, + limit=CronConfig.track_completed_tasks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), ) completed_task_ids = [] diff --git a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py index f022e72d5c..3b5a7b4a8a 100644 --- a/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py +++ b/packages/examples/cvat/exchange-oracle/src/cvat/api_calls.py @@ -2,12 +2,14 @@ import json import logging import zipfile +from contextlib import contextmanager +from contextvars import ContextVar from datetime import timedelta from enum import Enum from http import HTTPStatus from io import BytesIO from time import sleep -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models from cvat_sdk.api_client.api_client import Endpoint @@ -90,7 +92,23 @@ def _get_annotations( return file_buffer +_api_client_context: ContextVar[ApiClient] = ContextVar("api_client", default=None) + + +@contextmanager +def api_client_context(api_client: ApiClient) -> Generator[ApiClient, None, None]: + old = _api_client_context.set(api_client) + try: + yield api_client + finally: + _api_client_context.reset(old) + + def get_api_client() -> ApiClient: + current_api_client = _api_client_context.get() + if current_api_client: + return current_api_client + configuration = Configuration( host=Config.cvat_config.cvat_url, username=Config.cvat_config.cvat_admin, @@ -559,7 +577,7 @@ def update_job_assignee(id: str, assignee_id: Optional[int]): raise -def restart_job(id: str): +def restart_job(id: str, *, assignee_id: Optional[int] = None): logger = logging.getLogger("app") with get_api_client() as api_client: @@ -567,7 +585,7 @@ def restart_job(id: str): api_client.jobs_api.partial_update( id=id, patched_job_write_request=models.PatchedJobWriteRequest( - stage="annotation", state="new" + stage="annotation", state="new", assignee=assignee_id ), ) except exceptions.ApiException as e: diff --git a/packages/examples/cvat/exchange-oracle/src/db/__init__.py b/packages/examples/cvat/exchange-oracle/src/db/__init__.py index 6e9c85cded..e79f7d0d9e 100644 --- a/packages/examples/cvat/exchange-oracle/src/db/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/db/__init__.py @@ -9,6 +9,8 @@ DATABASE_URL, echo="debug" if Config.loglevel <= src.utils.logging.TRACE else False, connect_args={"options": "-c lock_timeout={:d}".format(Config.postgres_config.lock_timeout)}, + pool_size=Config.features.db_connection_limit, + pool_recycle=Config.features.db_connection_recycle_timeout, ) SessionLocal = sessionmaker(autocommit=False, bind=engine) diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py index fabeb3543f..c8402d6f6d 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/exchange.py @@ -8,18 +8,22 @@ import src.services.cvat as cvat_service import src.services.exchange as oracle_service from src.db import SessionLocal +from src.db import errors as db_errors from src.schemas.exchange import AssignmentRequest, TaskResponse, UserRequest, UserResponse +from src.utils.concurrency import run_as_sync from src.validators.signature import validate_human_app_signature router = APIRouter() @router.get("/tasks", description="Lists available tasks") -async def list_tasks( +def list_tasks( wallet_address: Optional[str] = Query(default=None), signature: str = Header(description="Calling service signature"), ) -> list[TaskResponse]: - await validate_human_app_signature(signature) + # Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow) + + run_as_sync(validate_human_app_signature, signature) if not wallet_address: return oracle_service.get_available_tasks() @@ -28,11 +32,13 @@ async def list_tasks( @router.put("/register", description="Binds a CVAT user a to HUMAN App user") -async def register( +def register( user: UserRequest, signature: str = Header(description="Calling service signature"), ) -> UserResponse: - await validate_human_app_signature(signature) + # Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow, CVAT) + + run_as_sync(validate_human_app_signature, signature) with SessionLocal.begin() as session: email_db_user = cvat_service.get_user_by_email(session, user.cvat_email, for_update=True) @@ -97,19 +103,33 @@ async def register( "/tasks/{id}/assignment", description="Start an assignment within the task for the annotator", ) -async def create_assignment( +def create_assignment( data: AssignmentRequest, project_id: str = Path(alias="id"), signature: str = Header(description="Calling service signature"), ) -> TaskResponse: - await validate_human_app_signature(signature) + # Declare this endpoint as sync as it uses a lot of blocking IO (db, manifest, escrow, CVAT) + + run_as_sync(validate_human_app_signature, signature) - try: - assignment_id = oracle_service.create_assignment( - project_id=project_id, wallet_address=data.wallet_address + attempt = 0 + max_attempts = 10 + while attempt < max_attempts: + try: + assignment_id = oracle_service.create_assignment( + project_id=project_id, wallet_address=data.wallet_address + ) + break + except oracle_service.UserHasUnfinishedAssignmentError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e + except db_errors.LockNotAvailable: + attempt += 1 + + if attempt >= max_attempts: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail="Too many requests at the moment, please try again later", ) - except oracle_service.UserHasUnfinishedAssignmentError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e if not assignment_id: raise HTTPException( diff --git a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py index 1c4948d903..a346eddcc6 100644 --- a/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py +++ b/packages/examples/cvat/exchange-oracle/src/endpoints/middleware.py @@ -2,6 +2,8 @@ import time from typing import Any, Callable +import fastapi +import packaging.version as pv from fastapi import FastAPI, Request, Response from fastapi.responses import HTMLResponse, StreamingResponse from pyinstrument import Profiler @@ -58,14 +60,32 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): """ + @staticmethod + async def _set_body(request: Request, body: bytes): + # Before FastAPI 0.108.0 infinite hang is expected, + # if request body is awaited more than once. + # It's not needed when using FastAPI >= 0.108.0. + # https://github.com/tiangolo/fastapi/discussions/8187#discussioncomment-7962889 + if pv.parse(fastapi.__version__) >= pv.Version("0.108.0"): + return + + async def receive(): + return {"type": "http.request", "body": body} + + request._receive = receive + def __init__(self, app: FastAPI) -> None: super().__init__(app) self.logger = get_root_logger() + self.max_displayed_body_size = 200 + async def dispatch(self, request: Request, call_next: Callable) -> Response: logging_dict: dict[str, Any] = {} - await request.body() + body = await request.body() + await self._set_body(request, body) + response, response_dict = await self._log_response(call_next, request) request_dict = await self._log_request(request) logging_dict["request"] = request_dict @@ -97,10 +117,26 @@ async def _log_request(self, request: Request) -> dict[str, Any]: } try: - body = await request.json() + body = await request.body() + await self._set_body(request, body) except Exception: body = None else: + if body is not None: + raw_body = False + + if len(body) < self.max_displayed_body_size: + try: + body = json.loads(body) + except (json.JSONDecodeError, TypeError): + raw_body = True + else: + raw_body = True + + if raw_body: + body = body.decode(errors="ignore") + body = body[: self.max_displayed_body_size] + request_logging["body"] = body return request_logging diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py b/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py index be73b0b800..d78f88fae4 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/completed_escrows.py @@ -266,7 +266,7 @@ def _process_skeletons_from_boxes_escrows(self): except Exception as e: logger.error( "Failed to handle completed projects for escrow {}: {}".format( - escrow_address, e + completed_project.escrow_address, e ) ) continue diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py b/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py index b8d1c424b1..86d410aabb 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/error_handlers.py @@ -39,7 +39,9 @@ async def http_exception_handler(_, exc): @app.exception_handler(Exception) async def generic_exception_handler(_, exc: Exception): message = ( - "Something went wrong" if Config.environment != "development" else ".".join(exc.args) + "Something went wrong" + if Config.environment != "development" + else ".".join(map(str, exc.args)) ) return JSONResponse(content={"message": message}, status_code=500) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index 58fc471ef4..f4d3b986fc 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -127,6 +127,7 @@ def get_projects_by_status( status: ProjectStatuses, *, included_types: Optional[Sequence[TaskTypes]] = None, + task_status: Optional[TaskStatuses] = None, limit: int = 5, for_update: Union[bool, ForUpdateParams] = False, ) -> List[Project]: @@ -134,6 +135,9 @@ def get_projects_by_status( Project.status == status.value ) + if task_status: + projects = projects.where(Project.tasks.any(Task.status == task_status.value)) + if included_types is not None: projects = projects.where(Project.job_type.in_([t.value for t in included_types])) @@ -142,11 +146,9 @@ def get_projects_by_status( return projects -def get_available_projects( - session: Session, *, limit: int = 10, for_update: Union[bool, ForUpdateParams] = False -) -> List[Project]: +def get_available_projects(session: Session, *, limit: int = 10) -> List[Project]: return ( - _maybe_for_update(session.query(Project), enable=for_update) + session.query(Project) .where( (Project.status == ProjectStatuses.annotation.value) & Project.jobs.any( @@ -343,14 +345,29 @@ def get_tasks_by_cvat_id( def get_tasks_by_status( - session: Session, status: TaskStatuses, *, for_update: Union[bool, ForUpdateParams] = False + session: Session, + status: TaskStatuses, + *, + job_status: Optional[JobStatuses] = None, + project_status: Optional[ProjectStatuses] = None, + for_update: Union[bool, ForUpdateParams] = False, + limit: Optional[int] = 20, ) -> List[Task]: - return ( - _maybe_for_update(session.query(Task), enable=for_update) - .where(Task.status == status.value) - .all() + query = _maybe_for_update(session.query(Task), enable=for_update).where( + Task.status == status.value ) + if job_status: + query = query.where(Task.jobs.any(Job.status == job_status.value)) + + if project_status: + query = query.where(Task.project.has(Project.status == project_status.value)) + + if limit: + query = query.limit(limit) + + return query.all() + def update_task_status(session: Session, task_id: int, status: TaskStatuses) -> None: upd = update(Task).where(Task.id == task_id).values(status=status.value) @@ -404,6 +421,8 @@ def finish_data_uploads(session: Session, uploads: list[DataUpload]) -> None: # Job + + def create_job( session: Session, cvat_id: int, @@ -484,6 +503,27 @@ def count_jobs_by_escrow_address( ) +def get_free_job( + session: Session, + cvat_projects: List[int], + *, + for_update: Union[bool, ForUpdateParams] = False, +) -> Optional[Job]: + return ( + _maybe_for_update(session.query(Job), enable=for_update) + .where( + Job.cvat_project_id.in_(cvat_projects), + Job.status == JobStatuses.new, + ~Job.assignments.any( + (Assignment.status == AssignmentStatuses.created.value) + & (Assignment.completed_at == None) + & (utcnow() < Assignment.expires_at) + ), + ) + .first() + ) + + # Users @@ -644,6 +684,24 @@ def get_user_assignments_in_cvat_projects( ) +def count_active_user_assignments( + session: Session, + wallet_address: int, + cvat_projects: List[int], +) -> int: + return ( + session.query(Assignment) + .where( + Assignment.job.has(Job.cvat_project_id.in_(cvat_projects)), + Assignment.user_wallet_address == wallet_address, + Assignment.status == AssignmentStatuses.created.value, + Assignment.completed_at == None, + utcnow() < Assignment.expires_at, + ) + .count() + ) + + # Image def add_project_images(session: Session, cvat_project_id: int, filenames: List[str]) -> None: session.execute( diff --git a/packages/examples/cvat/exchange-oracle/src/services/exchange.py b/packages/examples/cvat/exchange-oracle/src/services/exchange.py index 9795cb5db9..d342f3e555 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/exchange.py +++ b/packages/examples/cvat/exchange-oracle/src/services/exchange.py @@ -2,10 +2,9 @@ from typing import Optional import src.cvat.api_calls as cvat_api -import src.models.cvat as models import src.services.cvat as cvat_service from src.chain.escrow import get_escrow_manifest -from src.core.types import AssignmentStatuses, JobStatuses, PlatformTypes, ProjectStatuses +from src.core.types import AssignmentStatuses, PlatformTypes, ProjectStatuses, TaskTypes from src.db import SessionLocal from src.schemas import exchange as service_api from src.utils.assignments import ( @@ -48,8 +47,7 @@ def serialize_task( title=f"Task {project.escrow_address[:10]}", description=manifest.annotation.description, job_bounty=manifest.job_bounty, - job_time_limit=manifest.annotation.max_time - or get_default_assignment_timeout(manifest.annotation.type), + job_time_limit=get_default_assignment_timeout(manifest.annotation.type), job_size=get_default_assignment_size(manifest), job_type=project.job_type, platform=PlatformTypes.CVAT, @@ -132,33 +130,20 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: ) return None - manifest = parse_manifest(get_escrow_manifest(project.chain_id, project.escrow_address)) - - unassigned_job: Optional[models.Job] = None - unfinished_assignments: list[models.Assignment] = [] - for job in project.jobs: - job_assignment = job.latest_assignment - if job_assignment and not job_assignment.is_finished: - unfinished_assignments.append(job_assignment) - - if ( - not unassigned_job - and job.status == JobStatuses.new - and (not job_assignment or job_assignment.is_finished) - ): - unassigned_job = job - - now = utcnow() - unfinished_user_assignments = [ - assignment - for assignment in unfinished_assignments - if assignment.user_wallet_address == wallet_address and now < assignment.expires_at - ] - if unfinished_user_assignments: + has_active_assignments = ( + cvat_service.count_active_user_assignments( + session, wallet_address=wallet_address, cvat_projects=[project.cvat_id] + ) + > 0 + ) + if has_active_assignments: raise UserHasUnfinishedAssignmentError( "The user already has an unfinished assignment in this project" ) + unassigned_job = cvat_service.get_free_job( + session, cvat_projects=[project.cvat_id], for_update=True + ) if not unassigned_job: return None @@ -166,16 +151,26 @@ def create_assignment(project_id: int, wallet_address: str) -> Optional[str]: session, wallet_address=user.wallet_address, cvat_job_id=unassigned_job.cvat_id, - expires_at=now - + timedelta( - seconds=manifest.annotation.max_time - or get_default_assignment_timeout(manifest.annotation.type) - ), + expires_at=utcnow() + + timedelta(seconds=get_default_assignment_timeout(TaskTypes(project.job_type))), ) - cvat_api.clear_job_annotations(unassigned_job.cvat_id) - cvat_api.restart_job(unassigned_job.cvat_id) - cvat_api.update_job_assignee(unassigned_job.cvat_id, assignee_id=user.cvat_id) - # rollback is automatic within the transaction + # Need to save the values to use outside the transaction + unassigned_job_cvat_id = unassigned_job.cvat_id + user_cvat_id = user.cvat_id + + # Finish the transaction ASAP to release the locks acquired and unblock other clients. + + # It's possible that the following part is never completed. In this case the assignment + # will expire as usual after the assignment lifetime, even if not canceled here. + try: + with cvat_api.api_client_context(cvat_api.get_api_client()): + cvat_api.clear_job_annotations(unassigned_job_cvat_id) + cvat_api.restart_job(unassigned_job_cvat_id, assignee_id=user_cvat_id) + except Exception: + with SessionLocal.begin() as session: + cvat_service.update_assignment( + session, assignment_id, status=AssignmentStatuses.canceled + ) return assignment_id diff --git a/packages/examples/cvat/exchange-oracle/src/utils/concurrency.py b/packages/examples/cvat/exchange-oracle/src/utils/concurrency.py new file mode 100644 index 0000000000..4d2676c7d7 --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/utils/concurrency.py @@ -0,0 +1,36 @@ +from functools import partial + +from anyio import from_thread, to_thread + + +def _check_backend(): + import fastapi.concurrency + + assert hasattr(fastapi.concurrency, "anyio") + + +async def fastapi_set_max_threads(max_threads: int): + """ + Sets the maximum number of active threads in the sync worker pool of FastAPI. + This affects the maximum number of active blocking requests + (the endpoints defined as non-async def ...) in each process. + + """ + _check_backend() + + # https://anyio.readthedocs.io/en/stable/threads.html#adjusting-the-default-maximum-worker-thread-count + to_thread.current_default_thread_limiter().total_tokens = max_threads + + +def run_as_sync(async_fn, *args, **kwargs): + """ + Runs an async function synchronously. + Supposed to be called in blocking endpoints (defined as def ...) + """ + _check_backend() + + if args or kwargs: + async_fn = partial(async_fn, *args, **kwargs) + + with from_thread.start_blocking_portal() as portal: + return portal.call(async_fn) diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py index 809e523c9e..9eba05e767 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_exchange_api.py @@ -309,8 +309,7 @@ def test_create_assignment_200(client: TestClient) -> None: json={"wallet_address": user_address}, ) cvat_api.clear_job_annotations.assert_called_once() - cvat_api.restart_job.assert_called_once() - cvat_api.update_job_assignee.assert_called_once() + cvat_api.restart_job.assert_called_once_with(cvat_job_1.cvat_id, assignee_id=user.cvat_id) assert response.status_code == 200 db_assignment = session.query(Assignment).filter_by(user_wallet_address=user_address).first() diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py index b339ceef34..7741790418 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/services/test_exchange.py @@ -8,7 +8,7 @@ from pydantic import ValidationError import src.services.cvat as cvat_service -from src.core.types import AssignmentStatuses, PlatformTypes, ProjectStatuses +from src.core.types import AssignmentStatuses, JobStatuses, PlatformTypes, ProjectStatuses from src.db import SessionLocal from src.models.cvat import Assignment, User from src.schemas import exchange as service_api @@ -19,7 +19,12 @@ serialize_task, ) -from tests.utils.db_helper import create_project, create_project_task_and_job +from tests.utils.db_helper import ( + create_job, + create_project, + create_project_task_and_job, + create_task, +) class ServiceIntegrationTest(unittest.TestCase): @@ -219,14 +224,60 @@ def test_create_assignment(self): ): manifest = json.load(data) mock_get_manifest.return_value = manifest - assingment_id = create_assignment(cvat_project_1.id, user_address) + assignment_id = create_assignment(cvat_project_1.id, user_address) - assignment = self.session.query(Assignment).filter_by(id=assingment_id).first() + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() self.assertEqual(assignment.cvat_job_id, cvat_job_1.cvat_id) self.assertEqual(assignment.user_wallet_address, user_address) self.assertEqual(assignment.status, AssignmentStatuses.created) + def test_create_assignment_many_jobs_1_completed(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + cvat_job_1.status = JobStatuses.completed.value + + cvat_task_2 = create_task(self.session, 2, cvat_project.cvat_id) + cvat_job_2 = create_job(self.session, 2, cvat_task_2.cvat_id, cvat_project.cvat_id) + + user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + + now = datetime.now() + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address, + cvat_job_id=cvat_job_1.cvat_id, + created_at=now - timedelta(hours=1), + completed_at=now - timedelta(minutes=40), + expires_at=datetime.now() + timedelta(days=1), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project.id, user_address) + + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() + + self.assertEqual(assignment.cvat_job_id, cvat_job_2.cvat_id) + self.assertEqual(assignment.user_wallet_address, user_address) + self.assertEqual(assignment.status, AssignmentStatuses.created) + def test_create_assignment_invalid_user_address(self): cvat_project_1, _, _ = create_project_task_and_job( self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 @@ -280,3 +331,135 @@ def test_create_assignment_unfinished_assignment(self): with self.assertRaises(HTTPException): create_assignment("1", user_address) + + def test_create_assignment_no_available_jobs_completed_assignment(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + cvat_job_1.status = JobStatuses.completed.value + + user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user1 = User( + wallet_address=user_address1, + cvat_email="test1@hmt.ai", + cvat_id=1, + ) + self.session.add(user1) + + user_address2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + user2 = User( + wallet_address=user_address2, + cvat_email="test2@hmt.ai", + cvat_id=2, + ) + self.session.add(user2) + + now = datetime.now() + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address1, + cvat_job_id=cvat_job_1.cvat_id, + created_at=now - timedelta(days=1), + completed_at=now - timedelta(hours=22), + expires_at=now + timedelta(hours=2), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project.id, user_address2) + + self.assertEqual(assignment_id, None) + + def test_create_assignment_no_available_jobs_active_foreign_assignment(self): + cvat_project, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + + user_address1 = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user1 = User( + wallet_address=user_address1, + cvat_email="test1@hmt.ai", + cvat_id=1, + ) + self.session.add(user1) + + user_address2 = "0x86e83d346041E8806e352681f3F14549C0d2BC70" + user2 = User( + wallet_address=user_address2, + cvat_email="test2@hmt.ai", + cvat_id=2, + ) + self.session.add(user2) + + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address1, + cvat_job_id=cvat_job_1.cvat_id, + expires_at=datetime.now() + timedelta(days=1), + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project.id, user_address2) + + self.assertEqual(assignment_id, None) + + def test_create_assignment_in_validated_and_rejected_job(self): + cvat_project_1, _, cvat_job_1 = create_project_task_and_job( + self.session, "0x86e83d346041E8806e352681f3F14549C0d2BC67", 1 + ) + cvat_job_1.status = JobStatuses.new.value # validated and rejected return to 'new' + + user_address = "0x86e83d346041E8806e352681f3F14549C0d2BC69" + user = User( + wallet_address=user_address, + cvat_email="test@hmt.ai", + cvat_id=1, + ) + self.session.add(user) + + now = datetime.now() + assignment = Assignment( + id=str(uuid.uuid4()), + user_wallet_address=user_address, + cvat_job_id=cvat_job_1.cvat_id, + created_at=now - timedelta(hours=1), + completed_at=now - timedelta(minutes=40), + expires_at=datetime.now() + timedelta(days=1), + status=AssignmentStatuses.completed.value, + ) + self.session.add(assignment) + + self.session.commit() + + with ( + open("tests/utils/manifest.json") as data, + patch("src.services.exchange.get_escrow_manifest") as mock_get_manifest, + patch("src.services.exchange.cvat_api"), + ): + manifest = json.load(data) + mock_get_manifest.return_value = manifest + assignment_id = create_assignment(cvat_project_1.id, user_address) + + assignment = self.session.query(Assignment).filter_by(id=assignment_id).first() + + self.assertEqual(assignment.cvat_job_id, cvat_job_1.cvat_id) + self.assertEqual(assignment.user_wallet_address, user_address) + self.assertEqual(assignment.status, AssignmentStatuses.created) diff --git a/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py b/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py index 8a5a81e9cc..79103a0064 100644 --- a/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py +++ b/packages/examples/cvat/exchange-oracle/tests/utils/db_helper.py @@ -12,7 +12,7 @@ def create_project( cvat_id: int, *, status: ProjectStatuses = ProjectStatuses.annotation, -) -> tuple: +) -> Project: cvat_project = Project( id=str(uuid.uuid4()), cvat_id=cvat_id, @@ -28,13 +28,15 @@ def create_project( return cvat_project -def create_project_and_task(session: Session, escrow_address: str, cvat_id: int) -> tuple: +def create_project_and_task( + session: Session, escrow_address: str, cvat_id: int +) -> tuple[Project, Task]: cvat_project = create_project(session, escrow_address, cvat_id) cvat_task = create_task(session, cvat_project_id=cvat_project.cvat_id, cvat_id=cvat_id) return cvat_project, cvat_task -def create_task(session: Session, cvat_id: int, cvat_project_id: str) -> tuple: +def create_task(session: Session, cvat_id: int, cvat_project_id: str) -> Task: cvat_task = Task( id=str(uuid.uuid4()), cvat_id=cvat_id, @@ -46,7 +48,9 @@ def create_task(session: Session, cvat_id: int, cvat_project_id: str) -> tuple: return cvat_task -def create_project_task_and_job(session: Session, escrow_address: str, cvat_id: int) -> tuple: +def create_project_task_and_job( + session: Session, escrow_address: str, cvat_id: int +) -> tuple[Project, Task, Job]: cvat_project, cvat_task = create_project_and_task(session, escrow_address, cvat_id) cvat_job = create_job( session, @@ -57,12 +61,12 @@ def create_project_task_and_job(session: Session, escrow_address: str, cvat_id: return cvat_project, cvat_task, cvat_job -def create_job(session: Session, cvat_id: int, cvat_task_id: int, cvat_project_id: int) -> tuple: +def create_job(session: Session, cvat_id: int, cvat_task_id: int, cvat_project_id: int) -> Job: cvat_job = Job( id=str(uuid.uuid4()), cvat_id=cvat_id, - cvat_project_id=cvat_id, - cvat_task_id=cvat_id, + cvat_project_id=cvat_project_id, + cvat_task_id=cvat_task_id, status=JobStatuses.new, ) session.add(cvat_job)