diff --git a/.github/workflows/validation-sfcompute.yml b/.github/workflows/validation-sfcompute.yml new file mode 100644 index 0000000..4813366 --- /dev/null +++ b/.github/workflows/validation-sfcompute.yml @@ -0,0 +1,59 @@ +name: SFCompute Validation Tests + +on: + schedule: + # Run daily at 2 AM UTC + - cron: '0 2 * * *' + workflow_dispatch: + # Allow manual triggering + pull_request: + paths: + - 'v1/providers/sfcompute/**' + - 'internal/validation/**' + - 'v1/**' + branches: [ main ] + +jobs: + sfcompute-validation: + name: SFCompute Provider Validation + runs-on: ubuntu-latest + if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Install dependencies + run: make deps + + - name: Run SFCompute validation tests + env: + SFCOMPUTE_API_KEY: ${{ secrets.SFCOMPUTE_API_KEY }} + TEST_PRIVATE_KEY_BASE64: ${{ secrets.TEST_PRIVATE_KEY_BASE64 }} + TEST_PUBLIC_KEY_BASE64: ${{ secrets.TEST_PUBLIC_KEY_BASE64 }} + VALIDATION_TEST: true + run: | + cd v1/providers/sfcompute + go test -v -short=false -timeout=30m ./... + + - name: Upload test results + uses: actions/upload-artifact@v4 + if: always() + with: + name: sfcompute-validation-results + path: | + v1/providers/sfcompute/coverage.out diff --git a/go.mod b/go.mod index bcf4b5e..a695f95 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/jarcoal/httpmock v1.4.0 github.com/nebius/gosdk v0.0.0-20250826102719-940ad1dfb5de github.com/pkg/errors v0.9.1 + github.com/sfcompute/nodes-go v0.1.0-alpha.4 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.47.0 golang.org/x/text v0.33.0 @@ -83,6 +84,10 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/x448/float16 v0.8.4 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect diff --git a/go.sum b/go.sum index 5c23c5c..443dd04 100644 --- a/go.sum +++ b/go.sum @@ -160,6 +160,9 @@ github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7D github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sfcompute/nodes-go v0.1.0-alpha.3/go.mod h1:dF3O8MCxLz3FTVYhjCa876Z9O3EAM8E8fONivDpfmkM= +github.com/sfcompute/nodes-go v0.1.0-alpha.4 h1:oFBWcMPSpqLYm/NDs5I1jTvzgx9rsXDL9Ghsm30Hc0Q= +github.com/sfcompute/nodes-go v0.1.0-alpha.4/go.mod h1:nUviHgK+Fgt2hDFcRL3M8VoyiypC8fc0dsY8C30QU8M= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= @@ -180,6 +183,16 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/v1/instance_validation.go b/v1/instance_validation.go index b2f16ec..5b06f4e 100644 --- a/v1/instance_validation.go +++ b/v1/instance_validation.go @@ -58,19 +58,47 @@ func ValidateCreateInstance(ctx context.Context, client CloudCreateTerminateInst } func ValidateListCreatedInstance(ctx context.Context, client CloudCreateTerminateInstance, i *Instance) error { + // List instances by location and search for the instance by CloudID ins, err := client.ListInstances(ctx, ListInstancesArgs{ Locations: []string{i.Location}, }) if err != nil { return err } - var validationErr error if len(ins) == 0 { - validationErr = errors.Join(validationErr, fmt.Errorf("no instances found")) + return fmt.Errorf("no instances found") } foundInstance := collections.Find(ins, func(inst Instance) bool { return inst.CloudID == i.CloudID }) + err = validateInstance(i, foundInstance) + if err != nil { + return err + } + + // List instances by instance ID and search for the instance by CloudID + ins, err = client.ListInstances(ctx, ListInstancesArgs{ + InstanceIDs: []CloudProviderInstanceID{i.CloudID}, + }) + if err != nil { + return err + } + if len(ins) == 0 { + return fmt.Errorf("instance not found: %s", i.CloudID) + } + + foundInstance = collections.Find(ins, func(inst Instance) bool { + return inst.CloudID == i.CloudID + }) + err = validateInstance(i, foundInstance) + if err != nil { + return err + } + return nil +} + +func validateInstance(i *Instance, foundInstance *Instance) error { + var validationErr error if foundInstance == nil { validationErr = errors.Join(validationErr, fmt.Errorf("instance not found: %s", i.CloudID)) return validationErr diff --git a/v1/instancetype.go b/v1/instancetype.go index 3c09440..5401a76 100644 --- a/v1/instancetype.go +++ b/v1/instancetype.go @@ -439,3 +439,35 @@ func ValidateStableInstanceTypeIDs(ctx context.Context, client CloudInstanceType return nil } + +func IsSelectedByArgs(instanceType InstanceType, args GetInstanceTypeArgs) bool { + if args.Locations != nil { + if !args.Locations.IsAllowed(instanceType.Location) { + return false + } + } + + if args.GPUManufactererFilter != nil { + for _, supportedGPU := range instanceType.SupportedGPUs { + if !args.GPUManufactererFilter.IsAllowed(supportedGPU.Manufacturer) { + return false + } + } + } + + if args.CloudFilter != nil { + if !args.CloudFilter.IsAllowed(instanceType.Cloud) { + return false + } + } + + if args.ArchitectureFilter != nil { + for _, architecture := range instanceType.SupportedArchitectures { + if !args.ArchitectureFilter.IsAllowed(architecture) { + return false + } + } + } + + return true +} diff --git a/v1/providers/launchpad/instancetype.go b/v1/providers/launchpad/instancetype.go index ca9b317..a66f941 100644 --- a/v1/providers/launchpad/instancetype.go +++ b/v1/providers/launchpad/instancetype.go @@ -44,7 +44,7 @@ func (c *LaunchpadClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta } // Collect the instance type if it is selected by the args - if isSelectedByArgs(*instanceType, args) { + if v1.IsSelectedByArgs(*instanceType, args) { instanceTypes = append(instanceTypes, *instanceType) } else { continue @@ -55,40 +55,6 @@ func (c *LaunchpadClient) GetInstanceTypes(ctx context.Context, args v1.GetInsta return instanceTypes, nil } -func isSelectedByArgs(instanceType v1.InstanceType, args v1.GetInstanceTypeArgs) bool { - if args.Locations != nil { - for _, location := range instanceType.Location { - if !args.Locations.IsAllowed(string(location)) { - return false - } - } - } - - if args.GPUManufactererFilter != nil { - for _, supportedGPU := range instanceType.SupportedGPUs { - if !args.GPUManufactererFilter.IsAllowed(supportedGPU.Manufacturer) { - return false - } - } - } - - if args.CloudFilter != nil { - if !args.CloudFilter.IsAllowed(instanceType.Cloud) { - return false - } - } - - if args.ArchitectureFilter != nil { - for _, architecture := range instanceType.SupportedArchitectures { - if !args.ArchitectureFilter.IsAllowed(architecture) { - return false - } - } - } - - return true -} - func (c *LaunchpadClient) paginateInstanceTypes(ctx context.Context, pageSize int32) ([]openapi.InstanceType, error) { instanceTypes := make([]openapi.InstanceType, 0, pageSize) var page int32 = 1 diff --git a/v1/providers/sfcompute/capabilities.go b/v1/providers/sfcompute/capabilities.go new file mode 100644 index 0000000..cad2ca7 --- /dev/null +++ b/v1/providers/sfcompute/capabilities.go @@ -0,0 +1,23 @@ +package v1 + +import ( + "context" + + v1 "github.com/brevdev/cloud/v1" +) + +func getSFCCapabilities() v1.Capabilities { + return v1.Capabilities{ + v1.CapabilityCreateInstance, + v1.CapabilityTerminateInstance, + v1.CapabilityCreateTerminateInstance, + } +} + +func (c *SFCClient) GetCapabilities(_ context.Context) (v1.Capabilities, error) { + return getSFCCapabilities(), nil +} + +func (c *SFCCredential) GetCapabilities(_ context.Context) (v1.Capabilities, error) { + return getSFCCapabilities(), nil +} diff --git a/v1/providers/sfcompute/client.go b/v1/providers/sfcompute/client.go new file mode 100644 index 0000000..7dc2031 --- /dev/null +++ b/v1/providers/sfcompute/client.go @@ -0,0 +1,104 @@ +package v1 + +import ( + "context" + + v1 "github.com/brevdev/cloud/v1" + "github.com/sfcompute/nodes-go/option" + + sfcnodes "github.com/sfcompute/nodes-go" +) + +const CloudProviderID = "sfcompute" + +type SFCCredential struct { + RefID string + APIKey string `json:"api_key"` +} + +var _ v1.CloudCredential = &SFCCredential{} + +func NewSFCCredential(refID string, apiKey string) *SFCCredential { + return &SFCCredential{ + RefID: refID, + APIKey: apiKey, + } +} + +func (c *SFCCredential) GetReferenceID() string { + return c.RefID +} + +func (c *SFCCredential) GetAPIType() v1.APIType { + return v1.APITypeGlobal +} + +func (c *SFCCredential) GetCloudProviderID() v1.CloudProviderID { + return CloudProviderID +} + +func (c *SFCCredential) GetTenantID() (string, error) { + // sfc does not have a tenant system, return empty string + return "", nil +} + +type SFCClient struct { + v1.NotImplCloudClient + refID string + location string + apiKey string + client sfcnodes.Client + logger v1.Logger +} + +var _ v1.CloudClient = &SFCClient{} + +type SFCClientOption func(c *SFCClient) + +func WithLogger(logger v1.Logger) SFCClientOption { + return func(c *SFCClient) { + c.logger = logger + } +} + +func (c *SFCCredential) MakeClientWithOptions(_ context.Context, location string, opts ...SFCClientOption) (v1.CloudClient, error) { + sfcClient := &SFCClient{ + refID: c.RefID, + apiKey: c.APIKey, + client: sfcnodes.NewClient(option.WithBearerToken(c.APIKey)), + location: location, + logger: &v1.NoopLogger{}, + } + + for _, opt := range opts { + opt(sfcClient) + } + + return sfcClient, nil +} + +func (c *SFCCredential) MakeClient(ctx context.Context, location string) (v1.CloudClient, error) { + return c.MakeClientWithOptions(ctx, location) +} + +func (c *SFCClient) GetAPIType() v1.APIType { + return v1.APITypeGlobal +} + +func (c *SFCClient) GetCloudProviderID() v1.CloudProviderID { + return CloudProviderID +} + +func (c *SFCClient) GetReferenceID() string { + return c.refID +} + +func (c *SFCClient) GetTenantID() (string, error) { + // sfc does not have a tenant system, return empty string + return "", nil +} + +func (c *SFCClient) MakeClient(_ context.Context, location string) (v1.CloudClient, error) { + c.location = location + return c, nil +} diff --git a/v1/providers/sfcompute/instance.go b/v1/providers/sfcompute/instance.go new file mode 100644 index 0000000..77e282f --- /dev/null +++ b/v1/providers/sfcompute/instance.go @@ -0,0 +1,408 @@ +package v1 + +import ( + "context" + "encoding/base64" + "fmt" + "slices" + "strings" + "time" + + "github.com/alecthomas/units" + "github.com/brevdev/cloud/internal/errors" + v1 "github.com/brevdev/cloud/v1" + sfcnodes "github.com/sfcompute/nodes-go" + "github.com/sfcompute/nodes-go/packages/param" +) + +const ( + maxPricePerNodeHour = 1600 + defaultPort = 2222 + defaultSSHUsername = "ubuntu" +) + +func (c *SFCClient) CreateInstance(ctx context.Context, attrs v1.CreateInstanceAttrs) (*v1.Instance, error) { + // Get the zone for the location (do not include unavailable zones) + zone, err := c.getZone(ctx, attrs.Location, false) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + // Create a name for the node + name := brevDataToSFCName(attrs.RefID, attrs.Name) + + // Create the node + resp, err := c.client.Nodes.New(ctx, sfcnodes.NodeNewParams{ + CreateNodesRequest: sfcnodes.CreateNodesRequestParam{ + DesiredCount: 1, + MaxPricePerNodeHour: maxPricePerNodeHour, + Zone: zone.Name, + Names: []string{name}, + CloudInitUserData: param.Opt[string]{Value: sshKeyCloudInit(attrs.PublicKey)}, + }, + }) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + if len(resp.Data) == 0 { + return nil, errors.WrapAndTrace(fmt.Errorf("no nodes returned")) + } + node := resp.Data[0] + + // Get the instance + instance, err := c.GetInstance(ctx, v1.CloudProviderInstanceID(node.ID)) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + return instance, nil +} + +func sshKeyCloudInit(sshKey string) string { + script := fmt.Sprintf("#cloud-config\nssh_authorized_keys:\n - %s", sshKey) + return base64.StdEncoding.EncodeToString([]byte(script)) +} + +func (c *SFCClient) GetInstance(ctx context.Context, id v1.CloudProviderInstanceID) (*v1.Instance, error) { + c.logger.Debug(ctx, "sfc: GetInstance start", + v1.LogField("instanceID", id), + v1.LogField("location", c.location), + ) + + // Get the node from the API + node, err := c.client.Nodes.Get(ctx, string(id)) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + // Get the zone for the location (include unavailable zones, in case the zone is not available but the node is still running) + zone, err := c.getZone(ctx, node.Zone, true) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + nodeInfo, err := c.sfcNodeInfoFromNode(ctx, node, zone) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + instance, err := c.sfcNodeToBrevInstance(*nodeInfo) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + c.logger.Debug(ctx, "sfc: GetInstance end", + v1.LogField("instanceID", id), + v1.LogField("instance", instance), + ) + + return instance, nil +} + +func (c *SFCClient) getZone(ctx context.Context, location string, includeUnavailable bool) (*sfcnodes.ZoneListResponseData, error) { + // Fetch the zones to ensure the location is valid + zones, err := c.getZones(ctx, includeUnavailable) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + if len(zones) == 0 { + return nil, errors.WrapAndTrace(fmt.Errorf("no zones available")) + } + + // Find the zone that matches the location + var zone *sfcnodes.ZoneListResponseData + for _, z := range zones { + if z.Name == location { + zone = &z + break + } + } + if zone == nil { + return nil, errors.WrapAndTrace(fmt.Errorf("zone not found in location %s", location)) + } + + return zone, nil +} + +func (c *SFCClient) ListInstances(ctx context.Context, args v1.ListInstancesArgs) ([]v1.Instance, error) { + c.logger.Debug(ctx, "sfc: ListInstances start", + v1.LogField("location", c.location), + v1.LogField("args", fmt.Sprintf("%+v", args)), + ) + + resp, err := c.client.Nodes.List(ctx, sfcnodes.NodeListParams{}) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + c.logger.Debug(ctx, "sfc: ListInstances nodes list", + v1.LogField("node count", len(resp.Data)), + ) + + zoneCache := make(map[string]*sfcnodes.ZoneListResponseData) + + var instances []v1.Instance + for _, node := range resp.Data { + // Get the zone for the node, checking the cache first + zone, ok := zoneCache[node.Zone] + if !ok { + z, err := c.getZone(ctx, node.Zone, true) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + zoneCache[node.Zone] = z + zone = z + } + + // Filter by locations + if args.Locations != nil && !args.Locations.IsAllowed(zone.Name) { + c.logger.Debug(ctx, "sfc: ListInstances node filtered out by location", + v1.LogField("nodeID", node.ID), + v1.LogField("location", zone.Name), + ) + continue + } + + // Filter by instance IDs + if len(args.InstanceIDs) > 0 && !slices.Contains(args.InstanceIDs, v1.CloudProviderInstanceID(node.ID)) { + c.logger.Debug(ctx, "sfc: ListInstances node filtered out by instance ID", + v1.LogField("nodeID", node.ID), + v1.LogField("instanceID", v1.CloudProviderInstanceID(node.ID)), + ) + continue + } + + nodeInfo, err := c.sfcNodeInfoFromNodeListResponseData(ctx, &node, zone) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + inst, err := c.sfcNodeToBrevInstance(*nodeInfo) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + instances = append(instances, *inst) + } + + c.logger.Debug(ctx, "sfc: ListInstances end", + v1.LogField("instance count", len(instances)), + ) + + return instances, nil +} + +func (c *SFCClient) TerminateInstance(ctx context.Context, id v1.CloudProviderInstanceID) error { + c.logger.Debug(ctx, "sfc: TerminateInstance start", + v1.LogField("instanceID", id), + ) + + _, err := c.client.Nodes.Release(ctx, string(id)) + if err != nil { + return errors.WrapAndTrace(err) + } + + c.logger.Debug(ctx, "sfc: TerminateInstance end", + v1.LogField("instanceID", id), + ) + + return nil +} + +type sfcNodeInfo struct { + id string + name string + createdAt time.Time + status v1.LifecycleStatus + gpuType string + sshUsername string + sshHostname string + zone *sfcnodes.ZoneListResponseData +} + +func (c *SFCClient) sfcNodeToBrevInstance(node sfcNodeInfo) (*v1.Instance, error) { + // Get the refID and name from the node name + refID, name, err := sfcNameToBrevData(node.name) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + // Get the instance type for the zone + instanceType, err := getInstanceTypeForZone(*node.zone) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + + diskSizeInt64, err := instanceType.SupportedStorage[0].SizeBytes.ByteCountInUnitInt64(v1.Gibibyte) + if err != nil { + return nil, err + } + diskSize := units.Base2Bytes(diskSizeInt64 * int64(units.Gibibyte)) + + // Create the instance + inst := &v1.Instance{ + Name: name, + CloudID: v1.CloudProviderInstanceID(node.id), + RefID: refID, + PublicDNS: node.sshHostname, + PublicIP: node.sshHostname, + SSHUser: node.sshUsername, + SSHPort: defaultPort, + CreatedAt: node.createdAt, + DiskSize: diskSize, + DiskSizeBytes: instanceType.SupportedStorage[0].SizeBytes, // TODO: this should be pulled from the node itself + Status: v1.Status{ + LifecycleStatus: node.status, + }, + InstanceTypeID: instanceType.ID, + InstanceType: instanceType.Type, + Location: node.zone.Name, + Spot: false, + Stoppable: false, + Rebootable: false, + CloudCredRefID: c.refID, // TODO: this should be pulled from the node itself + } + return inst, nil +} + +func (c *SFCClient) sfcNodeInfoFromNode(ctx context.Context, node *sfcnodes.Node, zone *sfcnodes.ZoneListResponseData) (*sfcNodeInfo, error) { + var sshHostname string + if len(node.VMs.Data) == 1 { //nolint:gocritic // ok + hostname, err := c.getSSHHostnameFromVM(ctx, node.VMs.Data[0].ID, node.VMs.Data[0].Status) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + sshHostname = hostname + } else if len(node.VMs.Data) == 0 { + sshHostname = "" + } else { + return nil, errors.WrapAndTrace(fmt.Errorf("multiple VMs found for node %s", node.ID)) + } + + return &sfcNodeInfo{ + id: node.ID, + name: node.Name, + createdAt: time.Unix(node.CreatedAt, 0), + status: sfcStatusToLifecycleStatus(fmt.Sprint(node.Status)), + gpuType: string(node.GPUType), + sshUsername: defaultSSHUsername, + sshHostname: sshHostname, + zone: zone, + }, nil +} + +func (c *SFCClient) sfcNodeInfoFromNodeListResponseData(ctx context.Context, node *sfcnodes.ListResponseNodeData, zone *sfcnodes.ZoneListResponseData) (*sfcNodeInfo, error) { + sfcNode := sfcListResponseNodeDataToNode(node) + return c.sfcNodeInfoFromNode(ctx, sfcNode, zone) +} + +// Convert the sfcnodes.ListResponseNodeData into a node *sfcnodes.Node -- these are fundamentally the same object, but they +// lack a common interface. One type is returned from a single "get" call, the other is the type of each object returned by +// a "list" call. This conversion function allows the rest of our business logic to treat these as the same type. +func sfcListResponseNodeDataToNode(node *sfcnodes.ListResponseNodeData) *sfcnodes.Node { + vms := make([]sfcnodes.NodeVMsData, len(node.VMs.Data)) + for i, vm := range node.VMs.Data { + vms[i] = sfcnodes.NodeVMsData{ //nolint:staticcheck // ok + ID: vm.ID, + CreatedAt: vm.CreatedAt, + EndAt: vm.EndAt, + Object: vm.Object, + StartAt: vm.StartAt, + Status: vm.Status, + UpdatedAt: vm.UpdatedAt, + ImageID: vm.ImageID, + JSON: vm.JSON, + } + } + + return &sfcnodes.Node{ + ID: node.ID, + GPUType: node.GPUType, + Name: node.Name, + NodeType: node.NodeType, + Object: node.Object, + Owner: node.Owner, + Status: node.Status, + CreatedAt: node.CreatedAt, + DeletedAt: node.DeletedAt, + EndAt: node.EndAt, + MaxPricePerNodeHour: node.MaxPricePerNodeHour, + ProcurementID: node.ProcurementID, + StartAt: node.StartAt, + UpdatedAt: node.UpdatedAt, + Zone: node.Zone, + JSON: node.JSON, + VMs: sfcnodes.NodeVMs{ + Data: vms, + Object: node.VMs.Object, + JSON: node.VMs.JSON, + }, + } +} + +func sfcStatusToLifecycleStatus(status string) v1.LifecycleStatus { + switch strings.ToLower(status) { + case "pending", "unspecified", "awaitingcapacity", "unknown": + return v1.LifecycleStatusPending + case "running": + return v1.LifecycleStatusRunning + case "stopped": + return v1.LifecycleStatusStopped + case "terminating": + return v1.LifecycleStatusTerminating + case "released", "destroyed", "deleted": + return v1.LifecycleStatusTerminated + case "nodefailure", "failed": + return v1.LifecycleStatusFailed + default: + return v1.LifecycleStatusPending + } +} + +func (c *SFCClient) getSSHHostnameFromVM(ctx context.Context, vmID string, vmStatus string) (string, error) { + // If the VM is not running, set the SSH username and hostname to empty strings + if strings.ToLower(vmStatus) != "running" { + return "", nil + } + + // If the VM is running, get the SSH username and hostname + sshResponse, err := c.client.VMs.SSH(ctx, sfcnodes.VMSSHParams{VMID: vmID}) + if err != nil { + return "", errors.WrapAndTrace(err) + } + + return sshResponse.SSHHostname, nil +} + +func brevDataToSFCName(refID string, name string) string { + return fmt.Sprintf("%s_%s", refID, name) +} + +func sfcNameToBrevData(name string) (string, string, error) { + parts := strings.SplitN(name, "_", 2) + if len(parts) != 2 { + return "", "", errors.WrapAndTrace(fmt.Errorf("invalid node name %s", name)) + } + return parts[0], parts[1], nil +} + +// Optional if supported: +func (c *SFCClient) RebootInstance(_ context.Context, _ v1.CloudProviderInstanceID) error { + return v1.ErrNotImplemented +} + +func (c *SFCClient) StopInstance(_ context.Context, _ v1.CloudProviderInstanceID) error { + return v1.ErrNotImplemented +} + +func (c *SFCClient) StartInstance(_ context.Context, _ v1.CloudProviderInstanceID) error { + return v1.ErrNotImplemented +} + +// Merge strategies (pass-through is acceptable baseline). +func (c *SFCClient) MergeInstanceForUpdate(_ v1.Instance, newInst v1.Instance) v1.Instance { + return newInst +} + +func (c *SFCClient) MergeInstanceTypeForUpdate(_ v1.InstanceType, newIt v1.InstanceType) v1.InstanceType { + return newIt +} diff --git a/v1/providers/sfcompute/instancetype.go b/v1/providers/sfcompute/instancetype.go new file mode 100644 index 0000000..6858e93 --- /dev/null +++ b/v1/providers/sfcompute/instancetype.go @@ -0,0 +1,269 @@ +package v1 + +import ( + "context" + "fmt" + "slices" + "strings" + "time" + + "github.com/alecthomas/units" + "github.com/bojanz/currency" + sfcnodes "github.com/sfcompute/nodes-go" + + v1 "github.com/brevdev/cloud/v1" +) + +const ( + gpuTypeH100 = "h100" + gpuTypeH200 = "h200" + + deliveryTypeVM = "VM" + interconnectInfiniband = "infiniband" + formFactorSXM5 = "sxm5" + diskTypeSSD = "ssd" +) + +var allowedZones = []string{"hayesvalley", "yerba"} + +func makeDefaultInstanceTypePrice(amount string, currencyCode string) currency.Amount { + instanceTypePrice, err := currency.NewAmount(amount, currencyCode) + if err != nil { + panic(err) + } + return instanceTypePrice +} + +func (c *SFCClient) GetInstanceTypes(ctx context.Context, args v1.GetInstanceTypeArgs) ([]v1.InstanceType, error) { + c.logger.Debug(ctx, "sfc: GetInstanceTypes start", + v1.LogField("location", c.location), + v1.LogField("args", fmt.Sprintf("%+v", args)), + ) + + // Fetch all available zones + includeUnavailable := false + zones, err := c.getZones(ctx, includeUnavailable) + if err != nil { + return nil, err + } + + c.logger.Debug(ctx, "sfc: GetInstanceTypes zones list", + v1.LogField("zone count", len(zones)), + ) + + instanceTypes := make([]v1.InstanceType, 0, len(zones)) + for _, zone := range zones { + gpuType := strings.ToLower(string(zone.HardwareType)) + + if !gpuTypeIsAllowed(gpuType) { + c.logger.Debug(ctx, "sfc: GetInstanceTypes gpu type not allowed", + v1.LogField("gpuType", gpuType), + ) + continue + } + + instanceType, err := getInstanceTypeForZone(zone) + if err != nil { + return nil, err + } + + if !v1.IsSelectedByArgs(*instanceType, args) { + c.logger.Debug(ctx, "sfc: GetInstanceTypes instance type not selected by args", + v1.LogField("instanceType", instanceType.Type), + ) + continue + } + + instanceTypes = append(instanceTypes, *instanceType) + } + + c.logger.Debug(ctx, "sfc: GetInstanceTypes end", + v1.LogField("instanceType count", len(instanceTypes)), + ) + + return instanceTypes, nil +} + +func getInstanceTypeForZone(zone sfcnodes.ZoneListResponseData) (*v1.InstanceType, error) { + gpuType := strings.ToLower(string(zone.HardwareType)) + + gpuMetadata, err := getInstanceTypeMetadata(gpuType) + if err != nil { + return nil, err + } + + ramInt64, err := gpuMetadata.memoryBytes.ByteCountInUnitInt64(v1.Gibibyte) + if err != nil { + return nil, err + } + ram := units.Base2Bytes(ramInt64 * int64(units.Gibibyte)) + + memoryInt64, err := gpuMetadata.gpuVRAM.ByteCountInUnitInt64(v1.Gibibyte) + if err != nil { + return nil, err + } + memory := units.Base2Bytes(memoryInt64 * int64(units.Gibibyte)) + + diskSizeInt64, err := gpuMetadata.diskBytes.ByteCountInUnitInt64(v1.Gibibyte) + if err != nil { + return nil, err + } + diskSize := units.Base2Bytes(diskSizeInt64 * int64(units.Gibibyte)) + + instanceType := v1.InstanceType{ + IsAvailable: true, + Type: makeInstanceTypeName(zone), + Memory: ram, + MemoryBytes: gpuMetadata.memoryBytes, + Location: zoneToLocation(zone).Name, + Stoppable: false, + Rebootable: false, + IsContainer: false, + Provider: CloudProviderID, + BasePrice: &gpuMetadata.price, + EstimatedDeployTime: &gpuMetadata.estimatedDeployTime, + SupportedGPUs: []v1.GPU{{ + Count: gpuMetadata.gpuCount, + Type: strings.ToUpper(gpuType), + Manufacturer: gpuMetadata.gpuManufacturer, + Name: strings.ToUpper(gpuType), + Memory: memory, + MemoryBytes: gpuMetadata.gpuVRAM, + NetworkDetails: gpuMetadata.formFactor, + }}, + SupportedStorage: []v1.Storage{{ + Type: diskTypeSSD, + Count: 1, + Size: diskSize, + SizeBytes: gpuMetadata.diskBytes, + }}, + SupportedArchitectures: []v1.Architecture{gpuMetadata.architecture}, + } + + instanceType.ID = v1.MakeGenericInstanceTypeID(instanceType) + + return &instanceType, nil +} + +func gpuTypeIsAllowed(gpuType string) bool { + return gpuType == gpuTypeH100 || gpuType == gpuTypeH200 +} + +func makeInstanceTypeName(zone sfcnodes.ZoneListResponseData) string { + interconnect := "" + if strings.ToLower(zone.InterconnectType) == interconnectInfiniband { + interconnect = ".ib" + } + return fmt.Sprintf("%s%s", strings.ToLower(string(zone.HardwareType)), interconnect) +} + +func (c *SFCClient) GetLocations(ctx context.Context, args v1.GetLocationsArgs) ([]v1.Location, error) { + zones, err := c.getZones(ctx, args.IncludeUnavailable) + if err != nil { + return nil, err + } + + locations := make([]v1.Location, 0, len(zones)) + for _, zone := range zones { + location := zoneToLocation(zone) + locations = append(locations, location) + } + + return locations, nil +} + +func (c *SFCClient) getZones(ctx context.Context, includeUnavailable bool) ([]sfcnodes.ZoneListResponseData, error) { + // Fetch the zones from the API + resp, err := c.client.Zones.List(ctx) + if err != nil { + return nil, err + } + + // If there are no zones, return an empty list + if resp == nil || len(resp.Data) == 0 { + return []sfcnodes.ZoneListResponseData{}, nil + } + + zones := make([]sfcnodes.ZoneListResponseData, 0, len(resp.Data)) + for _, zone := range resp.Data { + // If the zone is not allowed, skip it + if !slices.Contains(allowedZones, strings.ToLower(zone.Name)) { + continue + } + + // If the there is no available capacity, and skip it + if len(zone.AvailableCapacity) == 0 && !includeUnavailable { + continue + } + + // If the delivery type is not VM, skip it + if zone.DeliveryType != deliveryTypeVM { + continue + } + + // Add the zone to the list + zones = append(zones, zone) + } + + return zones, nil +} + +func zoneToLocation(zone sfcnodes.ZoneListResponseData) v1.Location { + return v1.Location{ + Name: zone.Name, + Description: fmt.Sprintf("sfc_%s_%s", zone.Name, string(zone.HardwareType)), + Available: true, + } +} + +// sfcInstanceTypeMetadata is a struct that contains the metadata for a given instance type. +// These values are not currently provided by the SFCompute API, so we need to hardcode them. +type sfcInstanceTypeMetadata struct { + gpuType string + formFactor string + architecture v1.Architecture + memoryBytes v1.Bytes + diskBytes v1.Bytes + gpuCount int32 + gpuManufacturer v1.Manufacturer + gpuVRAM v1.Bytes + estimatedDeployTime time.Duration + price currency.Amount +} + +func getInstanceTypeMetadata(gpuType string) (*sfcInstanceTypeMetadata, error) { + switch gpuType { + case gpuTypeH100: + return &h100InstanceTypeMetadata, nil + case gpuTypeH200: + return &h200InstanceTypeMetadata, nil + default: + return nil, fmt.Errorf("invalid GPU type: %s", gpuType) + } +} + +var h100InstanceTypeMetadata = sfcInstanceTypeMetadata{ + gpuType: gpuTypeH100, + formFactor: formFactorSXM5, + architecture: v1.ArchitectureX86_64, + memoryBytes: v1.NewBytes(960, v1.Gigabyte), + diskBytes: v1.NewBytes(1500, v1.Gigabyte), + gpuCount: 8, + gpuManufacturer: v1.ManufacturerNVIDIA, + gpuVRAM: v1.NewBytes(80, v1.Gigabyte), + estimatedDeployTime: 14 * time.Minute, + price: makeDefaultInstanceTypePrice("16.00", "USD"), +} + +var h200InstanceTypeMetadata = sfcInstanceTypeMetadata{ + gpuType: gpuTypeH200, + formFactor: formFactorSXM5, + architecture: v1.ArchitectureX86_64, + memoryBytes: v1.NewBytes(960, v1.Gigabyte), + diskBytes: v1.NewBytes(1500, v1.Gigabyte), + gpuCount: 8, + gpuManufacturer: v1.ManufacturerNVIDIA, + gpuVRAM: v1.NewBytes(141, v1.Gigabyte), + estimatedDeployTime: 14 * time.Minute, + price: makeDefaultInstanceTypePrice("24.00", "USD"), +} diff --git a/v1/providers/sfcompute/scripts/instancetype_test.go b/v1/providers/sfcompute/scripts/instancetype_test.go new file mode 100644 index 0000000..d320b3a --- /dev/null +++ b/v1/providers/sfcompute/scripts/instancetype_test.go @@ -0,0 +1,174 @@ +//go:build scripts +// +build scripts + +package scripts + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/brevdev/cloud/internal/ssh" + v1 "github.com/brevdev/cloud/v1" + "github.com/google/uuid" +) + +func TestGetInstanceTypes(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + credential := NewSFCCredential("validation-test", apiKey) + client, err := credential.MakeClient(context.Background(), "eu-north1") + if err != nil { + t.Fatalf("failed to make client: %v", err) + } + + locations, err := client.GetLocations(context.Background(), v1.GetLocationsArgs{ + IncludeUnavailable: true, + }) + if err != nil { + t.Fatalf("failed to get locations: %v", err) + } + + t.Logf("locations: %v", locations) + + instanceTypes, err := client.GetInstanceTypes(context.Background(), v1.GetInstanceTypeArgs{ + Locations: v1.LocationsFilter{"all"}, + }) + if err != nil { + t.Fatalf("failed to get instance types: %v", err) + } + + t.Logf("instance types: %v", instanceTypes) +} + +func TestCreateInstance(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + credential := NewSFCCredential("validation-test", apiKey) + client, err := credential.MakeClient(context.Background(), "eu-north1") + if err != nil { + t.Fatalf("failed to make client: %v", err) + } + + id := uuid.New().String() + + instance, err := client.CreateInstance(context.Background(), v1.CreateInstanceAttrs{ + Name: "test", + RefID: id, + PublicKey: ssh.GetTestPublicKey(), + InstanceType: "h100", + Location: "hayesvalley", + }) + if err != nil { + t.Fatalf("failed to create instance: %v", err) + } + + t.Logf("instance: %v", instance) +} + +func TestGetInstance(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + credential := NewSFCCredential("validation-test", apiKey) + client, err := credential.MakeClient(context.Background(), "") + if err != nil { + t.Fatalf("failed to make client: %v", err) + } + + instance, err := client.GetInstance(context.Background(), "6c7a3ade-1e59-4e04-af6e-365046995a81_test") + if err != nil { + t.Fatalf("failed to get instance: %v", err) + } + + t.Logf("instance: %v", instance) + + // status + t.Logf("status: %v", instance.Status) + + // ssh details + t.Logf("ssh details: %v,%v,%v", instance.SSHUser, instance.SSHPort, instance.PublicIP) +} + +func TestSSHInstance(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + credential := NewSFCCredential("validation-test", apiKey) + client, err := credential.MakeClient(context.Background(), "") + if err != nil { + t.Fatalf("failed to make client: %v", err) + } + + instance, err := client.GetInstance(context.Background(), "6c7a3ade-1e59-4e04-af6e-365046995a81_test") + if err != nil { + t.Fatalf("failed to get instance: %v", err) + } + + t.Logf("instance: %v", instance) + + // ssh details + t.Logf("ssh details: %v,%v,%v", instance.SSHUser, instance.SSHPort, instance.PublicIP) + + // ssh to instance + err = ssh.WaitForSSH(context.Background(), ssh.ConnectionConfig{ + User: "root", + HostPort: fmt.Sprintf("%s:%d", instance.PublicIP, instance.SSHPort), + PrivKey: ssh.GetTestPrivateKey(), + }, ssh.WaitForSSHOptions{ + Timeout: 10 * time.Second, + }) + if err != nil { + t.Fatalf("failed to wait for SSH: %v", err) + } + + t.Logf("SSH connection validated successfully for %s@%s:%d", instance.SSHUser, instance.PublicIP, instance.SSHPort) +} + +func TestListInstances(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + credential := NewSFCCredential("validation-test", apiKey) + client, err := credential.MakeClient(context.Background(), "") + if err != nil { + t.Fatalf("failed to make client: %v", err) + } + + instances, err := client.ListInstances(context.Background(), v1.ListInstancesArgs{ + TagFilters: map[string][]string{ + "dev-plane-managedBy": {"dev-plane"}, + }, + Locations: v1.All, + }) + if err != nil { + t.Fatalf("failed to list instances: %v", err) + } + + t.Logf("instances: %v", instances) +} + +func TestTerminateInstance(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + credential := NewSFCCredential("validation-test", apiKey) + client, err := credential.MakeClient(context.Background(), "") + if err != nil { + t.Fatalf("failed to make client: %v", err) + } + + err = client.TerminateInstance(context.Background(), "6c7a3ade-1e59-4e04-af6e-365046995a81_test") + if err != nil { + t.Fatalf("failed to terminate instance: %v", err) + } +} diff --git a/v1/providers/sfcompute/validation_test.go b/v1/providers/sfcompute/validation_test.go new file mode 100644 index 0000000..196c739 --- /dev/null +++ b/v1/providers/sfcompute/validation_test.go @@ -0,0 +1,52 @@ +package v1 + +import ( + "os" + "testing" + + "github.com/brevdev/cloud/internal/validation" + v1 "github.com/brevdev/cloud/v1" +) + +func TestValidationFunctions(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + config := validation.ProviderConfig{ + Credential: NewSFCCredential("validation-test", apiKey), + StableIDs: []v1.InstanceTypeID{ + "hayesvalley-noSub-h100", + "yerba-noSub-h100", + }, + } + + validation.RunValidationSuite(t, config) +} + +func TestInstanceLifecycleValidation(t *testing.T) { + t.Parallel() + checkSkip(t) + apiKey := getAPIKey() + + config := validation.ProviderConfig{ + Credential: NewSFCCredential("validation-test", apiKey), + Location: "yerba", + } + + validation.RunInstanceLifecycleValidation(t, config) +} + +func checkSkip(t *testing.T) { + apiKey := getAPIKey() + isValidationTest := os.Getenv("VALIDATION_TEST") + if apiKey == "" && isValidationTest != "" { + t.Fatal("SFCOMPUTE_API_KEY not set, but VALIDATION_TEST is set") + } else if apiKey == "" && isValidationTest == "" { + t.Skip("SFCOMPUTE_API_KEY not set, skipping sfcompute validation tests") + } +} + +func getAPIKey() string { + return os.Getenv("SFCOMPUTE_API_KEY") +}