diff --git a/docs/docs/reference/cli/dstack/login.md b/docs/docs/reference/cli/dstack/login.md new file mode 100644 index 0000000000..d608476e27 --- /dev/null +++ b/docs/docs/reference/cli/dstack/login.md @@ -0,0 +1,17 @@ +# dstack login + +This command authorizes the CLI using Single Sign-On and automatically configures your projects. +It provides an alternative to `dstack project add`. + +## Usage + +
+ +```shell +$ dstack login --help +#GENERATE# +``` + +
+ +[//]: # (TODO: Provide examples) diff --git a/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx b/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx index aa70d00797..036851c3cf 100644 --- a/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx +++ b/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx @@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout'; import { useAppDispatch } from 'hooks'; import { ROUTES } from 'routes'; -import { useEntraCallbackMutation } from 'services/auth'; +import { useEntraCallbackMutation, useGetNextRedirectMutation } from 'services/auth'; import { AuthErrorMessage } from 'App/AuthErrorMessage'; import { getBaseUrl } from 'App/helpers'; @@ -23,15 +23,27 @@ export const LoginByEntraIDCallback: React.FC = () => { const [isInvalidCode, setIsInvalidCode] = useState(false); const dispatch = useAppDispatch(); + const [getNextRedirect] = useGetNextRedirectMutation(); const [entraCallback] = useEntraCallbackMutation(); const checkCode = () => { if (code && state) { - entraCallback({ code, state, base_url: getBaseUrl() }) + getNextRedirect({ code, state }) .unwrap() - .then(({ creds: { token } }) => { - dispatch(setAuthData({ token })); - navigate('/'); + .then(({ redirect_url }) => { + if (redirect_url) { + window.location.href = redirect_url; + return; + } + entraCallback({ code, state, base_url: getBaseUrl() }) + .unwrap() + .then(({ creds: { token } }) => { + dispatch(setAuthData({ token })); + navigate('/'); + }) + .catch(() => { + setIsInvalidCode(true); + }); }) .catch(() => { setIsInvalidCode(true); diff --git a/frontend/src/App/Login/LoginByGithubCallback/index.tsx b/frontend/src/App/Login/LoginByGithubCallback/index.tsx index 27d5a755a7..af88aa72f1 100644 --- a/frontend/src/App/Login/LoginByGithubCallback/index.tsx +++ b/frontend/src/App/Login/LoginByGithubCallback/index.tsx @@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout'; import { useAppDispatch } from 'hooks'; import { ROUTES } from 'routes'; -import { useGithubCallbackMutation } from 'services/auth'; +import { useGetNextRedirectMutation, useGithubCallbackMutation } from 'services/auth'; import { useLazyGetProjectsQuery } from 'services/project'; import { AuthErrorMessage } from 'App/AuthErrorMessage'; @@ -23,26 +23,35 @@ export const LoginByGithubCallback: React.FC = () => { const [isInvalidCode, setIsInvalidCode] = useState(false); const dispatch = useAppDispatch(); + const [getNextRedirect] = useGetNextRedirectMutation(); const [githubCallback] = useGithubCallbackMutation(); const [getProjects] = useLazyGetProjectsQuery(); const checkCode = () => { if (code && state) { - githubCallback({ code, state }) + getNextRedirect({ code: code, state: state }) .unwrap() - .then(async ({ creds: { token } }) => { - dispatch(setAuthData({ token })); - - if (process.env.UI_VERSION === 'sky') { - const result = await getProjects().unwrap(); - - if (result?.length === 0) { - navigate(ROUTES.PROJECT.ADD); - return; - } + .then(async ({ redirect_url }) => { + if (redirect_url) { + window.location.href = redirect_url; + return; } - - navigate('/'); + githubCallback({ code, state }) + .unwrap() + .then(async ({ creds: { token } }) => { + dispatch(setAuthData({ token })); + if (process.env.UI_VERSION === 'sky') { + const result = await getProjects().unwrap(); + if (result?.length === 0) { + navigate(ROUTES.PROJECT.ADD); + return; + } + } + navigate('/'); + }) + .catch(() => { + setIsInvalidCode(true); + }); }) .catch(() => { setIsInvalidCode(true); diff --git a/frontend/src/App/Login/LoginByGoogleCallback/index.tsx b/frontend/src/App/Login/LoginByGoogleCallback/index.tsx index 465d0be3ee..4f95f94e27 100644 --- a/frontend/src/App/Login/LoginByGoogleCallback/index.tsx +++ b/frontend/src/App/Login/LoginByGoogleCallback/index.tsx @@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout'; import { useAppDispatch } from 'hooks'; import { ROUTES } from 'routes'; -import { useGoogleCallbackMutation } from 'services/auth'; +import { useGetNextRedirectMutation, useGoogleCallbackMutation } from 'services/auth'; import { AuthErrorMessage } from 'App/AuthErrorMessage'; import { Loading } from 'App/Loading'; @@ -22,15 +22,27 @@ export const LoginByGoogleCallback: React.FC = () => { const [isInvalidCode, setIsInvalidCode] = useState(false); const dispatch = useAppDispatch(); + const [getNextRedirect] = useGetNextRedirectMutation(); const [googleCallback] = useGoogleCallbackMutation(); const checkCode = () => { if (code && state) { - googleCallback({ code, state }) + getNextRedirect({ code, state }) .unwrap() - .then(({ creds: { token } }) => { - dispatch(setAuthData({ token })); - navigate('/'); + .then(({ redirect_url }) => { + if (redirect_url) { + window.location.href = redirect_url; + return; + } + googleCallback({ code, state }) + .unwrap() + .then(({ creds: { token } }) => { + dispatch(setAuthData({ token })); + navigate('/'); + }) + .catch(() => { + setIsInvalidCode(true); + }); }) .catch(() => { setIsInvalidCode(true); diff --git a/frontend/src/App/Login/LoginByOktaCallback/index.tsx b/frontend/src/App/Login/LoginByOktaCallback/index.tsx index ccc9fbc749..72cdc96185 100644 --- a/frontend/src/App/Login/LoginByOktaCallback/index.tsx +++ b/frontend/src/App/Login/LoginByOktaCallback/index.tsx @@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout'; import { useAppDispatch } from 'hooks'; import { ROUTES } from 'routes'; -import { useOktaCallbackMutation } from 'services/auth'; +import { useGetNextRedirectMutation, useOktaCallbackMutation } from 'services/auth'; import { AuthErrorMessage } from 'App/AuthErrorMessage'; import { Loading } from 'App/Loading'; @@ -22,15 +22,27 @@ export const LoginByOktaCallback: React.FC = () => { const [isInvalidCode, setIsInvalidCode] = useState(false); const dispatch = useAppDispatch(); + const [getNextRedirect] = useGetNextRedirectMutation(); const [oktaCallback] = useOktaCallbackMutation(); const checkCode = () => { if (code && state) { - oktaCallback({ code, state }) + getNextRedirect({ code, state }) .unwrap() - .then(({ creds: { token } }) => { - dispatch(setAuthData({ token })); - navigate('/'); + .then(({ redirect_url }) => { + if (redirect_url) { + window.location.href = redirect_url; + return; + } + oktaCallback({ code, state }) + .unwrap() + .then(({ creds: { token } }) => { + dispatch(setAuthData({ token })); + navigate('/'); + }) + .catch(() => { + setIsInvalidCode(true); + }); }) .catch(() => { setIsInvalidCode(true); diff --git a/frontend/src/api.ts b/frontend/src/api.ts index 2dea526601..262aa46b75 100644 --- a/frontend/src/api.ts +++ b/frontend/src/api.ts @@ -5,6 +5,7 @@ export const API = { AUTH: { BASE: () => `${API.BASE()}/auth`, + NEXT_REDIRECT: () => `${API.AUTH.BASE()}/get_next_redirect`, GITHUB: { BASE: () => `${API.AUTH.BASE()}/github`, AUTHORIZE: () => `${API.AUTH.GITHUB.BASE()}/authorize`, diff --git a/frontend/src/services/auth.ts b/frontend/src/services/auth.ts index f65892911a..2512ed0a7d 100644 --- a/frontend/src/services/auth.ts +++ b/frontend/src/services/auth.ts @@ -12,6 +12,14 @@ export const authApi = createApi({ tagTypes: ['Auth'], endpoints: (builder) => ({ + getNextRedirect: builder.mutation<{ redirect_url?: string }, { code: string; state: string }>({ + query: (body) => ({ + url: API.AUTH.NEXT_REDIRECT(), + method: 'POST', + body, + }), + }), + githubAuthorize: builder.mutation<{ authorization_url: string }, void>({ query: () => ({ url: API.AUTH.GITHUB.AUTHORIZE(), @@ -103,6 +111,7 @@ export const authApi = createApi({ }); export const { + useGetNextRedirectMutation, useGithubAuthorizeMutation, useGithubCallbackMutation, useGetOktaInfoQuery, diff --git a/mkdocs.yml b/mkdocs.yml index e793bd23c2..74939703e3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -112,67 +112,67 @@ plugins: background_color: "black" color: "#FFFFFF" font_family: "Roboto" -# debug: true + # debug: true cards_layout_dir: docs/layouts cards_layout: custom - search - redirects: redirect_maps: - 'blog/2024/02/08/resources-authentication-and-more.md': 'https://github.com/dstackai/dstack/releases/0.15.0' - 'blog/2024/01/19/openai-endpoints-preview.md': 'https://github.com/dstackai/dstack/releases/0.14.0' - 'blog/2023/12/22/disk-size-cuda-12-1-mixtral-and-more.md': 'https://github.com/dstackai/dstack/releases/0.13.0' - 'blog/2023/11/21/vastai.md': 'https://github.com/dstackai/dstack/releases/0.12.3' - 'blog/2023/10/31/tensordock.md': 'https://github.com/dstackai/dstack/releases/0.12.2' - 'blog/2023/10/18/simplified-cloud-setup.md': 'https://github.com/dstackai/dstack/releases/0.12.0' - 'blog/2023/08/22/multiple-clouds.md': 'https://github.com/dstackai/dstack/releases/0.11' - 'blog/2023/08/07/services-preview.md': 'https://github.com/dstackai/dstack/releases/0.10.7' - 'blog/2023/07/14/lambda-cloud-ga-and-docker-support.md': 'https://github.com/dstackai/dstack/releases/0.10.5' - 'blog/2023/05/22/azure-support-better-ui-and-more.md': 'https://github.com/dstackai/dstack/releases/0.9.1' - 'blog/2023/03/13/gcp-support-just-landed.md': 'https://github.com/dstackai/dstack/releases/0.2' - 'blog/dstack-research.md': 'https://dstack.ai/#get-started' - 'docs/dev-environments.md': 'docs/concepts/dev-environments.md' - 'docs/tasks.md': 'docs/concepts/tasks.md' - 'docs/services.md': 'docs/concepts/services.md' - 'docs/fleets.md': 'docs/concepts/fleets.md' - 'docs/examples/llms/llama31.md': 'examples/llms/llama/index.md' - 'docs/examples/llms/llama32.md': 'examples/llms/llama/index.md' - 'examples/llms/llama31/index.md': 'examples/llms/llama/index.md' - 'examples/llms/llama32/index.md': 'examples/llms/llama/index.md' - 'docs/examples/accelerators/amd/index.md': 'examples/accelerators/amd/index.md' - 'docs/examples/deployment/nim/index.md': 'examples/inference/nim/index.md' - 'docs/examples/deployment/vllm/index.md': 'examples/inference/vllm/index.md' - 'docs/examples/deployment/tgi/index.md': 'examples/inference/tgi/index.md' - 'providers.md': 'partners.md' - 'backends.md': 'partners.md' - 'blog/monitoring-gpu-usage.md': 'blog/posts/dstack-metrics.md' - 'blog/inactive-dev-environments-auto-shutdown.md': 'blog/posts/inactivity-duration.md' - 'blog/data-centers-and-private-clouds.md': 'blog/posts/gpu-blocks-and-proxy-jump.md' - 'blog/distributed-training-with-aws-efa.md': 'examples/clusters/aws/index.md' - 'blog/dstack-stats.md': 'blog/posts/dstack-metrics.md' - 'docs/concepts/metrics.md': 'docs/guides/metrics.md' - 'docs/guides/monitoring.md': 'docs/guides/metrics.md' - 'blog/nvidia-and-amd-on-vultr.md.md': 'blog/posts/nvidia-and-amd-on-vultr.md' - 'examples/misc/nccl-tests/index.md': 'examples/clusters/nccl-rccl-tests/index.md' - 'examples/misc/a3high-clusters/index.md': 'examples/clusters/gcp/index.md' - 'examples/misc/a3mega-clusters/index.md': 'examples/clusters/gcp/index.md' - 'examples/distributed-training/nccl-tests/index.md': 'examples/clusters/nccl-rccl-tests/index.md' - 'examples/distributed-training/rccl-tests/index.md': 'examples/clusters/nccl-rccl-tests/index.md' - 'examples/deployment/nim/index.md': 'examples/inference/nim/index.md' - 'examples/deployment/vllm/index.md': 'examples/inference/vllm/index.md' - 'examples/deployment/tgi/index.md': 'examples/inference/tgi/index.md' - 'examples/deployment/sglang/index.md': 'examples/inference/sglang/index.md' - 'examples/deployment/trtllm/index.md': 'examples/inference/trtllm/index.md' - 'examples/fine-tuning/trl/index.md': 'examples/single-node-training/trl/index.md' - 'examples/fine-tuning/axolotl/index.md': 'examples/single-node-training/axolotl/index.md' - 'blog/efa.md': 'examples/clusters/aws/index.md' - 'docs/concepts/repos.md': 'docs/concepts/dev-environments.md#repos' - 'examples/clusters/a3high/index.md': 'examples/clusters/gcp/index.md' - 'examples/clusters/a3mega/index.md': 'examples/clusters/gcp/index.md' - 'examples/clusters/a4/index.md': 'examples/clusters/gcp/index.md' - 'examples/clusters/efa/index.md': 'examples/clusters/aws/index.md' + "blog/2024/02/08/resources-authentication-and-more.md": "https://github.com/dstackai/dstack/releases/0.15.0" + "blog/2024/01/19/openai-endpoints-preview.md": "https://github.com/dstackai/dstack/releases/0.14.0" + "blog/2023/12/22/disk-size-cuda-12-1-mixtral-and-more.md": "https://github.com/dstackai/dstack/releases/0.13.0" + "blog/2023/11/21/vastai.md": "https://github.com/dstackai/dstack/releases/0.12.3" + "blog/2023/10/31/tensordock.md": "https://github.com/dstackai/dstack/releases/0.12.2" + "blog/2023/10/18/simplified-cloud-setup.md": "https://github.com/dstackai/dstack/releases/0.12.0" + "blog/2023/08/22/multiple-clouds.md": "https://github.com/dstackai/dstack/releases/0.11" + "blog/2023/08/07/services-preview.md": "https://github.com/dstackai/dstack/releases/0.10.7" + "blog/2023/07/14/lambda-cloud-ga-and-docker-support.md": "https://github.com/dstackai/dstack/releases/0.10.5" + "blog/2023/05/22/azure-support-better-ui-and-more.md": "https://github.com/dstackai/dstack/releases/0.9.1" + "blog/2023/03/13/gcp-support-just-landed.md": "https://github.com/dstackai/dstack/releases/0.2" + "blog/dstack-research.md": "https://dstack.ai/#get-started" + "docs/dev-environments.md": "docs/concepts/dev-environments.md" + "docs/tasks.md": "docs/concepts/tasks.md" + "docs/services.md": "docs/concepts/services.md" + "docs/fleets.md": "docs/concepts/fleets.md" + "docs/examples/llms/llama31.md": "examples/llms/llama/index.md" + "docs/examples/llms/llama32.md": "examples/llms/llama/index.md" + "examples/llms/llama31/index.md": "examples/llms/llama/index.md" + "examples/llms/llama32/index.md": "examples/llms/llama/index.md" + "docs/examples/accelerators/amd/index.md": "examples/accelerators/amd/index.md" + "docs/examples/deployment/nim/index.md": "examples/inference/nim/index.md" + "docs/examples/deployment/vllm/index.md": "examples/inference/vllm/index.md" + "docs/examples/deployment/tgi/index.md": "examples/inference/tgi/index.md" + "providers.md": "partners.md" + "backends.md": "partners.md" + "blog/monitoring-gpu-usage.md": "blog/posts/dstack-metrics.md" + "blog/inactive-dev-environments-auto-shutdown.md": "blog/posts/inactivity-duration.md" + "blog/data-centers-and-private-clouds.md": "blog/posts/gpu-blocks-and-proxy-jump.md" + "blog/distributed-training-with-aws-efa.md": "examples/clusters/aws/index.md" + "blog/dstack-stats.md": "blog/posts/dstack-metrics.md" + "docs/concepts/metrics.md": "docs/guides/metrics.md" + "docs/guides/monitoring.md": "docs/guides/metrics.md" + "blog/nvidia-and-amd-on-vultr.md.md": "blog/posts/nvidia-and-amd-on-vultr.md" + "examples/misc/nccl-tests/index.md": "examples/clusters/nccl-rccl-tests/index.md" + "examples/misc/a3high-clusters/index.md": "examples/clusters/gcp/index.md" + "examples/misc/a3mega-clusters/index.md": "examples/clusters/gcp/index.md" + "examples/distributed-training/nccl-tests/index.md": "examples/clusters/nccl-rccl-tests/index.md" + "examples/distributed-training/rccl-tests/index.md": "examples/clusters/nccl-rccl-tests/index.md" + "examples/deployment/nim/index.md": "examples/inference/nim/index.md" + "examples/deployment/vllm/index.md": "examples/inference/vllm/index.md" + "examples/deployment/tgi/index.md": "examples/inference/tgi/index.md" + "examples/deployment/sglang/index.md": "examples/inference/sglang/index.md" + "examples/deployment/trtllm/index.md": "examples/inference/trtllm/index.md" + "examples/fine-tuning/trl/index.md": "examples/single-node-training/trl/index.md" + "examples/fine-tuning/axolotl/index.md": "examples/single-node-training/axolotl/index.md" + "blog/efa.md": "examples/clusters/aws/index.md" + "docs/concepts/repos.md": "docs/concepts/dev-environments.md#repos" + "examples/clusters/a3high/index.md": "examples/clusters/gcp/index.md" + "examples/clusters/a3mega/index.md": "examples/clusters/gcp/index.md" + "examples/clusters/a4/index.md": "examples/clusters/gcp/index.md" + "examples/clusters/efa/index.md": "examples/clusters/aws/index.md" - typeset - gen-files: - scripts: # always relative to mkdocs.yml + scripts: # always relative to mkdocs.yml - scripts/docs/gen_examples.py - scripts/docs/gen_cli_reference.py - scripts/docs/gen_openapi_reference.py @@ -279,70 +279,71 @@ nav: - Protips: docs/guides/protips.md - Migration: docs/guides/migration.md - Reference: - - .dstack.yml: - - dev-environment: docs/reference/dstack.yml/dev-environment.md - - task: docs/reference/dstack.yml/task.md - - service: docs/reference/dstack.yml/service.md - - fleet: docs/reference/dstack.yml/fleet.md - - gateway: docs/reference/dstack.yml/gateway.md - - volume: docs/reference/dstack.yml/volume.md - - server/config.yml: docs/reference/server/config.yml.md - - CLI: - - dstack server: docs/reference/cli/dstack/server.md - - dstack init: docs/reference/cli/dstack/init.md - - dstack apply: docs/reference/cli/dstack/apply.md - - dstack delete: docs/reference/cli/dstack/delete.md - - dstack ps: docs/reference/cli/dstack/ps.md - - dstack stop: docs/reference/cli/dstack/stop.md - - dstack attach: docs/reference/cli/dstack/attach.md - - dstack logs: docs/reference/cli/dstack/logs.md - - dstack metrics: docs/reference/cli/dstack/metrics.md - - dstack event: docs/reference/cli/dstack/event.md - - dstack project: docs/reference/cli/dstack/project.md - - dstack fleet: docs/reference/cli/dstack/fleet.md - - dstack offer: docs/reference/cli/dstack/offer.md - - dstack volume: docs/reference/cli/dstack/volume.md - - dstack gateway: docs/reference/cli/dstack/gateway.md - - dstack secret: docs/reference/cli/dstack/secret.md - - API: - - Python API: docs/reference/api/python/index.md - - REST API: docs/reference/api/rest/index.md - - Environment variables: docs/reference/environment-variables.md - - .dstack/profiles.yml: docs/reference/profiles.yml.md - - Plugins: - - Python API: docs/reference/plugins/python/index.md - - REST API: docs/reference/plugins/rest/index.md - - llms-full.txt: https://dstack.ai/llms-full.txt + - .dstack.yml: + - dev-environment: docs/reference/dstack.yml/dev-environment.md + - task: docs/reference/dstack.yml/task.md + - service: docs/reference/dstack.yml/service.md + - fleet: docs/reference/dstack.yml/fleet.md + - gateway: docs/reference/dstack.yml/gateway.md + - volume: docs/reference/dstack.yml/volume.md + - server/config.yml: docs/reference/server/config.yml.md + - CLI: + - dstack server: docs/reference/cli/dstack/server.md + - dstack init: docs/reference/cli/dstack/init.md + - dstack apply: docs/reference/cli/dstack/apply.md + - dstack delete: docs/reference/cli/dstack/delete.md + - dstack ps: docs/reference/cli/dstack/ps.md + - dstack stop: docs/reference/cli/dstack/stop.md + - dstack attach: docs/reference/cli/dstack/attach.md + - dstack login: docs/reference/cli/dstack/login.md + - dstack logs: docs/reference/cli/dstack/logs.md + - dstack metrics: docs/reference/cli/dstack/metrics.md + - dstack event: docs/reference/cli/dstack/event.md + - dstack project: docs/reference/cli/dstack/project.md + - dstack fleet: docs/reference/cli/dstack/fleet.md + - dstack offer: docs/reference/cli/dstack/offer.md + - dstack volume: docs/reference/cli/dstack/volume.md + - dstack gateway: docs/reference/cli/dstack/gateway.md + - dstack secret: docs/reference/cli/dstack/secret.md + - API: + - Python API: docs/reference/api/python/index.md + - REST API: docs/reference/api/rest/index.md + - Environment variables: docs/reference/environment-variables.md + - .dstack/profiles.yml: docs/reference/profiles.yml.md + - Plugins: + - Python API: docs/reference/plugins/python/index.md + - REST API: docs/reference/plugins/rest/index.md + - llms-full.txt: https://dstack.ai/llms-full.txt - Examples: - - examples.md - - Single-node training: - - TRL: examples/single-node-training/trl/index.md - - Axolotl: examples/single-node-training/axolotl/index.md - - Distributed training: - - TRL: examples/distributed-training/trl/index.md - - Axolotl: examples/distributed-training/axolotl/index.md - - Ray+RAGEN: examples/distributed-training/ray-ragen/index.md - - Clusters: - - AWS: examples/clusters/aws/index.md - - GCP: examples/clusters/gcp/index.md - - Lambda: examples/clusters/lambda/index.md - - Crusoe: examples/clusters/crusoe/index.md - - NCCL/RCCL tests: examples/clusters/nccl-rccl-tests/index.md - - Inference: - - SGLang: examples/inference/sglang/index.md - - vLLM: examples/inference/vllm/index.md - - TGI: examples/inference/tgi/index.md - - NIM: examples/inference/nim/index.md - - TensorRT-LLM: examples/inference/trtllm/index.md - - Accelerators: - - AMD: examples/accelerators/amd/index.md - - TPU: examples/accelerators/tpu/index.md - - Intel Gaudi: examples/accelerators/intel/index.md - - Tenstorrent: examples/accelerators/tenstorrent/index.md - - Models: - - Wan2.2: examples/models/wan22/index.md - - Blog: - - blog/index.md + - examples.md + - Single-node training: + - TRL: examples/single-node-training/trl/index.md + - Axolotl: examples/single-node-training/axolotl/index.md + - Distributed training: + - TRL: examples/distributed-training/trl/index.md + - Axolotl: examples/distributed-training/axolotl/index.md + - Ray+RAGEN: examples/distributed-training/ray-ragen/index.md + - Clusters: + - AWS: examples/clusters/aws/index.md + - GCP: examples/clusters/gcp/index.md + - Lambda: examples/clusters/lambda/index.md + - Crusoe: examples/clusters/crusoe/index.md + - NCCL/RCCL tests: examples/clusters/nccl-rccl-tests/index.md + - Inference: + - SGLang: examples/inference/sglang/index.md + - vLLM: examples/inference/vllm/index.md + - TGI: examples/inference/tgi/index.md + - NIM: examples/inference/nim/index.md + - TensorRT-LLM: examples/inference/trtllm/index.md + - Accelerators: + - AMD: examples/accelerators/amd/index.md + - TPU: examples/accelerators/tpu/index.md + - Intel Gaudi: examples/accelerators/intel/index.md + - Tenstorrent: examples/accelerators/tenstorrent/index.md + - Models: + - Wan2.2: examples/models/wan22/index.md + - Blog: + - blog/index.md - Case studies: blog/case-studies.md - Benchmarks: blog/benchmarks.md # - Discord: https://discord.gg/u8SmfwPpMd" target="_blank diff --git a/pyproject.toml b/pyproject.toml index e69ec4d5aa..e540705d93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,11 +100,12 @@ ignore = [ dev = [ "httpx>=0.28.1", "pre-commit>=4.2.0", + "pytest~=7.2", "pytest-asyncio>=0.23.8", "pytest-httpbin>=2.1.0", - "httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3 - "pytest~=7.2", "pytest-socket>=0.7.0", + "pytest-env>=1.1.0", + "httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3 "requests-mock>=1.12.1", "openai>=1.68.2", "freezegun>=1.5.1", diff --git a/pytest.ini b/pytest.ini index 899f67a61b..30c0e62811 100644 --- a/pytest.ini +++ b/pytest.ini @@ -8,3 +8,5 @@ addopts = markers = shim_version dockerized +env = + DSTACK_CLI_RICH_FORCE_TERMINAL=0 diff --git a/src/dstack/_internal/cli/commands/login.py b/src/dstack/_internal/cli/commands/login.py new file mode 100644 index 0000000000..54fdc0a0b6 --- /dev/null +++ b/src/dstack/_internal/cli/commands/login.py @@ -0,0 +1,237 @@ +import argparse +import queue +import threading +import urllib.parse +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Optional + +from dstack._internal.cli.commands import BaseCommand +from dstack._internal.cli.utils.common import console +from dstack._internal.core.errors import ClientError, CLIError +from dstack._internal.core.models.users import UserWithCreds +from dstack.api._public.runs import ConfigManager +from dstack.api.server import APIClient + + +class LoginCommand(BaseCommand): + NAME = "login" + DESCRIPTION = "Authorize the CLI using Single Sign-On" + + def _register(self): + super()._register() + self._parser.add_argument( + "--url", + help="The server URL, e.g. https://sky.dstack.ai", + required=True, + ) + self._parser.add_argument( + "-p", + "--provider", + help=( + "The SSO provider name." + " Selected automatically if the server supports only one provider." + ), + ) + + def _command(self, args: argparse.Namespace): + super()._command(args) + base_url = _normalize_url_or_error(args.url) + api_client = APIClient(base_url=base_url) + provider = self._select_provider_or_error(api_client=api_client, provider=args.provider) + server = _LoginServer(api_client=api_client, provider=provider) + try: + server.start() + auth_resp = api_client.auth.authorize(provider=provider, local_port=server.port) + opened = webbrowser.open(auth_resp.authorization_url) + if opened: + console.print( + f"Your browser has been opened to log in with [code]{provider.title()}[/]:\n" + ) + else: + console.print(f"Open the URL to log in with [code]{provider.title()}[/]:\n") + print(f"{auth_resp.authorization_url}\n") + user = server.get_logged_in_user() + finally: + server.shutdown() + if user is None: + raise CLIError("CLI authentication failed") + console.print(f"Logged in as [code]{user.username}[/].") + api_client = APIClient(base_url=base_url, token=user.creds.token) + self._configure_projects(api_client=api_client, user=user) + + def _select_provider_or_error(self, api_client: APIClient, provider: Optional[str]) -> str: + providers = api_client.auth.list_providers() + available_providers = [p.name for p in providers if p.enabled] + if len(available_providers) == 0: + raise CLIError("No SSO providers configured on the server.") + if provider is None: + if len(available_providers) > 1: + raise CLIError( + "Specify -p/--provider to choose SSO provider" + f" Available providers: {', '.join(available_providers)}" + ) + return available_providers[0] + if provider not in available_providers: + raise CLIError( + f"Provider {provider} not configured on the server." + f" Available providers: {', '.join(available_providers)}" + ) + return provider + + def _configure_projects(self, api_client: APIClient, user: UserWithCreds): + projects = api_client.projects.list(include_not_joined=False) + if len(projects) == 0: + console.print( + "No projects configured." + " Create your own project via the UI or contact a project manager to add you to the project." + ) + return + config_manager = ConfigManager() + default_project = config_manager.get_project_config() + new_default_project = None + for i, project in enumerate(projects): + set_as_default = ( + default_project is None + and i == 0 + or default_project is not None + and default_project.name == project.project_name + ) + if set_as_default: + new_default_project = project + config_manager.configure_project( + name=project.project_name, + url=api_client.base_url, + token=user.creds.token, + default=set_as_default, + ) + config_manager.save() + console.print( + f"Configured projects: {', '.join(f'[code]{p.project_name}[/]' for p in projects)}." + ) + if new_default_project: + console.print( + f"Set project [code]{new_default_project.project_name}[/] as default project." + ) + + +class _BadRequestError(Exception): + pass + + +class _LoginServer: + def __init__(self, api_client: APIClient, provider: str): + self._api_client = api_client + self._provider = provider + self._result_queue: queue.Queue[Optional[UserWithCreds]] = queue.Queue() + # Using built-in HTTP server to avoid extra deps. + callback_handler = self._make_callback_handler( + result_queue=self._result_queue, + api_client=api_client, + provider=provider, + ) + self._server = self._create_server(handler=callback_handler) + + def start(self): + self._thread = threading.Thread(target=self._server.serve_forever) + self._thread.start() + + def shutdown(self): + self._server.shutdown() + + def get_logged_in_user(self) -> Optional[UserWithCreds]: + return self._result_queue.get() + + @property + def port(self) -> int: + return self._server.server_port + + def _make_callback_handler( + self, + result_queue: queue.Queue[Optional[UserWithCreds]], + api_client: APIClient, + provider: str, + ) -> type[BaseHTTPRequestHandler]: + class _CallbackHandler(BaseHTTPRequestHandler): + def do_GET(self): + parsed_path = urllib.parse.urlparse(self.path) + if parsed_path.path != "/auth/callback": + self.send_response(404) + self.end_headers() + return + try: + self._handle_auth_callback(parsed_path) + except _BadRequestError as e: + self.send_error(400, e.args[0]) + result_queue.put(None) + + def log_message(self, format: str, *args): + # Do not log server requests. + pass + + def _handle_auth_callback(self, parsed_path: urllib.parse.ParseResult): + try: + params = urllib.parse.parse_qs(parsed_path.query, strict_parsing=True) + except ValueError: + raise _BadRequestError("Bad query params") + code = params.get("code", [None])[0] + state = params.get("state", [None])[0] + if code is None or state is None: + raise _BadRequestError("Missing required params") + try: + user = api_client.auth.callback(provider=provider, code=code, state=state) + except ClientError: + raise _BadRequestError("Authentication failed") + self._send_success_html() + result_queue.put(user) + + def _send_success_html(self): + body = _SUCCESS_HTML.encode() + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + return _CallbackHandler + + def _create_server(self, handler: type[BaseHTTPRequestHandler]) -> HTTPServer: + server_address = ("127.0.0.1", 0) + server = HTTPServer(server_address, handler) + return server + + +def _normalize_url_or_error(url: str) -> str: + if not url.startswith("http://") and not url.startswith("https://"): + url = "http://" + url + parsed = urllib.parse.urlparse(url) + if ( + not parsed.scheme + or not parsed.hostname + or parsed.path not in ("", "/") + or parsed.params + or parsed.query + or parsed.fragment + or (parsed.port is not None and not (1 <= parsed.port <= 65535)) + ): + raise CLIError("Invalid server URL format. Format: --url https://sky.dstack.ai") + return url + + +_SUCCESS_HTML = """\ + + + + + CLI authenticated + + + +

dstack CLI authenticated

+

You may close this page.

+ + +""" diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 98be45b8d5..61f3967ab7 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -12,6 +12,7 @@ from dstack._internal.cli.commands.fleet import FleetCommand from dstack._internal.cli.commands.gateway import GatewayCommand from dstack._internal.cli.commands.init import InitCommand +from dstack._internal.cli.commands.login import LoginCommand from dstack._internal.cli.commands.logs import LogsCommand from dstack._internal.cli.commands.metrics import MetricsCommand from dstack._internal.cli.commands.offer import OfferCommand @@ -68,6 +69,7 @@ def main(): GatewayCommand.register(subparsers) InitCommand.register(subparsers) OfferCommand.register(subparsers) + LoginCommand.register(subparsers) LogsCommand.register(subparsers) MetricsCommand.register(subparsers) ProjectCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/utils/common.py b/src/dstack/_internal/cli/utils/common.py index c75f08b81b..87f0687e1b 100644 --- a/src/dstack/_internal/cli/utils/common.py +++ b/src/dstack/_internal/cli/utils/common.py @@ -21,7 +21,10 @@ "code": "bold sea_green3", } -console = Console(theme=Theme(_colors)) +console = Console( + theme=Theme(_colors), + force_terminal=settings.CLI_RICH_FORCE_TERMINAL, +) LIVE_TABLE_REFRESH_RATE_PER_SEC = 1 diff --git a/src/dstack/_internal/core/models/auth.py b/src/dstack/_internal/core/models/auth.py new file mode 100644 index 0000000000..f6d09fbc73 --- /dev/null +++ b/src/dstack/_internal/core/models/auth.py @@ -0,0 +1,28 @@ +from typing import Annotated, Optional + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class OAuthProviderInfo(CoreModel): + name: Annotated[str, Field(description="The OAuth2 provider name.")] + enabled: Annotated[ + bool, Field(description="Whether the provider is configured on the server.") + ] + + +class OAuthState(CoreModel): + """ + A struct that the server puts in the OAuth2 state parameter. + """ + + value: Annotated[str, Field(description="A random string to protect against CSRF.")] + local_port: Annotated[ + Optional[int], + Field( + description="If specified, the user is redirected to localhost:local_port after the redirect from the provider.", + ge=1, + le=65535, + ), + ] = None diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 9c83bac793..527dd128fe 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -25,6 +25,7 @@ from dstack._internal.server.background.tasks.process_probes import PROBES_SCHEDULER from dstack._internal.server.db import get_db, get_session_ctx, migrate from dstack._internal.server.routers import ( + auth, backends, events, files, @@ -210,6 +211,7 @@ def add_no_api_version_check_routes(paths: List[str]): def register_routes(app: FastAPI, ui: bool = True): app.include_router(server.router) app.include_router(users.router) + app.include_router(auth.router) app.include_router(projects.router) app.include_router(backends.root_router) app.include_router(backends.project_router) diff --git a/src/dstack/_internal/server/routers/auth.py b/src/dstack/_internal/server/routers/auth.py new file mode 100644 index 0000000000..89fe2f57f5 --- /dev/null +++ b/src/dstack/_internal/server/routers/auth.py @@ -0,0 +1,34 @@ +from fastapi import APIRouter + +from dstack._internal.core.models.auth import OAuthProviderInfo +from dstack._internal.server.schemas.auth import ( + OAuthGetNextRedirectRequest, + OAuthGetNextRedirectResponse, +) +from dstack._internal.server.services import auth as auth_services +from dstack._internal.server.utils.routers import CustomORJSONResponse + +router = APIRouter(prefix="/api/auth", tags=["auth"]) + + +@router.post("/list_providers", response_model=list[OAuthProviderInfo]) +async def list_providers(): + """ + Returns OAuth2 providers registered on the server. + """ + return CustomORJSONResponse(auth_services.list_providers()) + + +@router.post("/get_next_redirect", response_model=OAuthGetNextRedirectResponse) +async def get_next_redirect(body: OAuthGetNextRedirectRequest): + """ + A helper endpoint that returns the next redirect URL in case the state encodes it. + Can be used by the UI after the redirect from the provider + to determine if the user needs to be redirected further (CLI login) + or the auth callback endpoint needs to be called directly (UI login). + """ + return CustomORJSONResponse( + OAuthGetNextRedirectResponse( + redirect_url=auth_services.get_next_redirect_url(code=body.code, state=body.state) + ) + ) diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py index 56d41b6ca0..d35b9535e8 100644 --- a/src/dstack/_internal/server/routers/projects.py +++ b/src/dstack/_internal/server/routers/projects.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -10,6 +10,7 @@ AddProjectMemberRequest, CreateProjectRequest, DeleteProjectsRequest, + ListProjectsRequest, RemoveProjectMemberRequest, SetProjectMembersRequest, UpdateProjectRequest, @@ -37,6 +38,7 @@ @router.post("/list", response_model=List[Project]) async def list_projects( + body: Optional[ListProjectsRequest] = None, session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), ): @@ -45,8 +47,13 @@ async def list_projects( `members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them. """ + if body is None: + # For backward compatibility + body = ListProjectsRequest() return CustomORJSONResponse( - await projects.list_user_accessible_projects(session=session, user=user) + await projects.list_user_accessible_projects( + session=session, user=user, include_not_joined=body.include_not_joined + ) ) diff --git a/src/dstack/_internal/server/schemas/auth.py b/src/dstack/_internal/server/schemas/auth.py new file mode 100644 index 0000000000..942f1fb388 --- /dev/null +++ b/src/dstack/_internal/server/schemas/auth.py @@ -0,0 +1,83 @@ +from typing import Annotated, Optional + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class OAuthInfoResponse(CoreModel): + enabled: Annotated[ + bool, Field(description="Whether the OAuth2 provider is configured on the server.") + ] + + +class OAuthAuthorizeRequest(CoreModel): + local_port: Annotated[ + Optional[int], + Field( + description="If specified, the user is redirected to localhost:local_port after the redirect from the provider.", + ge=1, + le=65535, + ), + ] = None + base_url: Annotated[ + Optional[str], + Field( + description=( + "The server base URL used to access the dstack server, e.g. `http://localhost:3000`." + " Used to build redirect URLs when the dstack server is available on multiple domains." + ) + ), + ] = None + + +class OAuthAuthorizeResponse(CoreModel): + authorization_url: Annotated[str, Field(description="An OAuth2 authorization URL.")] + + +class OAuthCallbackRequest(CoreModel): + code: Annotated[ + str, + Field( + description="The OAuth2 authorization code received from the provider in the redirect URL." + ), + ] + state: Annotated[ + str, + Field(description="The state parameter received from the provider in the redirect URL."), + ] + base_url: Annotated[ + Optional[str], + Field( + description=( + "The server base URL used to access the dstack server, e.g. `http://localhost:3000`." + " Used to build redirect URLs when the dstack server is available on multiple domains." + " It must match the base URL specified when generating the authorization URL." + ) + ), + ] = None + + +class OAuthGetNextRedirectRequest(CoreModel): + code: Annotated[ + str, + Field( + description="The OAuth2 authorization code received from the provider in the redirect URL." + ), + ] + state: Annotated[ + str, + Field(description="The state parameter received from the provider in the redirect URL."), + ] + + +class OAuthGetNextRedirectResponse(CoreModel): + redirect_url: Annotated[ + Optional[str], + Field( + description=( + "The URL that the user needs to be redirected to." + " If `null`, there is no next redirect." + ) + ), + ] diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py index 355bb3a770..ec05c1fb47 100644 --- a/src/dstack/_internal/server/schemas/projects.py +++ b/src/dstack/_internal/server/schemas/projects.py @@ -6,6 +6,12 @@ from dstack._internal.core.models.users import ProjectRole +class ListProjectsRequest(CoreModel): + include_not_joined: Annotated[ + bool, Field(description="Include public projects where user is not a member") + ] = True + + class CreateProjectRequest(CoreModel): project_name: str is_public: bool = False diff --git a/src/dstack/_internal/server/services/auth.py b/src/dstack/_internal/server/services/auth.py new file mode 100644 index 0000000000..8ea40994f3 --- /dev/null +++ b/src/dstack/_internal/server/services/auth.py @@ -0,0 +1,77 @@ +import secrets +import urllib.parse +from base64 import b64decode, b64encode +from typing import Optional + +from fastapi import Request, Response + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.auth import OAuthProviderInfo, OAuthState +from dstack._internal.server import settings +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +_OAUTH_STATE_COOKIE_KEY = "oauth-state" + +_OAUTH_PROVIDERS: list[OAuthProviderInfo] = [] + + +def register_provider(provider_info: OAuthProviderInfo): + """ + Registers an OAuth2 provider supported on the server. + If the provider is supported but not configured, it should be registered with `enabled=False`. + The provider must register endpoints `/api/auth/{provider}/authorize` and `/api/auth/{provider}/callback` + as defined by the client (see `dstack.api.server._auth.AuthAPIClient`). + """ + _OAUTH_PROVIDERS.append(provider_info) + + +def list_providers() -> list[OAuthProviderInfo]: + return _OAUTH_PROVIDERS + + +def generate_oauth_state(local_port: Optional[int] = None) -> str: + value = str(secrets.token_hex(16)) + state = OAuthState(value=value, local_port=local_port) + return b64encode(state.json().encode()).decode() + + +def set_state_cookie(response: Response, state: str): + response.set_cookie( + key=_OAUTH_STATE_COOKIE_KEY, + value=state, + secure=settings.SERVER_URL.startswith("https://"), + samesite="strict", + httponly=True, + ) + + +def get_validated_state(request: Request, state: str) -> OAuthState: + state_cookie = request.cookies.get(_OAUTH_STATE_COOKIE_KEY) + if state != state_cookie: + raise ServerClientError("Invalid state token") + decoded_state = _decode_state(state) + if decoded_state is None: + raise ServerClientError("Invalid state token") + return decoded_state + + +def get_next_redirect_url(code: str, state: str) -> Optional[str]: + decoded_state = _decode_state(state) + if decoded_state is None: + raise ServerClientError("Invalid state token") + if decoded_state.local_port is None: + return None + params = {"code": code, "state": state} + redirect_url = f"http://localhost:{decoded_state.local_port}/auth/callback?{urllib.parse.urlencode(params)}" + return redirect_url + + +def _decode_state(state: str) -> Optional[OAuthState]: + try: + return OAuthState.parse_raw(b64decode(state, validate=True).decode()) + except Exception as e: + logger.debug("Exception when decoding OAuth2 state parameter: %s", repr(e)) + return None diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 5e4842df56..3ef6c32785 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -83,18 +83,22 @@ async def list_user_projects( async def list_user_accessible_projects( session: AsyncSession, user: UserModel, + include_not_joined: bool, ) -> List[Project]: """ Returns all projects accessible to the user: - Projects where user is a member (public or private) - - Public projects where user is NOT a member + - if `include_not_joined`: Public projects where user is NOT a member """ if user.global_role == GlobalRole.ADMIN: projects = await list_project_models(session=session) else: - member_projects = await list_member_project_models(session=session, user=user) - public_projects = await list_public_non_member_project_models(session=session, user=user) - projects = member_projects + public_projects + projects = await list_member_project_models(session=session, user=user) + if include_not_joined: + public_projects = await list_public_non_member_project_models( + session=session, user=user + ) + projects += public_projects projects = sorted(projects, key=lambda p: p.created_at) return [ diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index 81682480a2..6089e37c07 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -1,6 +1,7 @@ import os from dstack import version +from dstack._internal.utils.env import environ from dstack._internal.utils.version import parse_version DSTACK_VERSION = os.getenv("DSTACK_VERSION", version.__version__) @@ -28,6 +29,8 @@ CLI_LOG_LEVEL = os.getenv("DSTACK_CLI_LOG_LEVEL", "INFO").upper() CLI_FILE_LOG_LEVEL = os.getenv("DSTACK_CLI_FILE_LOG_LEVEL", "DEBUG").upper() +# Can be used to disable control characters (e.g. for testing). +CLI_RICH_FORCE_TERMINAL = environ.get_bool("DSTACK_CLI_RICH_FORCE_TERMINAL") # Development settings diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 2ad94f0864..5d6ea08604 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -14,6 +14,7 @@ URLNotFoundError, ) from dstack._internal.utils.logging import get_logger +from dstack.api.server._auth import AuthAPIClient from dstack.api.server._backends import BackendsAPIClient from dstack.api.server._events import EventsAPIClient from dstack.api.server._files import FilesAPIClient @@ -52,16 +53,18 @@ class APIClient: files: operations with files """ - def __init__(self, base_url: str, token: str): + def __init__(self, base_url: str, token: Optional[str] = None): """ Args: base_url: The API endpoints prefix, e.g. `http://127.0.0.1:3000/`. token: The API token. """ self._base_url = base_url.rstrip("/") - self._token = token self._s = requests.session() - self._s.headers.update({"Authorization": f"Bearer {token}"}) + self._token = None + if token is not None: + self._token = token + self._s.headers.update({"Authorization": f"Bearer {token}"}) client_api_version = os.getenv("DSTACK_CLIENT_API_VERSION", version.__version__) if client_api_version is not None: self._s.headers.update({"X-API-VERSION": client_api_version}) @@ -71,6 +74,10 @@ def __init__(self, base_url: str, token: str): def base_url(self) -> str: return self._base_url + @property + def auth(self) -> AuthAPIClient: + return AuthAPIClient(self._request, self._logger) + @property def users(self) -> UsersAPIClient: return UsersAPIClient(self._request, self._logger) @@ -128,6 +135,8 @@ def events(self) -> EventsAPIClient: return EventsAPIClient(self._request, self._logger) def get_token_hash(self) -> str: + if self._token is None: + raise ValueError("Token not set") return hashlib.sha1(self._token.encode()).hexdigest()[:8] def _request( diff --git a/src/dstack/api/server/_auth.py b/src/dstack/api/server/_auth.py new file mode 100644 index 0000000000..b944a292a2 --- /dev/null +++ b/src/dstack/api/server/_auth.py @@ -0,0 +1,30 @@ +from typing import Optional + +from pydantic import parse_obj_as + +from dstack._internal.core.models.auth import OAuthProviderInfo +from dstack._internal.core.models.users import UserWithCreds +from dstack._internal.server.schemas.auth import ( + OAuthAuthorizeRequest, + OAuthAuthorizeResponse, + OAuthCallbackRequest, +) +from dstack.api.server._group import APIClientGroup + + +class AuthAPIClient(APIClientGroup): + def list_providers(self) -> list[OAuthProviderInfo]: + resp = self._request("/api/auth/list_providers") + return parse_obj_as(list[OAuthProviderInfo.__response__], resp.json()) + + def authorize(self, provider: str, local_port: Optional[int] = None) -> OAuthAuthorizeResponse: + body = OAuthAuthorizeRequest(local_port=local_port) + resp = self._request(f"/api/auth/{provider}/authorize", body=body.json()) + return parse_obj_as(OAuthAuthorizeResponse.__response__, resp.json()) + + def callback( + self, provider: str, code: str, state: str, base_url: Optional[str] = None + ) -> UserWithCreds: + body = OAuthCallbackRequest(code=code, state=state, base_url=base_url) + resp = self._request(f"/api/auth/{provider}/callback", body=body.json()) + return parse_obj_as(UserWithCreds.__response__, resp.json()) diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py index 0fb47c9ab5..31bdc3b2de 100644 --- a/src/dstack/api/server/_projects.py +++ b/src/dstack/api/server/_projects.py @@ -8,6 +8,7 @@ AddProjectMemberRequest, CreateProjectRequest, DeleteProjectsRequest, + ListProjectsRequest, MemberSetting, RemoveProjectMemberRequest, SetProjectMembersRequest, @@ -16,8 +17,9 @@ class ProjectsAPIClient(APIClientGroup): - def list(self) -> List[Project]: - resp = self._request("/api/projects/list") + def list(self, include_not_joined: bool = True) -> List[Project]: + body = ListProjectsRequest(include_not_joined=include_not_joined) + resp = self._request("/api/projects/list", body=body.json()) return parse_obj_as(List[Project.__response__], resp.json()) def create(self, project_name: str, is_public: bool = False) -> Project: diff --git a/src/tests/_internal/cli/commands/test_login.py b/src/tests/_internal/cli/commands/test_login.py new file mode 100644 index 0000000000..42b46c2b73 --- /dev/null +++ b/src/tests/_internal/cli/commands/test_login.py @@ -0,0 +1,103 @@ +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import call, patch + +from pytest import CaptureFixture + +from tests._internal.cli.common import run_dstack_cli + + +class TestLogin: + def test_login_no_projects(self, capsys: CaptureFixture, tmp_path: Path): + with ( + patch("dstack._internal.cli.commands.login.webbrowser") as webbrowser_mock, + patch("dstack._internal.cli.commands.login.APIClient") as APIClientMock, + patch("dstack._internal.cli.commands.login._LoginServer") as LoginServerMock, + ): + webbrowser_mock.open.return_value = True + APIClientMock.return_value.auth.list_providers.return_value = [ + SimpleNamespace(name="github", enabled=True) + ] + APIClientMock.return_value.auth.authorize.return_value = SimpleNamespace( + authorization_url="http://auth_url" + ) + APIClientMock.return_value.projects.list.return_value = [] + user = SimpleNamespace(username="me", creds=SimpleNamespace(token="token")) + LoginServerMock.return_value.get_logged_in_user.return_value = user + exit_code = run_dstack_cli( + [ + "login", + "--url", + "http://127.0.0.1:31313", + "--provider", + "github", + ], + home_dir=tmp_path, + ) + + assert exit_code == 0 + assert capsys.readouterr().out.replace("\n", "") == ( + "Your browser has been opened to log in with Github:" + "http://auth_url" + "Logged in as me." + "No projects configured. Create your own project via the UI or contact a project manager to add you to the project." + ) + + def test_login_configures_projects(self, capsys: CaptureFixture, tmp_path: Path): + with ( + patch("dstack._internal.cli.commands.login.webbrowser") as webbrowser_mock, + patch("dstack._internal.cli.commands.login.APIClient") as APIClientMock, + patch("dstack._internal.cli.commands.login.ConfigManager") as ConfigManagerMock, + patch("dstack._internal.cli.commands.login._LoginServer") as LoginServerMock, + ): + webbrowser_mock.open.return_value = True + APIClientMock.return_value.auth.list_providers.return_value = [ + SimpleNamespace(name="github", enabled=True) + ] + APIClientMock.return_value.auth.authorize.return_value = SimpleNamespace( + authorization_url="http://auth_url" + ) + APIClientMock.return_value.projects.list.return_value = [ + SimpleNamespace(project_name="project1"), + SimpleNamespace(project_name="project2"), + ] + APIClientMock.return_value.base_url = "http://127.0.0.1:31313" + ConfigManagerMock.return_value.get_project_config.return_value = None + user = SimpleNamespace(username="me", creds=SimpleNamespace(token="token")) + LoginServerMock.return_value.get_logged_in_user.return_value = user + exit_code = run_dstack_cli( + [ + "login", + "--url", + "http://127.0.0.1:31313", + "--provider", + "github", + ], + home_dir=tmp_path, + ) + ConfigManagerMock.return_value.configure_project.assert_has_calls( + [ + call( + name="project1", + url="http://127.0.0.1:31313", + token=user.creds.token, + default=True, + ), + call( + name="project2", + url="http://127.0.0.1:31313", + token=user.creds.token, + default=False, + ), + ] + ) + ConfigManagerMock.return_value.save.assert_called() + + assert exit_code == 0 + assert capsys.readouterr().out.replace("\n", "") == ( + "Your browser has been opened to log in with Github:" + "http://auth_url" + "Logged in as me." + "Configured projects: project1, project2." + "Set project project1 as default project." + ) diff --git a/src/tests/_internal/cli/common.py b/src/tests/_internal/cli/common.py index 8b4a370ea6..09f4541c7e 100644 --- a/src/tests/_internal/cli/common.py +++ b/src/tests/_internal/cli/common.py @@ -7,7 +7,7 @@ def run_dstack_cli( - args: List[str], + cli_args: List[str], home_dir: Optional[Path] = None, repo_dir: Optional[Path] = None, ) -> int: @@ -18,13 +18,14 @@ def run_dstack_cli( if home_dir is not None: prev_home_dir = os.environ["HOME"] os.environ["HOME"] = str(home_dir) - with patch("sys.argv", ["dstack"] + args): + with patch("sys.argv", ["dstack"] + cli_args): try: main() except SystemExit as e: exit_code = e.code - if home_dir is not None: - os.environ["HOME"] = prev_home_dir - if repo_dir is not None: - os.chdir(cwd) + finally: + if home_dir is not None: + os.environ["HOME"] = prev_home_dir + if repo_dir is not None: + os.chdir(cwd) return exit_code diff --git a/src/tests/_internal/server/routers/test_auth.py b/src/tests/_internal/server/routers/test_auth.py new file mode 100644 index 0000000000..f4c8bb0e59 --- /dev/null +++ b/src/tests/_internal/server/routers/test_auth.py @@ -0,0 +1,64 @@ +import json +from base64 import b64encode + +import pytest +from httpx import AsyncClient + +from dstack._internal.core.models.auth import OAuthProviderInfo +from dstack._internal.server.services.auth import register_provider + + +class TestListProviders: + @pytest.mark.asyncio + async def test_returns_no_providers(self, client: AsyncClient): + response = await client.post("/api/auth/list_providers") + assert response.status_code == 200 + assert response.json() == [] + + @pytest.mark.asyncio + async def test_returns_registered_providers(self, client: AsyncClient): + register_provider(OAuthProviderInfo(name="provider1", enabled=True)) + register_provider(OAuthProviderInfo(name="provider2", enabled=False)) + response = await client.post("/api/auth/list_providers") + assert response.status_code == 200 + assert response.json() == [ + { + "name": "provider1", + "enabled": True, + }, + { + "name": "provider2", + "enabled": False, + }, + ] + + +class TestGetNextRedirectURL: + @pytest.mark.asyncio + async def test_returns_no_redirect_url_if_local_port_not_set(self, client: AsyncClient): + state = b64encode(json.dumps({"value": "12356", "local_port": None}).encode()).decode() + response = await client.post( + "/api/auth/get_next_redirect", json={"code": "1234", "state": state} + ) + assert response.status_code == 200 + assert response.json() == {"redirect_url": None} + + @pytest.mark.asyncio + async def test_returns_redirect_url_if_local_port_set(self, client: AsyncClient): + state = b64encode(json.dumps({"value": "12356", "local_port": 12345}).encode()).decode() + response = await client.post( + "/api/auth/get_next_redirect", json={"code": "1234", "state": state} + ) + assert response.status_code == 200 + assert response.json() == { + "redirect_url": f"http://localhost:12345/auth/callback?code=1234&state={state}" + } + + @pytest.mark.asyncio + async def test_returns_400_if_state_invalid(self, client: AsyncClient): + state = "some_invalid_state" + response = await client.post( + "/api/auth/get_next_redirect", json={"code": "1234", "state": state} + ) + assert response.status_code == 400 + assert "Invalid state token" in response.json()["detail"][0]["msg"]