diff --git a/src/gpuhunt/providers/hotaisle.py b/src/gpuhunt/providers/hotaisle.py index 22c4893..34a768a 100644 --- a/src/gpuhunt/providers/hotaisle.py +++ b/src/gpuhunt/providers/hotaisle.py @@ -1,12 +1,12 @@ import logging import os -from typing import Optional +from typing import Optional, TypedDict, cast import requests from requests import Response from gpuhunt._internal.constraints import find_accelerators -from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem +from gpuhunt._internal.models import AcceleratorVendor, JSONObject, QueryFilter, RawCatalogItem from gpuhunt.providers import AbstractProvider logger = logging.getLogger(__name__) @@ -53,6 +53,10 @@ def _make_request(self, method: str, url: str) -> Response: return response +class HotAisleCatalogItemProviderData(TypedDict): + vm_specs: JSONObject + + def get_gpu_memory(gpu_name: str) -> Optional[float]: if accelerators := find_accelerators(names=[gpu_name], vendors=[AcceleratorVendor.AMD]): return float(accelerators[0].memory) @@ -96,6 +100,14 @@ def convert_response_to_raw_catalog_items(response: Response) -> list[RawCatalog gpu_vendor=gpu_vendor, spot=False, disk_size=disk_gb, + provider_data=cast( + JSONObject, + HotAisleCatalogItemProviderData( + # The specs object may duplicate some RawCatalogItem fields, but we store it in + # full because we need to pass it back to the API when creating VMs. + vm_specs=specs, + ), + ), ) offers.append(offer)