Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions .github/workflows/catalogs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ jobs:
path: azure.csv
retention-days: 1

catalog-datacrunch:
name: Collect DataCrunch catalog
catalog-verda:
name: Collect Verda catalog
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -81,17 +81,22 @@ jobs:
- name: Install dependencies
run: |
pip install pip -U
pip install -e '.[datacrunch]'
pip install -e '.[verda]'
- name: Collect catalog
working-directory: src
env:
DATACRUNCH_CLIENT_ID: ${{ secrets.DATACRUNCH_CLIENT_ID }}
DATACRUNCH_CLIENT_SECRET: ${{ secrets.DATACRUNCH_CLIENT_SECRET }}
run: python -m gpuhunt datacrunch --output ../datacrunch.csv
VERDA_CLIENT_ID: ${{ secrets.VERDA_CLIENT_ID }}
VERDA_CLIENT_SECRET: ${{ secrets.VERDA_CLIENT_SECRET }}
run: |
python -m gpuhunt verda --output ../verda.csv
# Copy for backward compatibility
cp ../verda.csv ../datacrunch.csv
- uses: actions/upload-artifact@v4
with:
name: catalogs-datacrunch
path: datacrunch.csv
name: catalogs-verda
path: |
verda.csv
datacrunch.csv
retention-days: 1

