diff --git a/docs/docs/reference/cli/dstack/login.md b/docs/docs/reference/cli/dstack/login.md
new file mode 100644
index 0000000000..d608476e27
--- /dev/null
+++ b/docs/docs/reference/cli/dstack/login.md
@@ -0,0 +1,17 @@
+# dstack login
+
+This command authorizes the CLI using Single Sign-On and automatically configures your projects.
+It provides an alternative to `dstack project add`.
+
+## Usage
+
+
+
+```shell
+$ dstack login --help
+#GENERATE#
+```
+
+
+
+[//]: # (TODO: Provide examples)
diff --git a/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx b/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx
index aa70d00797..036851c3cf 100644
--- a/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx
+++ b/frontend/src/App/Login/EntraID/LoginByEntraIDCallback/index.tsx
@@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout';
import { useAppDispatch } from 'hooks';
import { ROUTES } from 'routes';
-import { useEntraCallbackMutation } from 'services/auth';
+import { useEntraCallbackMutation, useGetNextRedirectMutation } from 'services/auth';
import { AuthErrorMessage } from 'App/AuthErrorMessage';
import { getBaseUrl } from 'App/helpers';
@@ -23,15 +23,27 @@ export const LoginByEntraIDCallback: React.FC = () => {
const [isInvalidCode, setIsInvalidCode] = useState(false);
const dispatch = useAppDispatch();
+ const [getNextRedirect] = useGetNextRedirectMutation();
const [entraCallback] = useEntraCallbackMutation();
const checkCode = () => {
if (code && state) {
- entraCallback({ code, state, base_url: getBaseUrl() })
+ getNextRedirect({ code, state })
.unwrap()
- .then(({ creds: { token } }) => {
- dispatch(setAuthData({ token }));
- navigate('/');
+ .then(({ redirect_url }) => {
+ if (redirect_url) {
+ window.location.href = redirect_url;
+ return;
+ }
+ entraCallback({ code, state, base_url: getBaseUrl() })
+ .unwrap()
+ .then(({ creds: { token } }) => {
+ dispatch(setAuthData({ token }));
+ navigate('/');
+ })
+ .catch(() => {
+ setIsInvalidCode(true);
+ });
})
.catch(() => {
setIsInvalidCode(true);
diff --git a/frontend/src/App/Login/LoginByGithubCallback/index.tsx b/frontend/src/App/Login/LoginByGithubCallback/index.tsx
index 27d5a755a7..af88aa72f1 100644
--- a/frontend/src/App/Login/LoginByGithubCallback/index.tsx
+++ b/frontend/src/App/Login/LoginByGithubCallback/index.tsx
@@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout';
import { useAppDispatch } from 'hooks';
import { ROUTES } from 'routes';
-import { useGithubCallbackMutation } from 'services/auth';
+import { useGetNextRedirectMutation, useGithubCallbackMutation } from 'services/auth';
import { useLazyGetProjectsQuery } from 'services/project';
import { AuthErrorMessage } from 'App/AuthErrorMessage';
@@ -23,26 +23,35 @@ export const LoginByGithubCallback: React.FC = () => {
const [isInvalidCode, setIsInvalidCode] = useState(false);
const dispatch = useAppDispatch();
+ const [getNextRedirect] = useGetNextRedirectMutation();
const [githubCallback] = useGithubCallbackMutation();
const [getProjects] = useLazyGetProjectsQuery();
const checkCode = () => {
if (code && state) {
- githubCallback({ code, state })
+ getNextRedirect({ code: code, state: state })
.unwrap()
- .then(async ({ creds: { token } }) => {
- dispatch(setAuthData({ token }));
-
- if (process.env.UI_VERSION === 'sky') {
- const result = await getProjects().unwrap();
-
- if (result?.length === 0) {
- navigate(ROUTES.PROJECT.ADD);
- return;
- }
+ .then(async ({ redirect_url }) => {
+ if (redirect_url) {
+ window.location.href = redirect_url;
+ return;
}
-
- navigate('/');
+ githubCallback({ code, state })
+ .unwrap()
+ .then(async ({ creds: { token } }) => {
+ dispatch(setAuthData({ token }));
+ if (process.env.UI_VERSION === 'sky') {
+ const result = await getProjects().unwrap();
+ if (result?.length === 0) {
+ navigate(ROUTES.PROJECT.ADD);
+ return;
+ }
+ }
+ navigate('/');
+ })
+ .catch(() => {
+ setIsInvalidCode(true);
+ });
})
.catch(() => {
setIsInvalidCode(true);
diff --git a/frontend/src/App/Login/LoginByGoogleCallback/index.tsx b/frontend/src/App/Login/LoginByGoogleCallback/index.tsx
index 465d0be3ee..4f95f94e27 100644
--- a/frontend/src/App/Login/LoginByGoogleCallback/index.tsx
+++ b/frontend/src/App/Login/LoginByGoogleCallback/index.tsx
@@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout';
import { useAppDispatch } from 'hooks';
import { ROUTES } from 'routes';
-import { useGoogleCallbackMutation } from 'services/auth';
+import { useGetNextRedirectMutation, useGoogleCallbackMutation } from 'services/auth';
import { AuthErrorMessage } from 'App/AuthErrorMessage';
import { Loading } from 'App/Loading';
@@ -22,15 +22,27 @@ export const LoginByGoogleCallback: React.FC = () => {
const [isInvalidCode, setIsInvalidCode] = useState(false);
const dispatch = useAppDispatch();
+ const [getNextRedirect] = useGetNextRedirectMutation();
const [googleCallback] = useGoogleCallbackMutation();
const checkCode = () => {
if (code && state) {
- googleCallback({ code, state })
+ getNextRedirect({ code, state })
.unwrap()
- .then(({ creds: { token } }) => {
- dispatch(setAuthData({ token }));
- navigate('/');
+ .then(({ redirect_url }) => {
+ if (redirect_url) {
+ window.location.href = redirect_url;
+ return;
+ }
+ googleCallback({ code, state })
+ .unwrap()
+ .then(({ creds: { token } }) => {
+ dispatch(setAuthData({ token }));
+ navigate('/');
+ })
+ .catch(() => {
+ setIsInvalidCode(true);
+ });
})
.catch(() => {
setIsInvalidCode(true);
diff --git a/frontend/src/App/Login/LoginByOktaCallback/index.tsx b/frontend/src/App/Login/LoginByOktaCallback/index.tsx
index ccc9fbc749..72cdc96185 100644
--- a/frontend/src/App/Login/LoginByOktaCallback/index.tsx
+++ b/frontend/src/App/Login/LoginByOktaCallback/index.tsx
@@ -7,7 +7,7 @@ import { UnauthorizedLayout } from 'layouts/UnauthorizedLayout';
import { useAppDispatch } from 'hooks';
import { ROUTES } from 'routes';
-import { useOktaCallbackMutation } from 'services/auth';
+import { useGetNextRedirectMutation, useOktaCallbackMutation } from 'services/auth';
import { AuthErrorMessage } from 'App/AuthErrorMessage';
import { Loading } from 'App/Loading';
@@ -22,15 +22,27 @@ export const LoginByOktaCallback: React.FC = () => {
const [isInvalidCode, setIsInvalidCode] = useState(false);
const dispatch = useAppDispatch();
+ const [getNextRedirect] = useGetNextRedirectMutation();
const [oktaCallback] = useOktaCallbackMutation();
const checkCode = () => {
if (code && state) {
- oktaCallback({ code, state })
+ getNextRedirect({ code, state })
.unwrap()
- .then(({ creds: { token } }) => {
- dispatch(setAuthData({ token }));
- navigate('/');
+ .then(({ redirect_url }) => {
+ if (redirect_url) {
+ window.location.href = redirect_url;
+ return;
+ }
+ oktaCallback({ code, state })
+ .unwrap()
+ .then(({ creds: { token } }) => {
+ dispatch(setAuthData({ token }));
+ navigate('/');
+ })
+ .catch(() => {
+ setIsInvalidCode(true);
+ });
})
.catch(() => {
setIsInvalidCode(true);
diff --git a/frontend/src/api.ts b/frontend/src/api.ts
index 2dea526601..262aa46b75 100644
--- a/frontend/src/api.ts
+++ b/frontend/src/api.ts
@@ -5,6 +5,7 @@ export const API = {
AUTH: {
BASE: () => `${API.BASE()}/auth`,
+ NEXT_REDIRECT: () => `${API.AUTH.BASE()}/get_next_redirect`,
GITHUB: {
BASE: () => `${API.AUTH.BASE()}/github`,
AUTHORIZE: () => `${API.AUTH.GITHUB.BASE()}/authorize`,
diff --git a/frontend/src/services/auth.ts b/frontend/src/services/auth.ts
index f65892911a..2512ed0a7d 100644
--- a/frontend/src/services/auth.ts
+++ b/frontend/src/services/auth.ts
@@ -12,6 +12,14 @@ export const authApi = createApi({
tagTypes: ['Auth'],
endpoints: (builder) => ({
+ getNextRedirect: builder.mutation<{ redirect_url?: string }, { code: string; state: string }>({
+ query: (body) => ({
+ url: API.AUTH.NEXT_REDIRECT(),
+ method: 'POST',
+ body,
+ }),
+ }),
+
githubAuthorize: builder.mutation<{ authorization_url: string }, void>({
query: () => ({
url: API.AUTH.GITHUB.AUTHORIZE(),
@@ -103,6 +111,7 @@ export const authApi = createApi({
});
export const {
+ useGetNextRedirectMutation,
useGithubAuthorizeMutation,
useGithubCallbackMutation,
useGetOktaInfoQuery,
diff --git a/mkdocs.yml b/mkdocs.yml
index e793bd23c2..74939703e3 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -112,67 +112,67 @@ plugins:
background_color: "black"
color: "#FFFFFF"
font_family: "Roboto"
-# debug: true
+ # debug: true
cards_layout_dir: docs/layouts
cards_layout: custom
- search
- redirects:
redirect_maps:
- 'blog/2024/02/08/resources-authentication-and-more.md': 'https://github.com/dstackai/dstack/releases/0.15.0'
- 'blog/2024/01/19/openai-endpoints-preview.md': 'https://github.com/dstackai/dstack/releases/0.14.0'
- 'blog/2023/12/22/disk-size-cuda-12-1-mixtral-and-more.md': 'https://github.com/dstackai/dstack/releases/0.13.0'
- 'blog/2023/11/21/vastai.md': 'https://github.com/dstackai/dstack/releases/0.12.3'
- 'blog/2023/10/31/tensordock.md': 'https://github.com/dstackai/dstack/releases/0.12.2'
- 'blog/2023/10/18/simplified-cloud-setup.md': 'https://github.com/dstackai/dstack/releases/0.12.0'
- 'blog/2023/08/22/multiple-clouds.md': 'https://github.com/dstackai/dstack/releases/0.11'
- 'blog/2023/08/07/services-preview.md': 'https://github.com/dstackai/dstack/releases/0.10.7'
- 'blog/2023/07/14/lambda-cloud-ga-and-docker-support.md': 'https://github.com/dstackai/dstack/releases/0.10.5'
- 'blog/2023/05/22/azure-support-better-ui-and-more.md': 'https://github.com/dstackai/dstack/releases/0.9.1'
- 'blog/2023/03/13/gcp-support-just-landed.md': 'https://github.com/dstackai/dstack/releases/0.2'
- 'blog/dstack-research.md': 'https://dstack.ai/#get-started'
- 'docs/dev-environments.md': 'docs/concepts/dev-environments.md'
- 'docs/tasks.md': 'docs/concepts/tasks.md'
- 'docs/services.md': 'docs/concepts/services.md'
- 'docs/fleets.md': 'docs/concepts/fleets.md'
- 'docs/examples/llms/llama31.md': 'examples/llms/llama/index.md'
- 'docs/examples/llms/llama32.md': 'examples/llms/llama/index.md'
- 'examples/llms/llama31/index.md': 'examples/llms/llama/index.md'
- 'examples/llms/llama32/index.md': 'examples/llms/llama/index.md'
- 'docs/examples/accelerators/amd/index.md': 'examples/accelerators/amd/index.md'
- 'docs/examples/deployment/nim/index.md': 'examples/inference/nim/index.md'
- 'docs/examples/deployment/vllm/index.md': 'examples/inference/vllm/index.md'
- 'docs/examples/deployment/tgi/index.md': 'examples/inference/tgi/index.md'
- 'providers.md': 'partners.md'
- 'backends.md': 'partners.md'
- 'blog/monitoring-gpu-usage.md': 'blog/posts/dstack-metrics.md'
- 'blog/inactive-dev-environments-auto-shutdown.md': 'blog/posts/inactivity-duration.md'
- 'blog/data-centers-and-private-clouds.md': 'blog/posts/gpu-blocks-and-proxy-jump.md'
- 'blog/distributed-training-with-aws-efa.md': 'examples/clusters/aws/index.md'
- 'blog/dstack-stats.md': 'blog/posts/dstack-metrics.md'
- 'docs/concepts/metrics.md': 'docs/guides/metrics.md'
- 'docs/guides/monitoring.md': 'docs/guides/metrics.md'
- 'blog/nvidia-and-amd-on-vultr.md.md': 'blog/posts/nvidia-and-amd-on-vultr.md'
- 'examples/misc/nccl-tests/index.md': 'examples/clusters/nccl-rccl-tests/index.md'
- 'examples/misc/a3high-clusters/index.md': 'examples/clusters/gcp/index.md'
- 'examples/misc/a3mega-clusters/index.md': 'examples/clusters/gcp/index.md'
- 'examples/distributed-training/nccl-tests/index.md': 'examples/clusters/nccl-rccl-tests/index.md'
- 'examples/distributed-training/rccl-tests/index.md': 'examples/clusters/nccl-rccl-tests/index.md'
- 'examples/deployment/nim/index.md': 'examples/inference/nim/index.md'
- 'examples/deployment/vllm/index.md': 'examples/inference/vllm/index.md'
- 'examples/deployment/tgi/index.md': 'examples/inference/tgi/index.md'
- 'examples/deployment/sglang/index.md': 'examples/inference/sglang/index.md'
- 'examples/deployment/trtllm/index.md': 'examples/inference/trtllm/index.md'
- 'examples/fine-tuning/trl/index.md': 'examples/single-node-training/trl/index.md'
- 'examples/fine-tuning/axolotl/index.md': 'examples/single-node-training/axolotl/index.md'
- 'blog/efa.md': 'examples/clusters/aws/index.md'
- 'docs/concepts/repos.md': 'docs/concepts/dev-environments.md#repos'
- 'examples/clusters/a3high/index.md': 'examples/clusters/gcp/index.md'
- 'examples/clusters/a3mega/index.md': 'examples/clusters/gcp/index.md'
- 'examples/clusters/a4/index.md': 'examples/clusters/gcp/index.md'
- 'examples/clusters/efa/index.md': 'examples/clusters/aws/index.md'
+ "blog/2024/02/08/resources-authentication-and-more.md": "https://github.com/dstackai/dstack/releases/0.15.0"
+ "blog/2024/01/19/openai-endpoints-preview.md": "https://github.com/dstackai/dstack/releases/0.14.0"
+ "blog/2023/12/22/disk-size-cuda-12-1-mixtral-and-more.md": "https://github.com/dstackai/dstack/releases/0.13.0"
+ "blog/2023/11/21/vastai.md": "https://github.com/dstackai/dstack/releases/0.12.3"
+ "blog/2023/10/31/tensordock.md": "https://github.com/dstackai/dstack/releases/0.12.2"
+ "blog/2023/10/18/simplified-cloud-setup.md": "https://github.com/dstackai/dstack/releases/0.12.0"
+ "blog/2023/08/22/multiple-clouds.md": "https://github.com/dstackai/dstack/releases/0.11"
+ "blog/2023/08/07/services-preview.md": "https://github.com/dstackai/dstack/releases/0.10.7"
+ "blog/2023/07/14/lambda-cloud-ga-and-docker-support.md": "https://github.com/dstackai/dstack/releases/0.10.5"
+ "blog/2023/05/22/azure-support-better-ui-and-more.md": "https://github.com/dstackai/dstack/releases/0.9.1"
+ "blog/2023/03/13/gcp-support-just-landed.md": "https://github.com/dstackai/dstack/releases/0.2"
+ "blog/dstack-research.md": "https://dstack.ai/#get-started"
+ "docs/dev-environments.md": "docs/concepts/dev-environments.md"
+ "docs/tasks.md": "docs/concepts/tasks.md"
+ "docs/services.md": "docs/concepts/services.md"
+ "docs/fleets.md": "docs/concepts/fleets.md"
+ "docs/examples/llms/llama31.md": "examples/llms/llama/index.md"
+ "docs/examples/llms/llama32.md": "examples/llms/llama/index.md"
+ "examples/llms/llama31/index.md": "examples/llms/llama/index.md"
+ "examples/llms/llama32/index.md": "examples/llms/llama/index.md"
+ "docs/examples/accelerators/amd/index.md": "examples/accelerators/amd/index.md"
+ "docs/examples/deployment/nim/index.md": "examples/inference/nim/index.md"
+ "docs/examples/deployment/vllm/index.md": "examples/inference/vllm/index.md"
+ "docs/examples/deployment/tgi/index.md": "examples/inference/tgi/index.md"
+ "providers.md": "partners.md"
+ "backends.md": "partners.md"
+ "blog/monitoring-gpu-usage.md": "blog/posts/dstack-metrics.md"
+ "blog/inactive-dev-environments-auto-shutdown.md": "blog/posts/inactivity-duration.md"
+ "blog/data-centers-and-private-clouds.md": "blog/posts/gpu-blocks-and-proxy-jump.md"
+ "blog/distributed-training-with-aws-efa.md": "examples/clusters/aws/index.md"
+ "blog/dstack-stats.md": "blog/posts/dstack-metrics.md"
+ "docs/concepts/metrics.md": "docs/guides/metrics.md"
+ "docs/guides/monitoring.md": "docs/guides/metrics.md"
+ "blog/nvidia-and-amd-on-vultr.md.md": "blog/posts/nvidia-and-amd-on-vultr.md"
+ "examples/misc/nccl-tests/index.md": "examples/clusters/nccl-rccl-tests/index.md"
+ "examples/misc/a3high-clusters/index.md": "examples/clusters/gcp/index.md"
+ "examples/misc/a3mega-clusters/index.md": "examples/clusters/gcp/index.md"
+ "examples/distributed-training/nccl-tests/index.md": "examples/clusters/nccl-rccl-tests/index.md"
+ "examples/distributed-training/rccl-tests/index.md": "examples/clusters/nccl-rccl-tests/index.md"
+ "examples/deployment/nim/index.md": "examples/inference/nim/index.md"
+ "examples/deployment/vllm/index.md": "examples/inference/vllm/index.md"
+ "examples/deployment/tgi/index.md": "examples/inference/tgi/index.md"
+ "examples/deployment/sglang/index.md": "examples/inference/sglang/index.md"
+ "examples/deployment/trtllm/index.md": "examples/inference/trtllm/index.md"
+ "examples/fine-tuning/trl/index.md": "examples/single-node-training/trl/index.md"
+ "examples/fine-tuning/axolotl/index.md": "examples/single-node-training/axolotl/index.md"
+ "blog/efa.md": "examples/clusters/aws/index.md"
+ "docs/concepts/repos.md": "docs/concepts/dev-environments.md#repos"
+ "examples/clusters/a3high/index.md": "examples/clusters/gcp/index.md"
+ "examples/clusters/a3mega/index.md": "examples/clusters/gcp/index.md"
+ "examples/clusters/a4/index.md": "examples/clusters/gcp/index.md"
+ "examples/clusters/efa/index.md": "examples/clusters/aws/index.md"
- typeset
- gen-files:
- scripts: # always relative to mkdocs.yml
+ scripts: # always relative to mkdocs.yml
- scripts/docs/gen_examples.py
- scripts/docs/gen_cli_reference.py
- scripts/docs/gen_openapi_reference.py
@@ -279,70 +279,71 @@ nav:
- Protips: docs/guides/protips.md
- Migration: docs/guides/migration.md
- Reference:
- - .dstack.yml:
- - dev-environment: docs/reference/dstack.yml/dev-environment.md
- - task: docs/reference/dstack.yml/task.md
- - service: docs/reference/dstack.yml/service.md
- - fleet: docs/reference/dstack.yml/fleet.md
- - gateway: docs/reference/dstack.yml/gateway.md
- - volume: docs/reference/dstack.yml/volume.md
- - server/config.yml: docs/reference/server/config.yml.md
- - CLI:
- - dstack server: docs/reference/cli/dstack/server.md
- - dstack init: docs/reference/cli/dstack/init.md
- - dstack apply: docs/reference/cli/dstack/apply.md
- - dstack delete: docs/reference/cli/dstack/delete.md
- - dstack ps: docs/reference/cli/dstack/ps.md
- - dstack stop: docs/reference/cli/dstack/stop.md
- - dstack attach: docs/reference/cli/dstack/attach.md
- - dstack logs: docs/reference/cli/dstack/logs.md
- - dstack metrics: docs/reference/cli/dstack/metrics.md
- - dstack event: docs/reference/cli/dstack/event.md
- - dstack project: docs/reference/cli/dstack/project.md
- - dstack fleet: docs/reference/cli/dstack/fleet.md
- - dstack offer: docs/reference/cli/dstack/offer.md
- - dstack volume: docs/reference/cli/dstack/volume.md
- - dstack gateway: docs/reference/cli/dstack/gateway.md
- - dstack secret: docs/reference/cli/dstack/secret.md
- - API:
- - Python API: docs/reference/api/python/index.md
- - REST API: docs/reference/api/rest/index.md
- - Environment variables: docs/reference/environment-variables.md
- - .dstack/profiles.yml: docs/reference/profiles.yml.md
- - Plugins:
- - Python API: docs/reference/plugins/python/index.md
- - REST API: docs/reference/plugins/rest/index.md
- - llms-full.txt: https://dstack.ai/llms-full.txt
+ - .dstack.yml:
+ - dev-environment: docs/reference/dstack.yml/dev-environment.md
+ - task: docs/reference/dstack.yml/task.md
+ - service: docs/reference/dstack.yml/service.md
+ - fleet: docs/reference/dstack.yml/fleet.md
+ - gateway: docs/reference/dstack.yml/gateway.md
+ - volume: docs/reference/dstack.yml/volume.md
+ - server/config.yml: docs/reference/server/config.yml.md
+ - CLI:
+ - dstack server: docs/reference/cli/dstack/server.md
+ - dstack init: docs/reference/cli/dstack/init.md
+ - dstack apply: docs/reference/cli/dstack/apply.md
+ - dstack delete: docs/reference/cli/dstack/delete.md
+ - dstack ps: docs/reference/cli/dstack/ps.md
+ - dstack stop: docs/reference/cli/dstack/stop.md
+ - dstack attach: docs/reference/cli/dstack/attach.md
+ - dstack login: docs/reference/cli/dstack/login.md
+ - dstack logs: docs/reference/cli/dstack/logs.md
+ - dstack metrics: docs/reference/cli/dstack/metrics.md
+ - dstack event: docs/reference/cli/dstack/event.md
+ - dstack project: docs/reference/cli/dstack/project.md
+ - dstack fleet: docs/reference/cli/dstack/fleet.md
+ - dstack offer: docs/reference/cli/dstack/offer.md
+ - dstack volume: docs/reference/cli/dstack/volume.md
+ - dstack gateway: docs/reference/cli/dstack/gateway.md
+ - dstack secret: docs/reference/cli/dstack/secret.md
+ - API:
+ - Python API: docs/reference/api/python/index.md
+ - REST API: docs/reference/api/rest/index.md
+ - Environment variables: docs/reference/environment-variables.md
+ - .dstack/profiles.yml: docs/reference/profiles.yml.md
+ - Plugins:
+ - Python API: docs/reference/plugins/python/index.md
+ - REST API: docs/reference/plugins/rest/index.md
+ - llms-full.txt: https://dstack.ai/llms-full.txt
- Examples:
- - examples.md
- - Single-node training:
- - TRL: examples/single-node-training/trl/index.md
- - Axolotl: examples/single-node-training/axolotl/index.md
- - Distributed training:
- - TRL: examples/distributed-training/trl/index.md
- - Axolotl: examples/distributed-training/axolotl/index.md
- - Ray+RAGEN: examples/distributed-training/ray-ragen/index.md
- - Clusters:
- - AWS: examples/clusters/aws/index.md
- - GCP: examples/clusters/gcp/index.md
- - Lambda: examples/clusters/lambda/index.md
- - Crusoe: examples/clusters/crusoe/index.md
- - NCCL/RCCL tests: examples/clusters/nccl-rccl-tests/index.md
- - Inference:
- - SGLang: examples/inference/sglang/index.md
- - vLLM: examples/inference/vllm/index.md
- - TGI: examples/inference/tgi/index.md
- - NIM: examples/inference/nim/index.md
- - TensorRT-LLM: examples/inference/trtllm/index.md
- - Accelerators:
- - AMD: examples/accelerators/amd/index.md
- - TPU: examples/accelerators/tpu/index.md
- - Intel Gaudi: examples/accelerators/intel/index.md
- - Tenstorrent: examples/accelerators/tenstorrent/index.md
- - Models:
- - Wan2.2: examples/models/wan22/index.md
- - Blog:
- - blog/index.md
+ - examples.md
+ - Single-node training:
+ - TRL: examples/single-node-training/trl/index.md
+ - Axolotl: examples/single-node-training/axolotl/index.md
+ - Distributed training:
+ - TRL: examples/distributed-training/trl/index.md
+ - Axolotl: examples/distributed-training/axolotl/index.md
+ - Ray+RAGEN: examples/distributed-training/ray-ragen/index.md
+ - Clusters:
+ - AWS: examples/clusters/aws/index.md
+ - GCP: examples/clusters/gcp/index.md
+ - Lambda: examples/clusters/lambda/index.md
+ - Crusoe: examples/clusters/crusoe/index.md
+ - NCCL/RCCL tests: examples/clusters/nccl-rccl-tests/index.md
+ - Inference:
+ - SGLang: examples/inference/sglang/index.md
+ - vLLM: examples/inference/vllm/index.md
+ - TGI: examples/inference/tgi/index.md
+ - NIM: examples/inference/nim/index.md
+ - TensorRT-LLM: examples/inference/trtllm/index.md
+ - Accelerators:
+ - AMD: examples/accelerators/amd/index.md
+ - TPU: examples/accelerators/tpu/index.md
+ - Intel Gaudi: examples/accelerators/intel/index.md
+ - Tenstorrent: examples/accelerators/tenstorrent/index.md
+ - Models:
+ - Wan2.2: examples/models/wan22/index.md
+ - Blog:
+ - blog/index.md
- Case studies: blog/case-studies.md
- Benchmarks: blog/benchmarks.md
# - Discord: https://discord.gg/u8SmfwPpMd" target="_blank
diff --git a/pyproject.toml b/pyproject.toml
index e69ec4d5aa..e540705d93 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -100,11 +100,12 @@ ignore = [
dev = [
"httpx>=0.28.1",
"pre-commit>=4.2.0",
+ "pytest~=7.2",
"pytest-asyncio>=0.23.8",
"pytest-httpbin>=2.1.0",
- "httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3
- "pytest~=7.2",
"pytest-socket>=0.7.0",
+ "pytest-env>=1.1.0",
+ "httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3
"requests-mock>=1.12.1",
"openai>=1.68.2",
"freezegun>=1.5.1",
diff --git a/pytest.ini b/pytest.ini
index 899f67a61b..30c0e62811 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -8,3 +8,5 @@ addopts =
markers =
shim_version
dockerized
+env =
+ DSTACK_CLI_RICH_FORCE_TERMINAL=0
diff --git a/src/dstack/_internal/cli/commands/login.py b/src/dstack/_internal/cli/commands/login.py
new file mode 100644
index 0000000000..54fdc0a0b6
--- /dev/null
+++ b/src/dstack/_internal/cli/commands/login.py
@@ -0,0 +1,237 @@
+import argparse
+import queue
+import threading
+import urllib.parse
+import webbrowser
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from typing import Optional
+
+from dstack._internal.cli.commands import BaseCommand
+from dstack._internal.cli.utils.common import console
+from dstack._internal.core.errors import ClientError, CLIError
+from dstack._internal.core.models.users import UserWithCreds
+from dstack.api._public.runs import ConfigManager
+from dstack.api.server import APIClient
+
+
+class LoginCommand(BaseCommand):
+ NAME = "login"
+ DESCRIPTION = "Authorize the CLI using Single Sign-On"
+
+ def _register(self):
+ super()._register()
+ self._parser.add_argument(
+ "--url",
+ help="The server URL, e.g. https://sky.dstack.ai",
+ required=True,
+ )
+ self._parser.add_argument(
+ "-p",
+ "--provider",
+ help=(
+ "The SSO provider name."
+ " Selected automatically if the server supports only one provider."
+ ),
+ )
+
+ def _command(self, args: argparse.Namespace):
+ super()._command(args)
+ base_url = _normalize_url_or_error(args.url)
+ api_client = APIClient(base_url=base_url)
+ provider = self._select_provider_or_error(api_client=api_client, provider=args.provider)
+ server = _LoginServer(api_client=api_client, provider=provider)
+ try:
+ server.start()
+ auth_resp = api_client.auth.authorize(provider=provider, local_port=server.port)
+ opened = webbrowser.open(auth_resp.authorization_url)
+ if opened:
+ console.print(
+ f"Your browser has been opened to log in with [code]{provider.title()}[/]:\n"
+ )
+ else:
+ console.print(f"Open the URL to log in with [code]{provider.title()}[/]:\n")
+ print(f"{auth_resp.authorization_url}\n")
+ user = server.get_logged_in_user()
+ finally:
+ server.shutdown()
+ if user is None:
+ raise CLIError("CLI authentication failed")
+ console.print(f"Logged in as [code]{user.username}[/].")
+ api_client = APIClient(base_url=base_url, token=user.creds.token)
+ self._configure_projects(api_client=api_client, user=user)
+
+ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[str]) -> str:
+ providers = api_client.auth.list_providers()
+ available_providers = [p.name for p in providers if p.enabled]
+ if len(available_providers) == 0:
+ raise CLIError("No SSO providers configured on the server.")
+ if provider is None:
+ if len(available_providers) > 1:
+ raise CLIError(
+ "Specify -p/--provider to choose SSO provider"
+ f" Available providers: {', '.join(available_providers)}"
+ )
+ return available_providers[0]
+ if provider not in available_providers:
+ raise CLIError(
+ f"Provider {provider} not configured on the server."
+ f" Available providers: {', '.join(available_providers)}"
+ )
+ return provider
+
+ def _configure_projects(self, api_client: APIClient, user: UserWithCreds):
+ projects = api_client.projects.list(include_not_joined=False)
+ if len(projects) == 0:
+ console.print(
+ "No projects configured."
+ " Create your own project via the UI or contact a project manager to add you to the project."
+ )
+ return
+ config_manager = ConfigManager()
+ default_project = config_manager.get_project_config()
+ new_default_project = None
+ for i, project in enumerate(projects):
+ set_as_default = (
+ default_project is None
+ and i == 0
+ or default_project is not None
+ and default_project.name == project.project_name
+ )
+ if set_as_default:
+ new_default_project = project
+ config_manager.configure_project(
+ name=project.project_name,
+ url=api_client.base_url,
+ token=user.creds.token,
+ default=set_as_default,
+ )
+ config_manager.save()
+ console.print(
+ f"Configured projects: {', '.join(f'[code]{p.project_name}[/]' for p in projects)}."
+ )
+ if new_default_project:
+ console.print(
+ f"Set project [code]{new_default_project.project_name}[/] as default project."
+ )
+
+
+class _BadRequestError(Exception):
+ pass
+
+
+class _LoginServer:
+ def __init__(self, api_client: APIClient, provider: str):
+ self._api_client = api_client
+ self._provider = provider
+ self._result_queue: queue.Queue[Optional[UserWithCreds]] = queue.Queue()
+ # Using built-in HTTP server to avoid extra deps.
+ callback_handler = self._make_callback_handler(
+ result_queue=self._result_queue,
+ api_client=api_client,
+ provider=provider,
+ )
+ self._server = self._create_server(handler=callback_handler)
+
+ def start(self):
+ self._thread = threading.Thread(target=self._server.serve_forever)
+ self._thread.start()
+
+ def shutdown(self):
+ self._server.shutdown()
+
+ def get_logged_in_user(self) -> Optional[UserWithCreds]:
+ return self._result_queue.get()
+
+ @property
+ def port(self) -> int:
+ return self._server.server_port
+
+ def _make_callback_handler(
+ self,
+ result_queue: queue.Queue[Optional[UserWithCreds]],
+ api_client: APIClient,
+ provider: str,
+ ) -> type[BaseHTTPRequestHandler]:
+ class _CallbackHandler(BaseHTTPRequestHandler):
+ def do_GET(self):
+ parsed_path = urllib.parse.urlparse(self.path)
+ if parsed_path.path != "/auth/callback":
+ self.send_response(404)
+ self.end_headers()
+ return
+ try:
+ self._handle_auth_callback(parsed_path)
+ except _BadRequestError as e:
+ self.send_error(400, e.args[0])
+ result_queue.put(None)
+
+ def log_message(self, format: str, *args):
+ # Do not log server requests.
+ pass
+
+ def _handle_auth_callback(self, parsed_path: urllib.parse.ParseResult):
+ try:
+ params = urllib.parse.parse_qs(parsed_path.query, strict_parsing=True)
+ except ValueError:
+ raise _BadRequestError("Bad query params")
+ code = params.get("code", [None])[0]
+ state = params.get("state", [None])[0]
+ if code is None or state is None:
+ raise _BadRequestError("Missing required params")
+ try:
+ user = api_client.auth.callback(provider=provider, code=code, state=state)
+ except ClientError:
+ raise _BadRequestError("Authentication failed")
+ self._send_success_html()
+ result_queue.put(user)
+
+ def _send_success_html(self):
+ body = _SUCCESS_HTML.encode()
+ self.send_response(200)
+ self.send_header("Content-Type", "text/html; charset=utf-8")
+ self.send_header("Content-Length", str(len(body)))
+ self.end_headers()
+ self.wfile.write(body)
+
+ return _CallbackHandler
+
+ def _create_server(self, handler: type[BaseHTTPRequestHandler]) -> HTTPServer:
+ server_address = ("127.0.0.1", 0)
+ server = HTTPServer(server_address, handler)
+ return server
+
+
+def _normalize_url_or_error(url: str) -> str:
+ if not url.startswith("http://") and not url.startswith("https://"):
+ url = "http://" + url
+ parsed = urllib.parse.urlparse(url)
+ if (
+ not parsed.scheme
+ or not parsed.hostname
+ or parsed.path not in ("", "/")
+ or parsed.params
+ or parsed.query
+ or parsed.fragment
+ or (parsed.port is not None and not (1 <= parsed.port <= 65535))
+ ):
+ raise CLIError("Invalid server URL format. Format: --url https://sky.dstack.ai")
+ return url
+
+
+_SUCCESS_HTML = """\
+
+
+
+
+ CLI authenticated
+
+
+
+ dstack CLI authenticated
+ You may close this page.
+
+
+"""
diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py
index 98be45b8d5..61f3967ab7 100644
--- a/src/dstack/_internal/cli/main.py
+++ b/src/dstack/_internal/cli/main.py
@@ -12,6 +12,7 @@
from dstack._internal.cli.commands.fleet import FleetCommand
from dstack._internal.cli.commands.gateway import GatewayCommand
from dstack._internal.cli.commands.init import InitCommand
+from dstack._internal.cli.commands.login import LoginCommand
from dstack._internal.cli.commands.logs import LogsCommand
from dstack._internal.cli.commands.metrics import MetricsCommand
from dstack._internal.cli.commands.offer import OfferCommand
@@ -68,6 +69,7 @@ def main():
GatewayCommand.register(subparsers)
InitCommand.register(subparsers)
OfferCommand.register(subparsers)
+ LoginCommand.register(subparsers)
LogsCommand.register(subparsers)
MetricsCommand.register(subparsers)
ProjectCommand.register(subparsers)
diff --git a/src/dstack/_internal/cli/utils/common.py b/src/dstack/_internal/cli/utils/common.py
index c75f08b81b..87f0687e1b 100644
--- a/src/dstack/_internal/cli/utils/common.py
+++ b/src/dstack/_internal/cli/utils/common.py
@@ -21,7 +21,10 @@
"code": "bold sea_green3",
}
-console = Console(theme=Theme(_colors))
+console = Console(
+ theme=Theme(_colors),
+ force_terminal=settings.CLI_RICH_FORCE_TERMINAL,
+)
LIVE_TABLE_REFRESH_RATE_PER_SEC = 1
diff --git a/src/dstack/_internal/core/models/auth.py b/src/dstack/_internal/core/models/auth.py
new file mode 100644
index 0000000000..f6d09fbc73
--- /dev/null
+++ b/src/dstack/_internal/core/models/auth.py
@@ -0,0 +1,28 @@
+from typing import Annotated, Optional
+
+from pydantic import Field
+
+from dstack._internal.core.models.common import CoreModel
+
+
+class OAuthProviderInfo(CoreModel):
+ name: Annotated[str, Field(description="The OAuth2 provider name.")]
+ enabled: Annotated[
+ bool, Field(description="Whether the provider is configured on the server.")
+ ]
+
+
+class OAuthState(CoreModel):
+ """
+ A struct that the server puts in the OAuth2 state parameter.
+ """
+
+ value: Annotated[str, Field(description="A random string to protect against CSRF.")]
+ local_port: Annotated[
+ Optional[int],
+ Field(
+ description="If specified, the user is redirected to localhost:local_port after the redirect from the provider.",
+ ge=1,
+ le=65535,
+ ),
+ ] = None
diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py
index 9c83bac793..527dd128fe 100644
--- a/src/dstack/_internal/server/app.py
+++ b/src/dstack/_internal/server/app.py
@@ -25,6 +25,7 @@
from dstack._internal.server.background.tasks.process_probes import PROBES_SCHEDULER
from dstack._internal.server.db import get_db, get_session_ctx, migrate
from dstack._internal.server.routers import (
+ auth,
backends,
events,
files,
@@ -210,6 +211,7 @@ def add_no_api_version_check_routes(paths: List[str]):
def register_routes(app: FastAPI, ui: bool = True):
app.include_router(server.router)
app.include_router(users.router)
+ app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(backends.root_router)
app.include_router(backends.project_router)
diff --git a/src/dstack/_internal/server/routers/auth.py b/src/dstack/_internal/server/routers/auth.py
new file mode 100644
index 0000000000..89fe2f57f5
--- /dev/null
+++ b/src/dstack/_internal/server/routers/auth.py
@@ -0,0 +1,34 @@
+from fastapi import APIRouter
+
+from dstack._internal.core.models.auth import OAuthProviderInfo
+from dstack._internal.server.schemas.auth import (
+ OAuthGetNextRedirectRequest,
+ OAuthGetNextRedirectResponse,
+)
+from dstack._internal.server.services import auth as auth_services
+from dstack._internal.server.utils.routers import CustomORJSONResponse
+
+router = APIRouter(prefix="/api/auth", tags=["auth"])
+
+
+@router.post("/list_providers", response_model=list[OAuthProviderInfo])
+async def list_providers():
+ """
+ Returns OAuth2 providers registered on the server.
+ """
+ return CustomORJSONResponse(auth_services.list_providers())
+
+
+@router.post("/get_next_redirect", response_model=OAuthGetNextRedirectResponse)
+async def get_next_redirect(body: OAuthGetNextRedirectRequest):
+ """
+ A helper endpoint that returns the next redirect URL in case the state encodes it.
+ Can be used by the UI after the redirect from the provider
+ to determine if the user needs to be redirected further (CLI login)
+ or the auth callback endpoint needs to be called directly (UI login).
+ """
+ return CustomORJSONResponse(
+ OAuthGetNextRedirectResponse(
+ redirect_url=auth_services.get_next_redirect_url(code=body.code, state=body.state)
+ )
+ )
diff --git a/src/dstack/_internal/server/routers/projects.py b/src/dstack/_internal/server/routers/projects.py
index 56d41b6ca0..d35b9535e8 100644
--- a/src/dstack/_internal/server/routers/projects.py
+++ b/src/dstack/_internal/server/routers/projects.py
@@ -1,4 +1,4 @@
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
@@ -10,6 +10,7 @@
AddProjectMemberRequest,
CreateProjectRequest,
DeleteProjectsRequest,
+ ListProjectsRequest,
RemoveProjectMemberRequest,
SetProjectMembersRequest,
UpdateProjectRequest,
@@ -37,6 +38,7 @@
@router.post("/list", response_model=List[Project])
async def list_projects(
+ body: Optional[ListProjectsRequest] = None,
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
):
@@ -45,8 +47,13 @@ async def list_projects(
`members` and `backends` are always empty - call `/api/projects/{project_name}/get` to retrieve them.
"""
+ if body is None:
+ # For backward compatibility
+ body = ListProjectsRequest()
return CustomORJSONResponse(
- await projects.list_user_accessible_projects(session=session, user=user)
+ await projects.list_user_accessible_projects(
+ session=session, user=user, include_not_joined=body.include_not_joined
+ )
)
diff --git a/src/dstack/_internal/server/schemas/auth.py b/src/dstack/_internal/server/schemas/auth.py
new file mode 100644
index 0000000000..942f1fb388
--- /dev/null
+++ b/src/dstack/_internal/server/schemas/auth.py
@@ -0,0 +1,83 @@
+from typing import Annotated, Optional
+
+from pydantic import Field
+
+from dstack._internal.core.models.common import CoreModel
+
+
+class OAuthInfoResponse(CoreModel):
+ enabled: Annotated[
+ bool, Field(description="Whether the OAuth2 provider is configured on the server.")
+ ]
+
+
+class OAuthAuthorizeRequest(CoreModel):
+ local_port: Annotated[
+ Optional[int],
+ Field(
+ description="If specified, the user is redirected to localhost:local_port after the redirect from the provider.",
+ ge=1,
+ le=65535,
+ ),
+ ] = None
+ base_url: Annotated[
+ Optional[str],
+ Field(
+ description=(
+ "The server base URL used to access the dstack server, e.g. `http://localhost:3000`."
+ " Used to build redirect URLs when the dstack server is available on multiple domains."
+ )
+ ),
+ ] = None
+
+
+class OAuthAuthorizeResponse(CoreModel):
+ authorization_url: Annotated[str, Field(description="An OAuth2 authorization URL.")]
+
+
+class OAuthCallbackRequest(CoreModel):
+ code: Annotated[
+ str,
+ Field(
+ description="The OAuth2 authorization code received from the provider in the redirect URL."
+ ),
+ ]
+ state: Annotated[
+ str,
+ Field(description="The state parameter received from the provider in the redirect URL."),
+ ]
+ base_url: Annotated[
+ Optional[str],
+ Field(
+ description=(
+ "The server base URL used to access the dstack server, e.g. `http://localhost:3000`."
+ " Used to build redirect URLs when the dstack server is available on multiple domains."
+ " It must match the base URL specified when generating the authorization URL."
+ )
+ ),
+ ] = None
+
+
+class OAuthGetNextRedirectRequest(CoreModel):
+ code: Annotated[
+ str,
+ Field(
+ description="The OAuth2 authorization code received from the provider in the redirect URL."
+ ),
+ ]
+ state: Annotated[
+ str,
+ Field(description="The state parameter received from the provider in the redirect URL."),
+ ]
+
+
+class OAuthGetNextRedirectResponse(CoreModel):
+ redirect_url: Annotated[
+ Optional[str],
+ Field(
+ description=(
+ "The URL that the user needs to be redirected to."
+ " If `null`, there is no next redirect."
+ )
+ ),
+ ]
diff --git a/src/dstack/_internal/server/schemas/projects.py b/src/dstack/_internal/server/schemas/projects.py
index 355bb3a770..ec05c1fb47 100644
--- a/src/dstack/_internal/server/schemas/projects.py
+++ b/src/dstack/_internal/server/schemas/projects.py
@@ -6,6 +6,12 @@
from dstack._internal.core.models.users import ProjectRole
+class ListProjectsRequest(CoreModel):
+ include_not_joined: Annotated[
+ bool, Field(description="Include public projects where user is not a member")
+ ] = True
+
+
class CreateProjectRequest(CoreModel):
project_name: str
is_public: bool = False
diff --git a/src/dstack/_internal/server/services/auth.py b/src/dstack/_internal/server/services/auth.py
new file mode 100644
index 0000000000..8ea40994f3
--- /dev/null
+++ b/src/dstack/_internal/server/services/auth.py
@@ -0,0 +1,77 @@
+import secrets
+import urllib.parse
+from base64 import b64decode, b64encode
+from typing import Optional
+
+from fastapi import Request, Response
+
+from dstack._internal.core.errors import ServerClientError
+from dstack._internal.core.models.auth import OAuthProviderInfo, OAuthState
+from dstack._internal.server import settings
+from dstack._internal.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+_OAUTH_STATE_COOKIE_KEY = "oauth-state"
+
+_OAUTH_PROVIDERS: list[OAuthProviderInfo] = []
+
+
+def register_provider(provider_info: OAuthProviderInfo):
+ """
+ Registers an OAuth2 provider supported on the server.
+ If the provider is supported but not configured, it should be registered with `enabled=False`.
+ The provider must register endpoints `/api/auth/{provider}/authorize` and `/api/auth/{provider}/callback`
+ as defined by the client (see `dstack.api.server._auth.AuthAPIClient`).
+ """
+ _OAUTH_PROVIDERS.append(provider_info)
+
+
+def list_providers() -> list[OAuthProviderInfo]:
+ return _OAUTH_PROVIDERS
+
+
+def generate_oauth_state(local_port: Optional[int] = None) -> str:
+ value = str(secrets.token_hex(16))
+ state = OAuthState(value=value, local_port=local_port)
+ return b64encode(state.json().encode()).decode()
+
+
+def set_state_cookie(response: Response, state: str):
+ response.set_cookie(
+ key=_OAUTH_STATE_COOKIE_KEY,
+ value=state,
+ secure=settings.SERVER_URL.startswith("https://"),
+ samesite="strict",
+ httponly=True,
+ )
+
+
+def get_validated_state(request: Request, state: str) -> OAuthState:
+ state_cookie = request.cookies.get(_OAUTH_STATE_COOKIE_KEY)
+ if state != state_cookie:
+ raise ServerClientError("Invalid state token")
+ decoded_state = _decode_state(state)
+ if decoded_state is None:
+ raise ServerClientError("Invalid state token")
+ return decoded_state
+
+
+def get_next_redirect_url(code: str, state: str) -> Optional[str]:
+ decoded_state = _decode_state(state)
+ if decoded_state is None:
+ raise ServerClientError("Invalid state token")
+ if decoded_state.local_port is None:
+ return None
+ params = {"code": code, "state": state}
+ redirect_url = f"http://localhost:{decoded_state.local_port}/auth/callback?{urllib.parse.urlencode(params)}"
+ return redirect_url
+
+
+def _decode_state(state: str) -> Optional[OAuthState]:
+ try:
+ return OAuthState.parse_raw(b64decode(state, validate=True).decode())
+ except Exception as e:
+ logger.debug("Exception when decoding OAuth2 state parameter: %s", repr(e))
+ return None
diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py
index 5e4842df56..3ef6c32785 100644
--- a/src/dstack/_internal/server/services/projects.py
+++ b/src/dstack/_internal/server/services/projects.py
@@ -83,18 +83,22 @@ async def list_user_projects(
async def list_user_accessible_projects(
session: AsyncSession,
user: UserModel,
+ include_not_joined: bool,
) -> List[Project]:
"""
Returns all projects accessible to the user:
- Projects where user is a member (public or private)
- - Public projects where user is NOT a member
+ - if `include_not_joined`: Public projects where user is NOT a member
"""
if user.global_role == GlobalRole.ADMIN:
projects = await list_project_models(session=session)
else:
- member_projects = await list_member_project_models(session=session, user=user)
- public_projects = await list_public_non_member_project_models(session=session, user=user)
- projects = member_projects + public_projects
+ projects = await list_member_project_models(session=session, user=user)
+ if include_not_joined:
+ public_projects = await list_public_non_member_project_models(
+ session=session, user=user
+ )
+ projects += public_projects
projects = sorted(projects, key=lambda p: p.created_at)
return [
diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py
index 81682480a2..6089e37c07 100644
--- a/src/dstack/_internal/settings.py
+++ b/src/dstack/_internal/settings.py
@@ -1,6 +1,7 @@
import os
from dstack import version
+from dstack._internal.utils.env import environ
from dstack._internal.utils.version import parse_version
DSTACK_VERSION = os.getenv("DSTACK_VERSION", version.__version__)
@@ -28,6 +29,8 @@
CLI_LOG_LEVEL = os.getenv("DSTACK_CLI_LOG_LEVEL", "INFO").upper()
CLI_FILE_LOG_LEVEL = os.getenv("DSTACK_CLI_FILE_LOG_LEVEL", "DEBUG").upper()
+# Can be used to disable control characters (e.g. for testing).
+CLI_RICH_FORCE_TERMINAL = environ.get_bool("DSTACK_CLI_RICH_FORCE_TERMINAL")
# Development settings
diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py
index 2ad94f0864..5d6ea08604 100644
--- a/src/dstack/api/server/__init__.py
+++ b/src/dstack/api/server/__init__.py
@@ -14,6 +14,7 @@
URLNotFoundError,
)
from dstack._internal.utils.logging import get_logger
+from dstack.api.server._auth import AuthAPIClient
from dstack.api.server._backends import BackendsAPIClient
from dstack.api.server._events import EventsAPIClient
from dstack.api.server._files import FilesAPIClient
@@ -52,16 +53,18 @@ class APIClient:
files: operations with files
"""
- def __init__(self, base_url: str, token: str):
+ def __init__(self, base_url: str, token: Optional[str] = None):
"""
Args:
base_url: The API endpoints prefix, e.g. `http://127.0.0.1:3000/`.
token: The API token.
"""
self._base_url = base_url.rstrip("/")
- self._token = token
self._s = requests.session()
- self._s.headers.update({"Authorization": f"Bearer {token}"})
+ self._token = None
+ if token is not None:
+ self._token = token
+ self._s.headers.update({"Authorization": f"Bearer {token}"})
client_api_version = os.getenv("DSTACK_CLIENT_API_VERSION", version.__version__)
if client_api_version is not None:
self._s.headers.update({"X-API-VERSION": client_api_version})
@@ -71,6 +74,10 @@ def __init__(self, base_url: str, token: str):
def base_url(self) -> str:
return self._base_url
+ @property
+ def auth(self) -> AuthAPIClient:
+ return AuthAPIClient(self._request, self._logger)
+
@property
def users(self) -> UsersAPIClient:
return UsersAPIClient(self._request, self._logger)
@@ -128,6 +135,8 @@ def events(self) -> EventsAPIClient:
return EventsAPIClient(self._request, self._logger)
def get_token_hash(self) -> str:
+ if self._token is None:
+ raise ValueError("Token not set")
return hashlib.sha1(self._token.encode()).hexdigest()[:8]
def _request(
diff --git a/src/dstack/api/server/_auth.py b/src/dstack/api/server/_auth.py
new file mode 100644
index 0000000000..b944a292a2
--- /dev/null
+++ b/src/dstack/api/server/_auth.py
@@ -0,0 +1,30 @@
+from typing import Optional
+
+from pydantic import parse_obj_as
+
+from dstack._internal.core.models.auth import OAuthProviderInfo
+from dstack._internal.core.models.users import UserWithCreds
+from dstack._internal.server.schemas.auth import (
+ OAuthAuthorizeRequest,
+ OAuthAuthorizeResponse,
+ OAuthCallbackRequest,
+)
+from dstack.api.server._group import APIClientGroup
+
+
+class AuthAPIClient(APIClientGroup):
+ def list_providers(self) -> list[OAuthProviderInfo]:
+ resp = self._request("/api/auth/list_providers")
+ return parse_obj_as(list[OAuthProviderInfo.__response__], resp.json())
+
+ def authorize(self, provider: str, local_port: Optional[int] = None) -> OAuthAuthorizeResponse:
+ body = OAuthAuthorizeRequest(local_port=local_port)
+ resp = self._request(f"/api/auth/{provider}/authorize", body=body.json())
+ return parse_obj_as(OAuthAuthorizeResponse.__response__, resp.json())
+
+ def callback(
+ self, provider: str, code: str, state: str, base_url: Optional[str] = None
+ ) -> UserWithCreds:
+ body = OAuthCallbackRequest(code=code, state=state, base_url=base_url)
+ resp = self._request(f"/api/auth/{provider}/callback", body=body.json())
+ return parse_obj_as(UserWithCreds.__response__, resp.json())
diff --git a/src/dstack/api/server/_projects.py b/src/dstack/api/server/_projects.py
index 0fb47c9ab5..31bdc3b2de 100644
--- a/src/dstack/api/server/_projects.py
+++ b/src/dstack/api/server/_projects.py
@@ -8,6 +8,7 @@
AddProjectMemberRequest,
CreateProjectRequest,
DeleteProjectsRequest,
+ ListProjectsRequest,
MemberSetting,
RemoveProjectMemberRequest,
SetProjectMembersRequest,
@@ -16,8 +17,9 @@
class ProjectsAPIClient(APIClientGroup):
- def list(self) -> List[Project]:
- resp = self._request("/api/projects/list")
+ def list(self, include_not_joined: bool = True) -> List[Project]:
+ body = ListProjectsRequest(include_not_joined=include_not_joined)
+ resp = self._request("/api/projects/list", body=body.json())
return parse_obj_as(List[Project.__response__], resp.json())
def create(self, project_name: str, is_public: bool = False) -> Project:
diff --git a/src/tests/_internal/cli/commands/test_login.py b/src/tests/_internal/cli/commands/test_login.py
new file mode 100644
index 0000000000..42b46c2b73
--- /dev/null
+++ b/src/tests/_internal/cli/commands/test_login.py
@@ -0,0 +1,103 @@
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import call, patch
+
+from pytest import CaptureFixture
+
+from tests._internal.cli.common import run_dstack_cli
+
+
+class TestLogin:
+ def test_login_no_projects(self, capsys: CaptureFixture, tmp_path: Path):
+ with (
+ patch("dstack._internal.cli.commands.login.webbrowser") as webbrowser_mock,
+ patch("dstack._internal.cli.commands.login.APIClient") as APIClientMock,
+ patch("dstack._internal.cli.commands.login._LoginServer") as LoginServerMock,
+ ):
+ webbrowser_mock.open.return_value = True
+ APIClientMock.return_value.auth.list_providers.return_value = [
+ SimpleNamespace(name="github", enabled=True)
+ ]
+ APIClientMock.return_value.auth.authorize.return_value = SimpleNamespace(
+ authorization_url="http://auth_url"
+ )
+ APIClientMock.return_value.projects.list.return_value = []
+ user = SimpleNamespace(username="me", creds=SimpleNamespace(token="token"))
+ LoginServerMock.return_value.get_logged_in_user.return_value = user
+ exit_code = run_dstack_cli(
+ [
+ "login",
+ "--url",
+ "http://127.0.0.1:31313",
+ "--provider",
+ "github",
+ ],
+ home_dir=tmp_path,
+ )
+
+ assert exit_code == 0
+ assert capsys.readouterr().out.replace("\n", "") == (
+ "Your browser has been opened to log in with Github:"
+ "http://auth_url"
+ "Logged in as me."
+ "No projects configured. Create your own project via the UI or contact a project manager to add you to the project."
+ )
+
+ def test_login_configures_projects(self, capsys: CaptureFixture, tmp_path: Path):
+ with (
+ patch("dstack._internal.cli.commands.login.webbrowser") as webbrowser_mock,
+ patch("dstack._internal.cli.commands.login.APIClient") as APIClientMock,
+ patch("dstack._internal.cli.commands.login.ConfigManager") as ConfigManagerMock,
+ patch("dstack._internal.cli.commands.login._LoginServer") as LoginServerMock,
+ ):
+ webbrowser_mock.open.return_value = True
+ APIClientMock.return_value.auth.list_providers.return_value = [
+ SimpleNamespace(name="github", enabled=True)
+ ]
+ APIClientMock.return_value.auth.authorize.return_value = SimpleNamespace(
+ authorization_url="http://auth_url"
+ )
+ APIClientMock.return_value.projects.list.return_value = [
+ SimpleNamespace(project_name="project1"),
+ SimpleNamespace(project_name="project2"),
+ ]
+ APIClientMock.return_value.base_url = "http://127.0.0.1:31313"
+ ConfigManagerMock.return_value.get_project_config.return_value = None
+ user = SimpleNamespace(username="me", creds=SimpleNamespace(token="token"))
+ LoginServerMock.return_value.get_logged_in_user.return_value = user
+ exit_code = run_dstack_cli(
+ [
+ "login",
+ "--url",
+ "http://127.0.0.1:31313",
+ "--provider",
+ "github",
+ ],
+ home_dir=tmp_path,
+ )
+ ConfigManagerMock.return_value.configure_project.assert_has_calls(
+ [
+ call(
+ name="project1",
+ url="http://127.0.0.1:31313",
+ token=user.creds.token,
+ default=True,
+ ),
+ call(
+ name="project2",
+ url="http://127.0.0.1:31313",
+ token=user.creds.token,
+ default=False,
+ ),
+ ]
+ )
+ ConfigManagerMock.return_value.save.assert_called()
+
+ assert exit_code == 0
+ assert capsys.readouterr().out.replace("\n", "") == (
+ "Your browser has been opened to log in with Github:"
+ "http://auth_url"
+ "Logged in as me."
+ "Configured projects: project1, project2."
+ "Set project project1 as default project."
+ )
diff --git a/src/tests/_internal/cli/common.py b/src/tests/_internal/cli/common.py
index 8b4a370ea6..09f4541c7e 100644
--- a/src/tests/_internal/cli/common.py
+++ b/src/tests/_internal/cli/common.py
@@ -7,7 +7,7 @@
def run_dstack_cli(
- args: List[str],
+ cli_args: List[str],
home_dir: Optional[Path] = None,
repo_dir: Optional[Path] = None,
) -> int:
@@ -18,13 +18,14 @@ def run_dstack_cli(
if home_dir is not None:
prev_home_dir = os.environ["HOME"]
os.environ["HOME"] = str(home_dir)
- with patch("sys.argv", ["dstack"] + args):
+ with patch("sys.argv", ["dstack"] + cli_args):
try:
main()
except SystemExit as e:
exit_code = e.code
- if home_dir is not None:
- os.environ["HOME"] = prev_home_dir
- if repo_dir is not None:
- os.chdir(cwd)
+ finally:
+ if home_dir is not None:
+ os.environ["HOME"] = prev_home_dir
+ if repo_dir is not None:
+ os.chdir(cwd)
return exit_code
diff --git a/src/tests/_internal/server/routers/test_auth.py b/src/tests/_internal/server/routers/test_auth.py
new file mode 100644
index 0000000000..f4c8bb0e59
--- /dev/null
+++ b/src/tests/_internal/server/routers/test_auth.py
@@ -0,0 +1,64 @@
+import json
+from base64 import b64encode
+
+import pytest
+from httpx import AsyncClient
+
+from dstack._internal.core.models.auth import OAuthProviderInfo
+from dstack._internal.server.services.auth import register_provider
+
+
+class TestListProviders:
+ @pytest.mark.asyncio
+ async def test_returns_no_providers(self, client: AsyncClient):
+ response = await client.post("/api/auth/list_providers")
+ assert response.status_code == 200
+ assert response.json() == []
+
+ @pytest.mark.asyncio
+ async def test_returns_registered_providers(self, client: AsyncClient):
+ register_provider(OAuthProviderInfo(name="provider1", enabled=True))
+ register_provider(OAuthProviderInfo(name="provider2", enabled=False))
+ response = await client.post("/api/auth/list_providers")
+ assert response.status_code == 200
+ assert response.json() == [
+ {
+ "name": "provider1",
+ "enabled": True,
+ },
+ {
+ "name": "provider2",
+ "enabled": False,
+ },
+ ]
+
+
+class TestGetNextRedirectURL:
+ @pytest.mark.asyncio
+ async def test_returns_no_redirect_url_if_local_port_not_set(self, client: AsyncClient):
+ state = b64encode(json.dumps({"value": "12356", "local_port": None}).encode()).decode()
+ response = await client.post(
+ "/api/auth/get_next_redirect", json={"code": "1234", "state": state}
+ )
+ assert response.status_code == 200
+ assert response.json() == {"redirect_url": None}
+
+ @pytest.mark.asyncio
+ async def test_returns_redirect_url_if_local_port_set(self, client: AsyncClient):
+ state = b64encode(json.dumps({"value": "12356", "local_port": 12345}).encode()).decode()
+ response = await client.post(
+ "/api/auth/get_next_redirect", json={"code": "1234", "state": state}
+ )
+ assert response.status_code == 200
+ assert response.json() == {
+ "redirect_url": f"http://localhost:12345/auth/callback?code=1234&state={state}"
+ }
+
+ @pytest.mark.asyncio
+ async def test_returns_400_if_state_invalid(self, client: AsyncClient):
+ state = "some_invalid_state"
+ response = await client.post(
+ "/api/auth/get_next_redirect", json={"code": "1234", "state": state}
+ )
+ assert response.status_code == 400
+ assert "Invalid state token" in response.json()["detail"][0]["msg"]