From be6646e5985e1bbb22d2e833b2d7269f50e1c8ef Mon Sep 17 00:00:00 2001 From: jhcipar Date: Fri, 19 Sep 2025 17:12:10 -0400 Subject: [PATCH 1/2] feat: update endpoints if config has changed changed default resource "identifier" to be a resource id that includes an input config name not generated from config args so a logical resource isn't defined by its config (so we can change config for the same resource) added a "resource hash" to the base sls resource class so we can detect changes to input config args and update resources in place instead of redeploying adds update methods to serverless resources changed deploy method to add platform-related state (eg durable resource ids) back to pickled state and config objects at runtime so we can fetch and interact with runpod sls endpoints created via tetra add update template methods to sls resource so we can update template-only variables via gql (eg env vars) changed the defaults for some sls resource configs to reflect existing defaults in runpod add update path to resource manager class when existing and new config have differnt resource hashes changed the behavior of sync gpu and gpuIds fields because there was a bug where gpus would always get created and pickled as the ANY gpu group --- src/tetra_rp/core/api/runpod.py | 86 +++++++++++++ src/tetra_rp/core/resources/base.py | 34 ++++- src/tetra_rp/core/resources/network_volume.py | 4 + .../core/resources/resource_manager.py | 20 +++ src/tetra_rp/core/resources/serverless.py | 118 ++++++++++++++++-- src/tetra_rp/core/resources/template.py | 1 + tests/unit/resources/test_serverless.py | 5 +- 7 files changed, 254 insertions(+), 14 deletions(-) diff --git a/src/tetra_rp/core/api/runpod.py b/src/tetra_rp/core/api/runpod.py index af1754d0..2b1ea5e6 100644 --- a/src/tetra_rp/core/api/runpod.py +++ b/src/tetra_rp/core/api/runpod.py @@ -110,6 +110,7 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: instanceIds activeBuildid idePodId + templateId } } """ @@ -132,6 +133,91 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: return endpoint_data + async def update_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + mutation = """ + mutation saveEndpoint($input: EndpointInput!) { + saveEndpoint(input: $input) { + aiKey + gpuIds + id + idleTimeout + locations + name + networkVolumeId + scalerType + scalerValue + templateId + type + userId + version + workersMax + workersMin + workersStandby + workersPFBTarget + gpuCount + allowedCudaVersions + executionTimeoutMs + instanceIds + activeBuildid + idePodId + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Updating endpoint with GraphQL: {input_data.get('name', 'unnamed')}" + ) + + result = await self._execute_graphql(mutation, variables) + + if "saveEndpoint" not in result: + raise Exception("Unexpected GraphQL response structure") + + endpoint_data = result["saveEndpoint"] + log.info( + f"Updated endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}" + ) + + return endpoint_data + + async def update_template(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + mutation = """ + mutation saveTemplate($input: SaveTemplateInput) { + saveTemplate(input: $input) { + id + containerDiskInGb + dockerArgs + env { + key + value + } + imageName + name + readme + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Updating template with GraphQL: {input_data.get('name', 'unnamed')}" + ) + + result = await self._execute_graphql(mutation, variables) + + if "saveTemplate" not in result: + raise Exception("Unexpected GraphQL response structure") + + template_data = result["saveTemplate"] + log.info( + f"Updated template: {template_data.get('id', 'unknown')} - {template_data.get('name', 'unnamed')}" + ) + + return template_data + async def get_cpu_types(self) -> Dict[str, Any]: """Get available CPU types.""" query = """ diff --git a/src/tetra_rp/core/resources/base.py b/src/tetra_rp/core/resources/base.py index 2e2e28c6..c5fe36c0 100644 --- a/src/tetra_rp/core/resources/base.py +++ b/src/tetra_rp/core/resources/base.py @@ -1,8 +1,7 @@ import hashlib from abc import ABC, abstractmethod -from typing import Optional -from pydantic import BaseModel, ConfigDict - +from typing import Optional, ClassVar +from pydantic import BaseModel, ConfigDict, computed_field class BaseResource(BaseModel): """Base class for all resources.""" @@ -14,15 +13,35 @@ class BaseResource(BaseModel): ) id: Optional[str] = None + _hashed_fields: ClassVar[set] = set() + + # diffed fields is a temporary holder for fields that are "out of sync" - + # where a local instance representation of an endpoint is not up to date with the remote resource. + # it's needed for determining how updates are applied (eg, if we need to update a pod template) + fields_to_update: set[str] = set() + + @computed_field @property - def resource_id(self) -> str: + def resource_hash(self) -> str: """Unique resource ID based on configuration.""" resource_type = self.__class__.__name__ - config_str = self.model_dump_json(exclude_none=True) + # don't self reference and exclude any deployment state (eg id) + config_str = self.model_dump_json(include=self.__class__._hashed_fields) hash_obj = hashlib.md5(f"{resource_type}:{config_str}".encode()) return f"{resource_type}_{hash_obj.hexdigest()}" + @property + def resource_id(self) -> str: + """Logical Tetra resource id defined by resource type and name. + Distinct from a server-side Runpod id. + """ + resource_type = self.__class__.__name__ + # TODO: eventually we could namespace this to user ids or team ids + if not self.name: + self.name = "unnamed" + return f"{resource_type}_{self.name}" + class DeployableResource(BaseResource, ABC): """Base class for deployable resources.""" @@ -45,3 +64,8 @@ def is_deployed(self) -> bool: async def deploy(self) -> "DeployableResource": """Deploy the resource.""" raise NotImplementedError("Subclasses should implement this method.") + + @abstractmethod + async def update(self) -> "DeployableResource": + """Update the resource.""" + raise NotImplementedError("Subclasses should implement this method.") diff --git a/src/tetra_rp/core/resources/network_volume.py b/src/tetra_rp/core/resources/network_volume.py index 240fa584..84b0dc13 100644 --- a/src/tetra_rp/core/resources/network_volume.py +++ b/src/tetra_rp/core/resources/network_volume.py @@ -124,6 +124,10 @@ async def _create_new_volume(self, client) -> "NetworkVolume": raise ValueError("Deployment failed, no volume was created.") + async def update(self) -> "DeployableResource": + # TODO: impl + return self + async def deploy(self) -> "DeployableResource": """ Deploys the network volume resource using the provided configuration. diff --git a/src/tetra_rp/core/resources/resource_manager.py b/src/tetra_rp/core/resources/resource_manager.py index 93204456..5862d841 100644 --- a/src/tetra_rp/core/resources/resource_manager.py +++ b/src/tetra_rp/core/resources/resource_manager.py @@ -100,6 +100,7 @@ async def get_or_deploy_resource( async with resource_lock: # Double-check pattern: check again inside the lock if existing := self._resources.get(uid): + # if the old resource isn't actually deployed, then we can just deploy the new one if not existing.is_deployed(): log.warning(f"{existing} is no longer valid, redeploying.") self.remove_resource(uid) @@ -109,6 +110,25 @@ async def get_or_deploy_resource( self.add_resource(uid, deployed_resource) return deployed_resource + # if the old resource is actually deployed, then we need to update it + if existing.resource_hash != config.resource_hash: + log.info(f"change in resource configuration detected, updating resource.") + for field in existing.__class__._hashed_fields: + existing_value, new_value = getattr(existing, field), getattr(config, field) + if existing_value != new_value: + log.debug(f"field: {field}, existing value: {getattr(existing, field)}, new value: {getattr(config, field)}") + config.fields_to_update.add(field) + + # there are some fields that should be stored in pickled state and should be loaded back to the new obj + # these are used to make updates to platform endpoints/resources + # TODO: clean this up + await config.sync_config_with_deployed_resource(existing) + deployed_resource = await config.update() + self.remove_resource(uid) + self.add_resource(uid, deployed_resource) + return deployed_resource + + # otherwise, nothing has changed and we just return what we already have log.debug(f"{existing} exists, reusing.") log.info(f"URL: {existing.url}") return existing diff --git a/src/tetra_rp/core/resources/serverless.py b/src/tetra_rp/core/resources/serverless.py index 8409f46f..30c29778 100644 --- a/src/tetra_rp/core/resources/serverless.py +++ b/src/tetra_rp/core/resources/serverless.py @@ -61,6 +61,7 @@ class ServerlessResource(DeployableResource): Base class for GPU serverless resource """ + # Fields marked as _input_only are excluded from gql requests to make the client impl simpler _input_only = { "id", "cudaVersions", @@ -70,6 +71,29 @@ class ServerlessResource(DeployableResource): "flashboot", "imageName", "networkVolume", + "resource_hash", + "fields_to_update", + } + + # hashed fields are fields that define configuration of an object. they are used for computing + # if a resource has changed and should only be mutable fields from the perspective of the platform. + # does not account for platform (Runpod) state fields (eg endpoint id) right now. + _hashed_fields = { + "datacenter", + "env", + "gpuIds", + "networkVolume", + "executionTimeoutMs", + "gpuCount", + "locations", + "name", + "networkVolumeId", + "scalerType", + "scalerValue", + "workersMax", + "workersMin", + "workersPFBTarget", + "allowedCudaVersions", } # === Input-only Fields === @@ -82,7 +106,7 @@ class ServerlessResource(DeployableResource): datacenter: DataCenter = Field(default=DataCenter.EU_RO_1) # === Input Fields === - executionTimeoutMs: Optional[int] = None + executionTimeoutMs: Optional[int] = 0 gpuCount: Optional[int] = 1 idleTimeout: Optional[int] = 5 locations: Optional[str] = None @@ -93,12 +117,12 @@ class ServerlessResource(DeployableResource): templateId: Optional[str] = None workersMax: Optional[int] = 3 workersMin: Optional[int] = 0 - workersPFBTarget: Optional[int] = None + workersPFBTarget: Optional[int] = 0 # === Runtime Fields === activeBuildid: Optional[str] = None aiKey: Optional[str] = None - allowedCudaVersions: Optional[str] = None + allowedCudaVersions: str = "" computeType: Optional[str] = None createdAt: Optional[str] = None # TODO: use datetime gpuIds: Optional[str] = "" @@ -143,7 +167,7 @@ def validate_gpus(cls, value: List[GpuGroup]) -> List[GpuGroup]: @model_validator(mode="after") def sync_input_fields(self): """Sync between temporary inputs and exported fields""" - if self.flashboot: + if self.flashboot and not self.name.endswith("-fb"): self.name += "-fb" # Sync datacenter to locations field for API @@ -167,7 +191,10 @@ def sync_input_fields(self): def _sync_input_fields_gpu(self): # GPU-specific fields - if self.gpus: + # the response from the api for gpus is none + # apply this path only if gpuIds is None, otherwise we overwrite gpuIds + # with ANY gpu because the default for gpus is any + if self.gpus and not self.gpuIds: # Convert gpus list to gpuIds string self.gpuIds = ",".join(gpu.value for gpu in self.gpus) elif self.gpuIds: @@ -199,6 +226,43 @@ async def _ensure_network_volume_deployed(self) -> None: deployedNetworkVolume = await self.networkVolume.deploy() self.networkVolumeId = deployedNetworkVolume.id + async def _sync_graphql_object_with_inputs(self, returned_endpoint: "ServerlessResource"): + for _input_field in self._input_only: + if _input_field not in ["resource_hash"] and getattr(self, _input_field) is not None: + # sync input only fields stripped from gql request back to endpoint + setattr(returned_endpoint, _input_field, getattr(self, _input_field)) + + # assigning template info back to the object is needed for updating it in the future + returned_endpoint.template = self.template + if returned_endpoint.template: + returned_endpoint.template.id = returned_endpoint.templateId + + return returned_endpoint + + async def sync_config_with_deployed_resource(self, existing: "ServerlessResource") -> None: + self.id = existing.id + if not existing.template: + raise ValueError("Existing resource does not have a template, this is an invalid state. Update resources and try again") + self.template.id = existing.template.id + + async def _update_template(self) -> "DeployableResource": + if not self.template: + raise ValueError("Tried to update a template that doesn't exist. Redeploy endpoint or attach a template to it") + + try: + async with RunpodGraphQLClient() as client: + payload = self.template.model_dump(exclude={"resource_hash", "fields_to_update"}, exclude_none=True) + result = await client.update_template(payload) + if template := self.template.__class__(**result): + return template + + raise ValueError("Deployment failed, no endpoint was returned.") + + except Exception as e: + log.error(f"{self} failed to update: {e}") + raise + + def is_deployed(self) -> bool: """ Checks if the serverless resource is deployed and available. @@ -228,11 +292,17 @@ async def deploy(self) -> "DeployableResource": await self._ensure_network_volume_deployed() async with RunpodGraphQLClient() as client: - payload = self.model_dump(exclude=self._input_only, exclude_none=True) + # some "input only" fields are specific to tetra and not used in gql + exclude = { + f: ... for f in self._input_only} | {"template": {"resource_hash", "fields_to_update", "volumeInGb"} + } # TODO: maybe include this as a class attr + payload = self.model_dump(exclude=exclude, exclude_none=True) result = await client.create_endpoint(payload) + # we need to merge the returned fields from gql with what the inputs are here if endpoint := self.__class__(**result): - return endpoint + endpoint = await self._sync_graphql_object_with_inputs(endpoint) + return endpoint raise ValueError("Deployment failed, no endpoint was returned.") @@ -240,6 +310,40 @@ async def deploy(self) -> "DeployableResource": log.error(f"{self} failed to deploy: {e}") raise + async def update(self) -> "DeployableResource": + # check if we need to update the template + # only update if the template exists already and there are fields to update for it + if self.template and self.fields_to_update & set(self.template.model_fields): + # we need to add the template id back here from hydrated state + log.debug(f"loaded template to update: {self.template.model_dump()}") + template = await self._update_template() + self.template = template + + # if the only fields that need updated are template-only, just return now + if self.fields_to_update ^ set(template.model_fields): + log.debug("template-only update to endpoint complete") + return self + + try: + async with RunpodGraphQLClient() as client: + exclude = {f: ... for f in self._input_only} | {"template": {"resource_hash"}} # TODO: maybe include this as a class attr + # we need to include the id here so we update the existing endpoint + del exclude["id"] + payload = self.model_dump(exclude=exclude, exclude_none=True) + result = await client.update_endpoint(payload) + + if endpoint := self.__class__(**result): + # TODO: should we check that the returned id = the input? + # we could "soft fail" and notify the user if we fall back to making a new endpoint + endpoint = await self._sync_graphql_object_with_inputs(endpoint) + return endpoint + + raise ValueError("Update failed, no endpoint was returned.") + + except Exception as e: + log.error(f"{self} failed to update: {e}") + raise + async def run_sync(self, payload: Dict[str, Any]) -> "JobOutput": """ Executes a serverless endpoint request with the payload. diff --git a/src/tetra_rp/core/resources/template.py b/src/tetra_rp/core/resources/template.py index a4c0a254..bbfb38c5 100644 --- a/src/tetra_rp/core/resources/template.py +++ b/src/tetra_rp/core/resources/template.py @@ -30,6 +30,7 @@ class PodTemplate(BaseResource): name: Optional[str] = "" ports: Optional[str] = "" startScript: Optional[str] = "" + volumeInGb: Optional[int] = 20 @model_validator(mode="after") def sync_input_fields(self): diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index eb1a780e..2a6aaf2c 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -425,9 +425,10 @@ def deployment_response(self): return { "id": "endpoint-123", "name": "test-serverless-fb", - "gpuIds": "RTX4090", + "gpuIds": "ADA_24", "allowedCudaVersions": "12.1", "networkVolumeId": "vol-456", + "templateId": "abc", } def test_is_deployed_false_when_no_id(self): @@ -484,7 +485,7 @@ async def test_deploy_success_with_network_volume( assert result.id == "endpoint-123" # The returned object gets the name from the API response, which gets processed again # result is a DeployableResource, so we need to cast it - assert hasattr(result, "name") and result.name == "test-serverless-fb-fb" + assert hasattr(result, "name") and result.name == "test-serverless-fb" # Verify locations was set from datacenter assert hasattr(result, "locations") and result.locations == "EU-RO-1" From 1c585f906735079c98b6ffbd72e4327fdc60f1ad Mon Sep 17 00:00:00 2001 From: jhcipar Date: Fri, 19 Sep 2025 17:45:37 -0400 Subject: [PATCH 2/2] fix: exclude additional template fields for endpoint update path --- src/tetra_rp/core/resources/serverless.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tetra_rp/core/resources/serverless.py b/src/tetra_rp/core/resources/serverless.py index 30c29778..0e92b72a 100644 --- a/src/tetra_rp/core/resources/serverless.py +++ b/src/tetra_rp/core/resources/serverless.py @@ -326,7 +326,7 @@ async def update(self) -> "DeployableResource": try: async with RunpodGraphQLClient() as client: - exclude = {f: ... for f in self._input_only} | {"template": {"resource_hash"}} # TODO: maybe include this as a class attr + exclude = {f: ... for f in self._input_only} | {"template": {"resource_hash", "fields_to_update", "volumeInGb", "id"}} # TODO: maybe include this as a class attr # we need to include the id here so we update the existing endpoint del exclude["id"] payload = self.model_dump(exclude=exclude, exclude_none=True)