catalog-gcp:
Expand Down Expand Up @@ -252,7 +257,7 @@ jobs:
needs:
- catalog-aws
- catalog-azure
- catalog-datacrunch
- catalog-verda
- catalog-gcp
- catalog-lambdalabs
- catalog-nebius
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ jobs:
run: |
IGNORE=
if [[ "${{ matrix.python-version }}" == "3.9" ]]; then
IGNORE="--ignore src/gpuhunt/providers/nebius.py"
IGNORE="--ignore src/gpuhunt/providers/nebius.py --ignore src/gpuhunt/providers/verda.py"
fi
pytest --doctest-modules src/gpuhunt $IGNORE
- name: Run pytest
run: |
pytest src/tests
IGNORE=
if [[ "${{ matrix.python-version }}" == "3.9" ]]; then
IGNORE="--ignore src/tests/providers/test_verda.py"
fi
pytest src/tests $IGNORE
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ print(*items, sep="\n")
* Azure
* CloudRift
* Cudo Compute
* DataCrunch
* Verda
* GCP
* LambdaLabs
* Nebius
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ oci = [
"oci",
"pydantic>=1.10.10,<2.0.0",
]
datacrunch = [
"datacrunch"
verda = [
'verda',
]
all = ["gpuhunt[aws,azure,datacrunch,gcp,maybe_nebius,oci]"]
maybe_verda = [
'verda; python_version>="3.10"',
]
all = ["gpuhunt[aws,azure,maybe_verda,gcp,maybe_nebius,oci]"]
dev = [
"pre-commit",
"pytest~=7.0",
Expand Down
10 changes: 4 additions & 6 deletions src/gpuhunt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
"azure",
"cloudrift",
"cudo",
"datacrunch",
"verda",
"digitalocean",
"gcp",
"hotaisle",
Expand Down Expand Up @@ -49,12 +49,10 @@ def main():
from gpuhunt.providers.cloudrift import CloudRiftProvider

provider = CloudRiftProvider()
elif args.provider == "datacrunch":
from gpuhunt.providers.datacrunch import DataCrunchProvider
elif args.provider == "verda":
from gpuhunt.providers.verda import VerdaProvider

provider = DataCrunchProvider(
os.getenv("DATACRUNCH_CLIENT_ID"), os.getenv("DATACRUNCH_CLIENT_SECRET")
)
provider = VerdaProvider(os.getenv("VERDA_CLIENT_ID"), os.getenv("VERDA_CLIENT_SECRET"))
elif args.provider == "digitalocean":
from gpuhunt.providers.digitalocean import DigitalOceanProvider

Expand Down
2 changes: 1 addition & 1 deletion src/gpuhunt/_internal/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
OFFLINE_PROVIDERS = [
"aws",
"azure",
"datacrunch",
"verda",
"gcp",
"lambdalabs",
"nebius",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from collections.abc import Iterable
from typing import Optional

from datacrunch import DataCrunchClient
from datacrunch.instance_types.instance_types import InstanceType
from verda import VerdaClient
from verda.instance_types import InstanceType

from gpuhunt import QueryFilter, RawCatalogItem
from gpuhunt.providers import AbstractProvider
Expand All @@ -19,11 +19,11 @@
]


class DataCrunchProvider(AbstractProvider):
NAME = "datacrunch"
class VerdaProvider(AbstractProvider):
NAME = "verda"

def __init__(self, client_id: str, client_secret: str) -> None:
self.datacrunch_client = DataCrunchClient(client_id, client_secret)
self.verda_client = VerdaClient(client_id, client_secret)

def get(
self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True
Expand All @@ -38,10 +38,10 @@ def get(
return sorted(instances, key=lambda x: x.price)

def _get_instance_types(self) -> list[InstanceType]:
return self.datacrunch_client.instance_types.get()
return self.verda_client.instance_types.get()

def _get_locations(self) -> list[dict]:
return self.datacrunch_client.locations.get()
return self.verda_client.locations.get()

@classmethod
def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import pytest

from gpuhunt.providers.datacrunch import ALL_AMD_GPUS, GPU_MAP
from gpuhunt.providers.verda import ALL_AMD_GPUS, GPU_MAP


@pytest.fixture
def data_rows(catalog_dir: Path) -> list[dict]:
file = catalog_dir / "datacrunch.csv"
file = catalog_dir / "verda.csv"
reader = csv.DictReader(file.open())
return list(reader)

Expand Down
16 changes: 8 additions & 8 deletions src/tests/_internal/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def item() -> CatalogItem:

@pytest.fixture
def cpu_items() -> list[CatalogItem]:
datacrunch = CatalogItem(
verda = CatalogItem(
instance_name="CPU.120V.480G",
location="ICE-01",
price=3.0,
Expand All @@ -38,7 +38,7 @@ def cpu_items() -> list[CatalogItem]:
gpu_name=None,
gpu_memory=None,
spot=False,
provider="datacrunch",
provider="verda",
disk_size=None,
)
aws = CatalogItem(
Expand All @@ -55,7 +55,7 @@ def cpu_items() -> list[CatalogItem]:
provider="aws",
disk_size=None,
)
return [datacrunch, aws]
return [verda, aws]


class TestMatches:
Expand Down Expand Up @@ -154,19 +154,19 @@ def test_ti_gpu(self):
assert matches(item, QueryFilter(gpu_name=["RTX3060TI"]))

def test_provider(self, cpu_items):
assert matches(cpu_items[0], QueryFilter(provider=["datacrunch"]))
assert matches(cpu_items[0], QueryFilter(provider=["DataCrunch"]))
assert matches(cpu_items[0], QueryFilter(provider=["verda"]))
assert matches(cpu_items[0], QueryFilter(provider=["verda"]))
assert not matches(cpu_items[0], QueryFilter(provider=["aws"]))

assert matches(cpu_items[1], QueryFilter(provider=["aws"]))
assert matches(cpu_items[1], QueryFilter(provider=["AWS"]))
assert not matches(cpu_items[1], QueryFilter(provider=["datacrunch"]))
assert not matches(cpu_items[1], QueryFilter(provider=["verda"]))

def test_provider_with_filter_setattr(self, cpu_items):
q = QueryFilter()
q.provider = ["datacrunch"]
q.provider = ["verda"]
assert matches(cpu_items[0], q)
q.provider = ["DataCrunch"]
q.provider = ["verda"]
assert matches(cpu_items[0], q)
q.provider = ["aws"]
assert not matches(cpu_items[0], q)
Expand Down
4 changes: 2 additions & 2 deletions src/tests/providers/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def providers():
"""List of all provider classes"""
members = []
for module_info in pkgutil.walk_packages(gpuhunt.providers.__path__):
if sys.version_info < (3, 10) and module_info.name == "nebius":
if sys.version_info < (3, 10) and module_info.name in ["nebius", "verda"]:
continue
module = importlib.import_module(
f".{module_info.name}",
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_all_providers_have_a_names(providers):
def test_catalog_providers(providers):
CATALOG_PROVIDERS = OFFLINE_PROVIDERS + ONLINE_PROVIDERS
if sys.version_info < (3, 10):
CATALOG_PROVIDERS = [p for p in CATALOG_PROVIDERS if p != "nebius"]
CATALOG_PROVIDERS = [p for p in CATALOG_PROVIDERS if p not in ["nebius", "verda"]]
names = [p.NAME for p in providers]
assert set(CATALOG_PROVIDERS) == set(names)
assert len(CATALOG_PROVIDERS) == len(names)
Loading