Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
"github.com/databricks/cli/libs/auth"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/spf13/cobra"
)

Expand All @@ -28,6 +28,8 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
return nil
}

const minimalDbConnectVersion = "13.1"

func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
cmd := &cobra.Command{
Use: "login [HOST]",
Expand Down Expand Up @@ -95,19 +97,12 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
return err
}
ctx := cmd.Context()

promptSpinner := cmdio.Spinner(ctx)
promptSpinner <- "Loading list of clusters to select from"
names, err := w.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{})
close(promptSpinner)
if err != nil {
return fmt.Errorf("failed to load clusters list. Original error: %w", err)
}
clusterId, err := cmdio.Select(ctx, names, "Choose cluster")
clusterID, err := cfgpickers.AskForCluster(ctx, w,
cfgpickers.WithDatabricksConnect(minimalDbConnectVersion))
if err != nil {
return err
}
cfg.ClusterID = clusterId
cfg.ClusterID = clusterID
}

if profileName != "" {
Expand Down
192 changes: 192 additions & 0 deletions libs/databrickscfg/cfgpickers/clusters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package cfgpickers

import (
"context"
"errors"
"fmt"
"regexp"
"strings"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/fatih/color"
"github.com/manifoldco/promptui"
"golang.org/x/mod/semver"
)

var minUcRuntime = canonicalVersion("v12.0")

var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`)
var dbrSnapshotVersionRegex = regexp.MustCompile(`^(\d+)\.x-snapshot.*`)

func canonicalVersion(v string) string {
return semver.Canonical("v" + strings.TrimPrefix(v, "v"))
}

func GetRuntimeVersion(cluster compute.ClusterDetails) (string, bool) {
match := dbrVersionRegex.FindStringSubmatch(cluster.SparkVersion)
if len(match) < 1 {
match = dbrSnapshotVersionRegex.FindStringSubmatch(cluster.SparkVersion)
if len(match) > 1 {
// we return 14.999 for 14.x-snapshot for semver.Compare() to work properly
return fmt.Sprintf("%s.999", match[1]), true
}
return "", false
}
return match[1], true
}

func IsCompatibleWithUC(cluster compute.ClusterDetails, minVersion string) bool {
minVersion = canonicalVersion(minVersion)
if semver.Compare(minUcRuntime, minVersion) >= 0 {
return false
}
runtimeVersion, ok := GetRuntimeVersion(cluster)
if !ok {
return false
}
clusterRuntime := canonicalVersion(runtimeVersion)
if semver.Compare(minVersion, clusterRuntime) > 0 {
return false
}
switch cluster.DataSecurityMode {
case compute.DataSecurityModeUserIsolation, compute.DataSecurityModeSingleUser:
return true
default:
return false
}
}

var ErrNoCompatibleClusters = errors.New("no compatible clusters found")

type compatibleCluster struct {
compute.ClusterDetails
versionName string
}

func (v compatibleCluster) Access() string {
switch v.DataSecurityMode {
case compute.DataSecurityModeUserIsolation:
return "Shared"
case compute.DataSecurityModeSingleUser:
return "Assigned"
default:
return "Unknown"
}
}

func (v compatibleCluster) Runtime() string {
runtime, _, _ := strings.Cut(v.versionName, " (")
return runtime
}

func (v compatibleCluster) State() string {
state := v.ClusterDetails.State
switch state {
case compute.StateRunning, compute.StateResizing:
return color.GreenString(state.String())
case compute.StateError, compute.StateTerminated, compute.StateTerminating, compute.StateUnknown:
return color.RedString(state.String())
default:
return color.BlueString(state.String())
}
}

type clusterFilter func(cluster *compute.ClusterDetails, me *iam.User) bool

func WithDatabricksConnect(minVersion string) func(*compute.ClusterDetails, *iam.User) bool {
return func(cluster *compute.ClusterDetails, me *iam.User) bool {
if !IsCompatibleWithUC(*cluster, minVersion) {
return false
}
switch cluster.ClusterSource {
case compute.ClusterSourceJob,
compute.ClusterSourceModels,
compute.ClusterSourcePipeline,
compute.ClusterSourcePipelineMaintenance,
compute.ClusterSourceSql:
// only UI and API clusters are usable for DBConnect.
// `CanUseClient: "NOTEBOOKS"`` didn't seem to have an effect.
return false
}
if cluster.SingleUserName != "" && cluster.SingleUserName != me.UserName {
return false
}
return true
}
}

