diff --git a/src/tetra_rp/core/api/runpod.py b/src/tetra_rp/core/api/runpod.py index af1754d0..2b1ea5e6 100644 --- a/src/tetra_rp/core/api/runpod.py +++ b/src/tetra_rp/core/api/runpod.py @@ -110,6 +110,7 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: instanceIds activeBuildid idePodId + templateId } } """ @@ -132,6 +133,91 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: return endpoint_data + async def update_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + mutation = """ + mutation saveEndpoint($input: EndpointInput!) { + saveEndpoint(input: $input) { + aiKey + gpuIds + id + idleTimeout + locations + name + networkVolumeId + scalerType + scalerValue + templateId + type + userId + version + workersMax + workersMin + workersStandby + workersPFBTarget + gpuCount + allowedCudaVersions + executionTimeoutMs + instanceIds + activeBuildid + idePodId + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Updating endpoint with GraphQL: {input_data.get('name', 'unnamed')}" + ) + + result = await self._execute_graphql(mutation, variables) + + if "saveEndpoint" not in result: + raise Exception("Unexpected GraphQL response structure") + + endpoint_data = result["saveEndpoint"] + log.info( + f"Updated endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}" + ) + + return endpoint_data + + async def update_template(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + mutation = """ + mutation saveTemplate($input: SaveTemplateInput) { + saveTemplate(input: $input) { + id + containerDiskInGb + dockerArgs + env { + key + value + } + imageName + name + readme + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Updating template with GraphQL: {input_data.get('name', 'unnamed')}" + ) + + result = await self._execute_graphql(mutation, variables) + + if "saveTemplate" not in result: + raise Exception("Unexpected GraphQL response structure") + + template_data = result["saveTemplate"] + log.info( + f"Updated template: {template_data.get('id', 'unknown')} - {template_data.get('name', 'unnamed')}" + ) + + return template_data + async def get_cpu_types(self) -> Dict[str, Any]: """Get available CPU types.""" query = """ diff --git a/src/tetra_rp/core/resources/base.py b/src/tetra_rp/core/resources/base.py index 2e2e28c6..c5fe36c0 100644 --- a/src/tetra_rp/core/resources/base.py +++ b/src/tetra_rp/core/resources/base.py @@ -1,8 +1,7 @@ import hashlib from abc import ABC, abstractmethod -from typing import Optional -from pydantic import BaseModel, ConfigDict - +from typing import Optional, ClassVar +from pydantic import BaseModel, ConfigDict, computed_field class BaseResource(BaseModel): """Base class for all resources.""" @@ -14,15 +13,35 @@ class BaseResource(BaseModel): ) id: Optional[str] = None + _hashed_fields: ClassVar[set] = set() + + # diffed fields is a temporary holder for fields that are "out of sync" - + # where a local instance representation of an endpoint is not up to date with the remote resource. + # it's needed for determining how updates are applied (eg, if we need to update a pod template) + fields_to_update: set[str] = set() + + @computed_field @property - def resource_id(self) -> str: + def resource_hash(self) -> str: """Unique resource ID based on configuration.""" resource_type = self.__class__.__name__ - config_str = self.model_dump_json(exclude_none=True) + # don't self reference and exclude any deployment state (eg id) + config_str = self.model_dump_json(include=self.__class__._hashed_fields) hash_obj = hashlib.md5(f"{resource_type}:{config_str}".encode()) return f"{resource_type}_{hash_obj.hexdigest()}" + @property + def resource_id(self) -> str: + """Logical Tetra resource id defined by resource type and name. + Distinct from a server-side Runpod id. + """ + resource_type = self.__class__.__name__ + # TODO: eventually we could namespace this to user ids or team ids + if not self.name: + self.name = "unnamed" + return f"{resource_type}_{self.name}" + class DeployableResource(BaseResource, ABC): """Base class for deployable resources.""" @@ -45,3 +64,8 @@ def is_deployed(self) -> bool: async def deploy(self) -> "DeployableResource": """Deploy the resource.""" raise NotImplementedError("Subclasses should implement this method.") + + @abstractmethod + async def update(self) -> "DeployableResource": + """Update the resource.""" + raise NotImplementedError("Subclasses should implement this method.") diff --git a/src/tetra_rp/core/resources/network_volume.py b/src/tetra_rp/core/resources/network_volume.py index 240fa584..84b0dc13 100644 --- a/src/tetra_rp/core/resources/network_volume.py +++ b/src/tetra_rp/core/resources/network_volume.py @@ -124,6 +124,10 @@ async def _create_new_volume(self, client) -> "NetworkVolume": raise ValueError("Deployment failed, no volume was created.") + async def update(self) -> "DeployableResource": + # TODO: impl + return self + async def deploy(self) -> "DeployableResource": """ Deploys the network volume resource using the provided configuration. diff --git a/src/tetra_rp/core/resources/resource_manager.py b/src/tetra_rp/core/resources/resource_manager.py index 93204456..5862d841 100644 --- a/src/tetra_rp/core/resources/resource_manager.py +++ b/src/tetra_rp/core/resources/resource_manager.py @@ -100,6 +100,7 @@ async def get_or_deploy_resource( async with resource_lock: # Double-check pattern: check again inside the lock if existing := self._resources.get(uid): + # if the old resource isn't actually deployed, then we can just deploy the new one if not existing.is_deployed(): log.warning(f"{existing} is no longer valid, redeploying.") self.remove_resource(uid) @@ -109,6 +110,25 @@ async def get_or_deploy_resource( self.add_resource(uid, deployed_resource) return deployed_resource + # if the old resource is actually deployed, then we need to update it + if existing.resource_hash != config.resource_hash: + log.info(f"change in resource configuration detected, updating resource.") + for field in existing.__class__._hashed_fields: + existing_value, new_value = getattr(existing, field), getattr(config, field) + if existing_value != new_value: + log.debug(f"field: {field}, existing value: {getattr(existing, field)}, new value: {getattr(config, field)}") + config.fields_to_update.add(field) + + # there are some fields that should be stored in pickled state and should be loaded back to the new obj + # these are used to make updates to platform endpoints/resources + # TODO: clean this up + await config.sync_config_with_deployed_resource(existing) + deployed_resource = await config.update() + self.remove_resource(uid) + self.add_resource(uid, deployed_resource) + return deployed_resource + + # otherwise, nothing has changed and we just return what we already have log.debug(f"{existing} exists, reusing.") log.info(f"URL: {existing.url}") return existing diff --git a/src/tetra_rp/core/resources/serverless.py b/src/tetra_rp/core/resources/serverless.py index 8409f46f..0e92b72a 100644 --- a/src/tetra_rp/core/resources/serverless.py +++ b/src/tetra_rp/core/resources/serverless.py @@ -61,6 +61,7 @@ class ServerlessResource(DeployableResource): Base class for GPU serverless resource """ + # Fields marked as _input_only are excluded from gql requests to make the client impl simpler _input_only = { "id", "cudaVersions", @@ -70,6 +71,29 @@ class ServerlessResource(DeployableResource): "flashboot", "imageName", "networkVolume", + "resource_hash", + "fields_to_update", + } + + # hashed fields are fields that define configuration of an object. they are used for computing + # if a resource has changed and should only be mutable fields from the perspective of the platform. + # does not account for platform (Runpod) state fields (eg endpoint id) right now. + _hashed_fields = { + "datacenter", + "env", + "gpuIds", + "networkVolume", + "executionTimeoutMs", + "gpuCount", + "locations", + "name", + "networkVolumeId", + "scalerType", + "scalerValue", + "workersMax", + "workersMin", + "workersPFBTarget", + "allowedCudaVersions", } # === Input-only Fields === @@ -82,7 +106,7 @@ class ServerlessResource(DeployableResource): datacenter: DataCenter = Field(default=DataCenter.EU_RO_1) # === Input Fields === - executionTimeoutMs: Optional[int] = None + executionTimeoutMs: Optional[int] = 0 gpuCount: Optional[int] = 1 idleTimeout: Optional[int] = 5 locations: Optional[str] = None @@ -93,12 +117,12 @@ class ServerlessResource(DeployableResource): templateId: Optional[str] = None workersMax: Optional[int] = 3 workersMin: Optional[int] = 0 - workersPFBTarget: Optional[int] = None + workersPFBTarget: Optional[int] = 0 # === Runtime Fields === activeBuildid: Optional[str] = None aiKey: Optional[str] = None - allowedCudaVersions: Optional[str] = None + allowedCudaVersions: str = "" computeType: Optional[str] = None createdAt: Optional[str] = None # TODO: use datetime gpuIds: Optional[str] = "" @@ -143,7 +167,7 @@ def validate_gpus(cls, value: List[GpuGroup]) -> List[GpuGroup]: @model_validator(mode="after") def sync_input_fields(self): """Sync between temporary inputs and exported fields""" - if self.flashboot: + if self.flashboot and not self.name.endswith("-fb"): self.name += "-fb" # Sync datacenter to locations field for API @@ -167,7 +191,10 @@ def sync_input_fields(self): def _sync_input_fields_gpu(self): # GPU-specific fields - if self.gpus: + # the response from the api for gpus is none + # apply this path only if gpuIds is None, otherwise we overwrite gpuIds + # with ANY gpu because the default for gpus is any + if self.gpus and not self.gpuIds: # Convert gpus list to gpuIds string self.gpuIds = ",".join(gpu.value for gpu in self.gpus) elif self.gpuIds: @@ -199,6 +226,43 @@ async def _ensure_network_volume_deployed(self) -> None: deployedNetworkVolume = await self.networkVolume.deploy() self.networkVolumeId = deployedNetworkVolume.id + async def _sync_graphql_object_with_inputs(self, returned_endpoint: "ServerlessResource"): + for _input_field in self._input_only: + if _input_field not in ["resource_hash"] and getattr(self, _input_field) is not None: + # sync input only fields stripped from gql request back to endpoint + setattr(returned_endpoint, _input_field, getattr(self, _input_field)) + + # assigning template info back to the object is needed for updating it in the future + returned_endpoint.template = self.template + if returned_endpoint.template: + returned_endpoint.template.id = returned_endpoint.templateId + + return returned_endpoint + + async def sync_config_with_deployed_resource(self, existing: "ServerlessResource") -> None: + self.id = existing.id + if not existing.template: + raise ValueError("Existing resource does not have a template, this is an invalid state. Update resources and try again") + self.template.id = existing.template.id + + async def _update_template(self) -> "DeployableResource": + if not self.template: + raise ValueError("Tried to update a template that doesn't exist. Redeploy endpoint or attach a template to it") + + try: + async with RunpodGraphQLClient() as client: + payload = self.template.model_dump(exclude={"resource_hash", "fields_to_update"}, exclude_none=True) + result = await client.update_template(payload) + if template := self.template.__class__(**result): + return template + + raise ValueError("Deployment failed, no endpoint was returned.") + + except Exception as e: + log.error(f"{self} failed to update: {e}") + raise + + def is_deployed(self) -> bool: """ Checks if the serverless resource is deployed and available. @@ -228,11 +292,17 @@ async def deploy(self) -> "DeployableResource": await self._ensure_network_volume_deployed() async with RunpodGraphQLClient() as client: - payload = self.model_dump(exclude=self._input_only, exclude_none=True) + # some "input only" fields are specific to tetra and not used in gql + exclude = { + f: ... for f in self._input_only} | {"template": {"resource_hash", "fields_to_update", "volumeInGb"} + } # TODO: maybe include this as a class attr + payload = self.model_dump(exclude=exclude, exclude_none=True) result = await client.create_endpoint(payload) + # we need to merge the returned fields from gql with what the inputs are here if endpoint := self.__class__(**result): - return endpoint + endpoint = await self._sync_graphql_object_with_inputs(endpoint) + return endpoint raise ValueError("Deployment failed, no endpoint was returned.") @@ -240,6 +310,40 @@ async def deploy(self) -> "DeployableResource": log.error(f"{self} failed to deploy: {e}") raise + async def update(self) -> "DeployableResource": + # check if we need to update the template + # only update if the template exists already and there are fields to update for it + if self.template and self.fields_to_update & set(self.template.model_fields): + # we need to add the template id back here from hydrated state + log.debug(f"loaded template to update: {self.template.model_dump()}") + template = await self._update_template() + self.template = template + + # if the only fields that need updated are template-only, just return now + if self.fields_to_update ^ set(template.model_fields): + log.debug("template-only update to endpoint complete") + return self + + try: + async with RunpodGraphQLClient() as client: + exclude = {f: ... for f in self._input_only} | {"template": {"resource_hash", "fields_to_update", "volumeInGb", "id"}} # TODO: maybe include this as a class attr + # we need to include the id here so we update the existing endpoint + del exclude["id"] + payload = self.model_dump(exclude=exclude, exclude_none=True) + result = await client.update_endpoint(payload) + + if endpoint := self.__class__(**result): + # TODO: should we check that the returned id = the input? + # we could "soft fail" and notify the user if we fall back to making a new endpoint + endpoint = await self._sync_graphql_object_with_inputs(endpoint) + return endpoint + + raise ValueError("Update failed, no endpoint was returned.") + + except Exception as e: + log.error(f"{self} failed to update: {e}") + raise + async def run_sync(self, payload: Dict[str, Any]) -> "JobOutput": """ Executes a serverless endpoint request with the payload. diff --git a/src/tetra_rp/core/resources/template.py b/src/tetra_rp/core/resources/template.py index a4c0a254..bbfb38c5 100644 --- a/src/tetra_rp/core/resources/template.py +++ b/src/tetra_rp/core/resources/template.py @@ -30,6 +30,7 @@ class PodTemplate(BaseResource): name: Optional[str] = "" ports: Optional[str] = "" startScript: Optional[str] = "" + volumeInGb: Optional[int] = 20 @model_validator(mode="after") def sync_input_fields(self): diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index eb1a780e..2a6aaf2c 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -425,9 +425,10 @@ def deployment_response(self): return { "id": "endpoint-123", "name": "test-serverless-fb", - "gpuIds": "RTX4090", + "gpuIds": "ADA_24", "allowedCudaVersions": "12.1", "networkVolumeId": "vol-456", + "templateId": "abc", } def test_is_deployed_false_when_no_id(self): @@ -484,7 +485,7 @@ async def test_deploy_success_with_network_volume( assert result.id == "endpoint-123" # The returned object gets the name from the API response, which gets processed again # result is a DeployableResource, so we need to cast it - assert hasattr(result, "name") and result.name == "test-serverless-fb-fb" + assert hasattr(result, "name") and result.name == "test-serverless-fb" # Verify locations was set from datacenter assert hasattr(result, "locations") and result.locations == "EU-RO-1"