diff --git a/docs/docs/reference/api/python/index.md b/docs/docs/reference/api/python/index.md index 5d4f7c1f49..96e39467e3 100644 --- a/docs/docs/reference/api/python/index.md +++ b/docs/docs/reference/api/python/index.md @@ -44,7 +44,7 @@ finally: !!! info "NOTE:" 1. The `configuration` argument in the `apply_configuration` method can be either `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`. 2. When you create `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`, you can specify the `image` argument. If `image` isn't specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument. - 3. The `repo` argument in the `apply_configuration` method allows the mounting of a local folder, a remote repo, or a + 3. The `repo` argument in the `apply_configuration` method allows the mounting of a remote repo or a programmatically created repo. In this case, the `commands` argument can refer to the files within this repo. 4. The `attach` method waits for the run to start and, for `dstack.api.Task` sets up an SSH tunnel and forwards configured `ports` to `localhost`. diff --git a/src/dstack/_internal/core/services/ssh/key_manager.py b/src/dstack/_internal/core/services/ssh/key_manager.py index 98e941638a..322ad323cc 100644 --- a/src/dstack/_internal/core/services/ssh/key_manager.py +++ b/src/dstack/_internal/core/services/ssh/key_manager.py @@ -2,9 +2,9 @@ from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -from dstack._internal.core.models.users import UserWithCreds +from dstack._internal.core.errors import ClientError if TYPE_CHECKING: from dstack.api.server import APIClient @@ -24,9 +24,9 @@ def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None: self._key_path = ssh_keys_dir / api_client.get_token_hash() self._pub_key_path = self._key_path.with_suffix(".pub") - def get_user_key(self) -> Optional[UserSSHKey]: + def get_user_key(self) -> UserSSHKey: """ - Return the up-to-date user key, or None if the user has no key (if created before 0.19.33) + Return the up-to-date user key """ if ( not self._key_path.exists() @@ -34,16 +34,15 @@ def get_user_key(self) -> Optional[UserSSHKey]: or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime) > KEY_REFRESH_RATE ): - if not self._download_user_key(): - return None + self._download_user_key() return UserSSHKey( public_key=self._pub_key_path.read_text(), private_key_path=self._key_path ) - def _download_user_key(self) -> bool: + def _download_user_key(self) -> None: user = self._api_client.users.get_my_user() - if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key): - return False + if user.ssh_private_key is None or user.ssh_public_key is None: + raise ClientError("Server response does not contain user SSH key") def key_opener(path, flags): return os.open(path, flags, 0o600) @@ -52,5 +51,3 @@ def key_opener(path, flags): f.write(user.ssh_private_key) with open(self._pub_key_path, "w") as f: f.write(user.ssh_public_key) - - return True diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index a7d438b805..30c8dbbd44 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -118,7 +118,7 @@ async def get_plan( """ user, project = user_project if not user.ssh_public_key and not body.run_spec.ssh_key_pub: - await users.refresh_ssh_key(session=session, user=user, username=user.name) + await users.refresh_ssh_key(session=session, user=user) run_plan = await runs.get_plan( session=session, project=project, @@ -148,7 +148,7 @@ async def apply_plan( """ user, project = user_project if not user.ssh_public_key and not body.plan.run_spec.ssh_key_pub: - await users.refresh_ssh_key(session=session, user=user, username=user.name) + await users.refresh_ssh_key(session=session, user=user) return CustomORJSONResponse( await runs.apply_plan( session=session, diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index be04f83929..2568c6ac29 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -43,10 +43,7 @@ async def get_my_user( ): if user.ssh_private_key is None or user.ssh_public_key is None: # Generate keys for pre-0.19.33 users - updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name) - if updated_user is None: - raise ResourceNotExistsError() - user = updated_user + await users.refresh_ssh_key(session=session, user=user) return CustomORJSONResponse(users.user_model_to_user_with_creds(user)) diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index a42fb64a1a..62fcc848ea 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -147,8 +147,10 @@ async def update_user( async def refresh_ssh_key( session: AsyncSession, user: UserModel, - username: str, + username: Optional[str] = None, ) -> Optional[UserModel]: + if username is None: + username = user.name logger.debug("Refreshing SSH key for user [code]%s[/code]", username) if user.global_role != GlobalRole.ADMIN and user.name != username: raise error_forbidden() diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 8f9459a766..7ba70f821d 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -277,6 +277,7 @@ def get_run_spec( configuration_path: str = "dstack.yaml", profile: Union[Profile, Callable[[], Profile], None] = lambda: Profile(name="default"), configuration: Optional[AnyRunConfiguration] = None, + ssh_key_pub: Optional[str] = "user_ssh_key", ) -> RunSpec: if callable(profile): profile = profile() @@ -288,7 +289,7 @@ def get_run_spec( configuration_path=configuration_path, configuration=configuration or DevEnvironmentConfiguration(ide="vscode"), profile=profile, - ssh_key_pub="user_ssh_key", + ssh_key_pub=ssh_key_pub, ) diff --git a/src/dstack/api/_public/common.py b/src/dstack/api/_public/common.py new file mode 100644 index 0000000000..148d39d038 --- /dev/null +++ b/src/dstack/api/_public/common.py @@ -0,0 +1,5 @@ +import enum + + +class Deprecated(enum.Enum): + PLACEHOLDER = "DEPRECATED" diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 8725162ea9..72d31189f2 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -9,7 +9,7 @@ from copy import copy from datetime import datetime from pathlib import Path -from typing import BinaryIO, Dict, Iterable, List, Optional +from typing import BinaryIO, Dict, Iterable, List, Optional, Union from urllib.parse import urlencode, urlparse from websocket import WebSocketApp @@ -46,10 +46,10 @@ from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest from dstack._internal.utils.common import get_or_error, make_proxy_url -from dstack._internal.utils.crypto import generate_rsa_key_pair from dstack._internal.utils.files import create_file_archive from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import PathLike +from dstack.api._public.common import Deprecated from dstack.api.server import APIClient logger = get_logger(__name__) @@ -278,13 +278,11 @@ def attach( if not ssh_identity_file: config_manager = ConfigManager() key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) - if ( - user_key := key_manager.get_user_key() - ) and user_key.public_key == self._run.run_spec.ssh_key_pub: + user_key = key_manager.get_user_key() + if user_key.public_key == self._run.run_spec.ssh_key_pub: ssh_identity_file = user_key.private_key_path else: if config_manager.dstack_key_path.exists(): - # TODO: Remove since 0.19.40 logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].") ssh_identity_file = config_manager.dstack_key_path else: @@ -451,7 +449,7 @@ def get_run_plan( repo: Optional[Repo] = None, profile: Optional[Profile] = None, configuration_path: Optional[str] = None, - repo_dir: Optional[str] = None, + repo_dir: Union[Deprecated, str, None] = Deprecated.PLACEHOLDER, ssh_identity_file: Optional[PathLike] = None, ) -> RunPlan: """ @@ -465,9 +463,10 @@ def get_run_plan( profile: The profile to use for the run. configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. - repo_dir: The path of the cloned repo inside the run container. If not set, - defaults first to the `repos[0].path` property of the configuration (for remote - repos only). + ssh_identity_file: Path to the private SSH key file. The corresponding public key + (`.pub` file) is read and included in the run plan, allowing SSH access to the instances. + If the `.pub` file does not exist, it is generated automatically. + If ssh_identity_file is not specified, the user key is used. Returns: Run plan. @@ -479,8 +478,15 @@ def get_run_plan( with _prepare_code_file(repo) as (_, repo_code_hash): pass - if repo_dir is None and configuration.repos: + if repo_dir is not Deprecated.PLACEHOLDER: + logger.warning( + "The repo_dir argument is deprecated, ignored, and will be removed soon." + " Remove it and use the repos[].path configuration property instead." + ) + if configuration.repos: repo_dir = configuration.repos[0].path + else: + repo_dir = None self._validate_configuration_files(configuration, configuration_path) file_archives: list[FileArchiveMapping] = [] @@ -497,20 +503,7 @@ def get_run_plan( if ssh_identity_file: ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text() else: - config_manager = ConfigManager() - key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir) - if key_manager.get_user_key(): - ssh_key_pub = None # using the server-managed user key - else: - if not config_manager.dstack_key_path.exists(): - generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path) - logger.warning( - f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]." - " You will only be able to attach to the run from this client." - " Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys" - " automatically replicated to all clients.", - ) - ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text() + ssh_key_pub = None # using the server-managed user key run_spec = RunSpec( run_name=configuration.name, repo_id=repo.repo_id, @@ -587,6 +580,10 @@ def apply_configuration( profile: The profile to use for the run. configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. reserve_ports: Reserve local ports before applying. Use if you'll attach to the run. + ssh_identity_file: Path to the private SSH key file. The corresponding public key + (`.pub` file) is read and included in the run plan, allowing SSH access to the instances. + If the `.pub` file does not exist, it is generated automatically. + If ssh_identity_file is not specified, the user key is used. Returns: Submitted run. diff --git a/src/dstack/api/server/_users.py b/src/dstack/api/server/_users.py index 08dd118b4d..6082636c4b 100644 --- a/src/dstack/api/server/_users.py +++ b/src/dstack/api/server/_users.py @@ -1,6 +1,6 @@ from typing import List -from pydantic import ValidationError, parse_obj_as +from pydantic import parse_obj_as from dstack._internal.core.models.users import GlobalRole, User, UserWithCreds from dstack._internal.server.schemas.users import ( @@ -17,24 +17,9 @@ def list(self) -> List[User]: resp = self._request("/api/users/list") return parse_obj_as(List[User.__response__], resp.json()) - def get_my_user(self) -> User: - """ - Returns `User` with pre-0.19.33 servers, or `UserWithCreds` with newer servers. - """ - + def get_my_user(self) -> UserWithCreds: resp = self._request("/api/users/get_my_user") - try: - return parse_obj_as(UserWithCreds.__response__, resp.json()) - except ValidationError as e: - # Compatibility with pre-0.19.33 server - if ( - len(e.errors()) == 1 - and e.errors()[0]["loc"] == ("__root__", "creds") - and e.errors()[0]["type"] == "value_error.missing" - ): - return parse_obj_as(User.__response__, resp.json()) - else: - raise + return parse_obj_as(UserWithCreds.__response__, resp.json()) def get_user(self, username: str) -> User: body = GetUserRequest(username=username) diff --git a/src/tests/_internal/core/services/ssh/test_key_manager.py b/src/tests/_internal/core/services/ssh/test_key_manager.py index d727ba6258..7792c70377 100644 --- a/src/tests/_internal/core/services/ssh/test_key_manager.py +++ b/src/tests/_internal/core/services/ssh/test_key_manager.py @@ -44,17 +44,6 @@ def set_mtime(path: Path, ts: float): os.utime(path, (ts, ts)) -def test_get_user_key_returns_none_when_no_user_creds(tmp_path: Path): - api_client = make_api_client( - user=User.__response__.parse_obj(SAMPLE_USER.dict()), token_hash=SAMPLE_USER_TOKEN_HASH - ) - manager = UserSSHKeyManager(api_client, tmp_path) - - assert manager.get_user_key() is None - assert not (tmp_path / SAMPLE_USER_TOKEN_HASH).exists() - assert not (tmp_path / f"{SAMPLE_USER_TOKEN_HASH}.pub").exists() - - def test_get_user_key_downloads_keys(tmp_path: Path): api_client = make_api_client(user=SAMPLE_USER, token_hash=SAMPLE_USER_TOKEN_HASH) manager = UserSSHKeyManager(api_client, tmp_path) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 570029a940..7de08b04aa 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1364,6 +1364,32 @@ async def test_returns_update_or_create_action_on_conf_change( assert response_json["action"] == action assert response_json["current_resource"] == json.loads(run.json()) + @pytest.mark.asyncio + @pytest.mark.usefixtures("test_db") + async def test_generates_user_ssh_key(self, session: AsyncSession, client: AsyncClient): + user = await create_user( + session=session, global_role=GlobalRole.USER, ssh_public_key=None, ssh_private_key=None + ) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name, ssh_key_pub=None) + + response = await client.post( + f"/api/project/{project.name}/runs/get_plan", + headers=get_auth_headers(user.token), + json={"run_spec": run_spec.dict()}, + ) + + assert response.status_code == 200, response.json() + run_spec_ssh_public_key = response.json()["effective_run_spec"]["ssh_key_pub"] + assert run_spec_ssh_public_key is not None + await session.refresh(user) + assert user.ssh_public_key == run_spec_ssh_public_key + assert user.ssh_private_key is not None + class TestApplyPlan: @pytest.mark.asyncio @@ -1517,6 +1543,38 @@ async def test_creates_pending_run_if_run_is_scheduled( assert run.status == RunStatus.PENDING assert run.next_triggered_at == datetime(2023, 1, 2, 3, 10, tzinfo=timezone.utc) + @pytest.mark.asyncio + @pytest.mark.usefixtures("test_db") + async def test_generates_user_ssh_key(self, session: AsyncSession, client: AsyncClient): + user = await create_user( + session=session, global_role=GlobalRole.USER, ssh_public_key=None, ssh_private_key=None + ) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name, ssh_key_pub=None) + + response = await client.post( + f"/api/project/{project.name}/runs/apply", + headers=get_auth_headers(user.token), + json={ + "plan": { + "run_spec": run_spec.dict(), + "current_resource": None, + }, + "force": False, + }, + ) + + assert response.status_code == 200, response.json() + run_spec_ssh_public_key = response.json()["run_spec"]["ssh_key_pub"] + assert run_spec_ssh_public_key is not None + await session.refresh(user) + assert user.ssh_public_key == run_spec_ssh_public_key + assert user.ssh_private_key is not None + class TestSubmitRun: @pytest.mark.asyncio