func loadInteractiveClusters(ctx context.Context, w *databricks.WorkspaceClient, filters []clusterFilter) ([]compatibleCluster, error) {
promptSpinner := cmdio.Spinner(ctx)
promptSpinner <- "Loading list of clusters to select from"
defer close(promptSpinner)
all, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{
CanUseClient: "NOTEBOOKS",
})
if err != nil {
return nil, fmt.Errorf("list clusters: %w", err)
}
me, err := w.CurrentUser.Me(ctx)
if err != nil {
return nil, fmt.Errorf("current user: %w", err)
}
versions := map[string]string{}
sv, err := w.Clusters.SparkVersions(ctx)
if err != nil {
return nil, fmt.Errorf("list runtime versions: %w", err)
}
for _, v := range sv.Versions {
versions[v.Key] = v.Name
}
var compatible []compatibleCluster
for _, cluster := range all {
var skip bool
for _, filter := range filters {
if !filter(&cluster, me) {
skip = true
}
}
if skip {
continue
}
compatible = append(compatible, compatibleCluster{
ClusterDetails: cluster,
versionName: versions[cluster.SparkVersion],
})
}
return compatible, nil
}

func AskForCluster(ctx context.Context, w *databricks.WorkspaceClient, filters ...clusterFilter) (string, error) {
compatible, err := loadInteractiveClusters(ctx, w, filters)
if err != nil {
return "", fmt.Errorf("load: %w", err)
}
if len(compatible) == 0 {
return "", ErrNoCompatibleClusters
}
if len(compatible) == 1 {
return compatible[0].ClusterId, nil
}
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
Label: "Choose compatible cluster",
Items: compatible,
Searcher: func(input string, idx int) bool {
lower := strings.ToLower(compatible[idx].ClusterName)
return strings.Contains(lower, input)
},
StartInSearchMode: true,
Templates: &promptui.SelectTemplates{
Label: "{{.ClusterName | faint}}",
Active: `{{.ClusterName | bold}} ({{.State}} {{.Access}} Runtime {{.Runtime}}) ({{.ClusterId | faint}})`,
Inactive: `{{.ClusterName}} ({{.State}} {{.Access}} Runtime {{.Runtime}})`,
Selected: `{{ "Configured cluster" | faint }}: {{ .ClusterName | bold }} ({{.ClusterId | faint}})`,
},
})
if err != nil {
return "", err
}
return compatible[i].ClusterId, nil
}
146 changes: 146 additions & 0 deletions libs/databrickscfg/cfgpickers/clusters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package cfgpickers

import (
"bytes"
"context"
"testing"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/flags"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/qa"
"github.com/databricks/databricks-sdk-go/service/compute"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/stretchr/testify/require"
)

func TestIsCompatible(t *testing.T) {
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "13.2.x-aarch64-scala2.12",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
}, "13.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "13.2.x-aarch64-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "13.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "9.1.x-photon-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "13.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "9.1.x-photon-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "10.0"))
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "custom-9.1.x-photon-scala2.12",
DataSecurityMode: compute.DataSecurityModeNone,
}, "14.0"))
}

func TestIsCompatibleWithSnapshots(t *testing.T) {
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
SparkVersion: "14.x-snapshot-cpu-ml-scala2.12",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
}, "14.0"))
}

func TestFirstCompatibleCluster(t *testing.T) {
cfg, server := qa.HTTPFixtures{
{
Method: "GET",
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
Response: compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{
{
ClusterId: "abc-id",
ClusterName: "first shared",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
SparkVersion: "12.2.x-whatever",
State: compute.StateRunning,
},
{
ClusterId: "bcd-id",
ClusterName: "second personal",
DataSecurityMode: compute.DataSecurityModeSingleUser,
SparkVersion: "14.5.x-whatever",
State: compute.StateRunning,
SingleUserName: "serge",
},
},
},
},
{
Method: "GET",
Resource: "/api/2.0/preview/scim/v2/Me",
Response: iam.User{
UserName: "serge",
},
},
{
Method: "GET",
Resource: "/api/2.0/clusters/spark-versions",
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Key: "14.5.x-whatever",
Name: "14.5 (Awesome)",
},
},
},
},
}.Config(t)
defer server.Close()
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))

ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
clusterID, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
require.NoError(t, err)
require.Equal(t, "bcd-id", clusterID)
}

func TestNoCompatibleClusters(t *testing.T) {
cfg, server := qa.HTTPFixtures{
{
Method: "GET",
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
Response: compute.ListClustersResponse{
Clusters: []compute.ClusterDetails{
{
ClusterId: "abc-id",
ClusterName: "first shared",
DataSecurityMode: compute.DataSecurityModeUserIsolation,
SparkVersion: "12.2.x-whatever",
State: compute.StateRunning,
},
},
},
},
{
Method: "GET",
Resource: "/api/2.0/preview/scim/v2/Me",
Response: iam.User{
UserName: "serge",
},
},
{
Method: "GET",
Resource: "/api/2.0/clusters/spark-versions",
Response: compute.GetSparkVersionsResponse{
Versions: []compute.SparkVersion{
{
Key: "14.5.x-whatever",
Name: "14.5 (Awesome)",
},
},
},
},
}.Config(t)
defer server.Close()
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))

ctx := context.Background()
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
_, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
require.Equal(t, ErrNoCompatibleClusters, err)
}
Loading