Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 2 additions & 198 deletions src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import copy
from datetime import datetime
from pathlib import Path
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
from typing import BinaryIO, Dict, Iterable, List, Optional
from urllib.parse import urlencode, urlparse

from websocket import WebSocketApp
Expand All @@ -25,16 +25,10 @@
)
from dstack._internal.core.models.files import FileArchiveMapping
from dstack._internal.core.models.profiles import (
CreationPolicy,
Profile,
ProfileRetryPolicy,
SpotPolicy,
TerminationPolicy,
UtilizationPolicy,
)
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.models.resources import ResourcesSpec
from dstack._internal.core.models.runs import (
Job,
JobSpec,
Expand All @@ -55,7 +49,7 @@
from dstack._internal.utils.crypto import generate_rsa_key_pair
from dstack._internal.utils.files import create_file_archive
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import PathLike, path_in_dir
from dstack._internal.utils.path import PathLike
from dstack.api.server import APIClient

logger = get_logger(__name__)
Expand Down Expand Up @@ -616,196 +610,6 @@ def apply_configuration(
)
return run

def submit(
self,
configuration: AnyRunConfiguration,
configuration_path: Optional[str] = None,
repo: Optional[Repo] = None,
backends: Optional[List[BackendType]] = None,
regions: Optional[List[str]] = None,
instance_types: Optional[List[str]] = None,
resources: Optional[ResourcesSpec] = None,
spot_policy: Optional[SpotPolicy] = None,
retry_policy: Optional[ProfileRetryPolicy] = None,
max_duration: Optional[Union[int, str]] = None,
max_price: Optional[float] = None,
working_dir: Optional[str] = None,
run_name: Optional[str] = None,
reserve_ports: bool = True,
) -> Run:
# """
# Submit a run

# Args:
# configuration (Union[Task, Service]): A run configuration.
# configuration_path: The path to the configuration file, relative to the root directory of the repo.
# repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): A repo to mount to the run.
# backends: A list of allowed backend for provisioning.
# regions: A list of cloud regions for provisioning.
# resources: The requirements to run the configuration. Overrides the configuration's resources.
# spot_policy: A spot policy for provisioning.
# retry_policy (RetryPolicy): A retry policy.
# max_duration: The max instance running duration in seconds.
# max_price: The max instance price in dollars per hour for provisioning.
# working_dir: A working directory relative to the repo root directory
# run_name: A desired name of the run. Must be unique in the project. If not specified, a random name is assigned.
# reserve_ports: Whether local ports should be reserved in advance.

# Returns:
# Submitted run.
# """
logger.warning("The submit() method is deprecated in favor of apply_configuration().")
if repo is None:
repo = VirtualRepo()
# TODO: Add Git credentials to RemoteRepo and if they are set, pass them here to RepoCollection.init
self._client.repos.init(repo)

run_plan = self.get_plan(
configuration=configuration,
repo=repo,
configuration_path=configuration_path,
backends=backends,
regions=regions,
instance_types=instance_types,
resources=resources,
spot_policy=spot_policy,
retry_policy=retry_policy,
max_duration=max_duration,
max_price=max_price,
working_dir=working_dir,
run_name=run_name,
)
return self.exec_plan(run_plan, repo, reserve_ports=reserve_ports)

# Deprecated in favor of get_run_plan()
def get_plan(
self,
configuration: AnyRunConfiguration,
repo: Optional[Repo] = None,
configuration_path: Optional[str] = None,
# Unused profile args are deprecated and removed but
# kept for signature backward compatibility.
backends: Optional[List[BackendType]] = None,
regions: Optional[List[str]] = None,
instance_types: Optional[List[str]] = None,
resources: Optional[ResourcesSpec] = None,
spot_policy: Optional[SpotPolicy] = None,
retry_policy: Optional[ProfileRetryPolicy] = None,
utilization_policy: Optional[UtilizationPolicy] = None,
max_duration: Optional[Union[int, str]] = None,
max_price: Optional[float] = None,
working_dir: Optional[str] = None,
run_name: Optional[str] = None,
pool_name: Optional[str] = None,
instance_name: Optional[str] = None,
creation_policy: Optional[CreationPolicy] = None,
termination_policy: Optional[TerminationPolicy] = None,
termination_policy_idle: Optional[Union[str, int]] = None,
reservation: Optional[str] = None,
idle_duration: Optional[Union[str, int]] = None,
stop_duration: Optional[Union[str, int]] = None,
) -> RunPlan:
# """
# Get run plan. Same arguments as `submit`
#
# Returns:
# run plan
# """
logger.warning("The get_plan() method is deprecated in favor of get_run_plan().")
if repo is None:
repo = VirtualRepo()
repo_code_hash = None
else:
with _prepare_code_file(repo) as (_, repo_code_hash):
pass

if working_dir is None:
working_dir = "."
elif repo.repo_dir is not None:
working_dir_path = Path(repo.repo_dir) / working_dir
if not path_in_dir(working_dir_path, repo.repo_dir):
raise ConfigurationError("Working directory is outside of the repo")
working_dir = working_dir_path.relative_to(repo.repo_dir).as_posix()

if configuration_path is None:
configuration_path = "(python)"

if resources is not None:
configuration = configuration.copy(deep=True)
configuration.resources = resources

# TODO: [Andrey] "(python") looks as a hack
profile = Profile(
name="(python)",
backends=backends,
regions=regions,
instance_types=instance_types,
reservation=reservation,
spot_policy=spot_policy,
retry=None,
utilization_policy=utilization_policy,
max_duration=max_duration, # type: ignore[assignment]
stop_duration=stop_duration, # type: ignore[assignment]
max_price=max_price,
creation_policy=creation_policy,
idle_duration=idle_duration, # type: ignore[assignment]
)
config_manager = ConfigManager()
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
if key_manager.get_user_key():
ssh_key_pub = None # using the server-managed user key
else:
if not config_manager.dstack_key_path.exists():
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
logger.warning(
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
" You will only be able to attach to the run from this client."
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
" automatically replicated to all clients.",
)
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
run_spec = RunSpec(
run_name=run_name,
repo_id=repo.repo_id,
repo_data=repo.run_repo_data,
repo_code_hash=repo_code_hash,
working_dir=working_dir,
configuration_path=configuration_path,
configuration=configuration,
profile=profile,
ssh_key_pub=ssh_key_pub,
)
logger.debug("Getting run plan")
run_plan = self._api_client.runs.get_plan(self._project, run_spec)
if run_plan.current_resource is None and run_name is not None:
# If run_plan.current_resource is missing, this can mean old (0.18.x) server.
# TODO: Remove in 0.19
try:
run_plan.current_resource = self._api_client.runs.get(self._project, run_name)
except ResourceNotExistsError:
pass
return run_plan

def exec_plan(
self,
run_plan: RunPlan,
repo: Repo,
reserve_ports: bool = True,
) -> Run:
# """
# Execute the run plan.

# Args:
# run_plan: Result of `get_run_plan` call.
# repo: Repo to use for the run.
# reserve_ports: Reserve local ports before submit.

# Returns:
# Submitted run.
# """
logger.warning("The exec_plan() method is deprecated in favor of apply_plan().")
return self.apply_plan(run_plan=run_plan, repo=repo, reserve_ports=reserve_ports)

def list(self, all: bool = False, limit: Optional[int] = None) -> List[Run]:
"""
List runs.
Expand Down