Skip to content
100 changes: 65 additions & 35 deletions cli/azd/extensions/azure.ai.finetune/internal/cmd/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/azure/azure-dev/cli/azd/pkg/azdext"
"github.com/azure/azure-dev/cli/azd/pkg/ux"

FTYaml "azure.ai.finetune/internal/fine_tuning_yaml"
"azure.ai.finetune/internal/services"
JobWrapper "azure.ai.finetune/internal/tools"
"azure.ai.finetune/internal/utils"
Expand Down Expand Up @@ -68,15 +67,18 @@ func formatFineTunedModel(model string) string {

func newOperationSubmitCommand() *cobra.Command {
var filename string
var model string
var trainingFile string
var validationFile string
var suffix string
var seed int64
cmd := &cobra.Command{
Use: "submit",
Short: "Submit fine tuning job",
Short: "submit fine tuning job",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := azdext.WithAccessToken(cmd.Context())

// Validate filename is provided
if filename == "" {
return fmt.Errorf("config file is required, use -f or --file flag")
if filename == "" && (model == "" || trainingFile == "") {
return fmt.Errorf("either config file or model and training-file parameters are required")
}

azdClient, err := azdext.NewAzdClient()
Expand All @@ -85,60 +87,88 @@ func newOperationSubmitCommand() *cobra.Command {
}
defer azdClient.Close()

// Parse and validate the YAML configuration file
color.Green("Parsing configuration file...")
config, err := FTYaml.ParseFineTuningConfig(filename)
if err != nil {
return err
// Show spinner while creating job
spinner := ux.NewSpinner(&ux.SpinnerOptions{
Text: "creating fine-tuning job...",
})
if err := spinner.Start(ctx); err != nil {
fmt.Printf("failed to start spinner: %v\n", err)
}

// Upload training file
// Parse and validate the YAML configuration file if provided
var config *models.CreateFineTuningRequest
if filename != "" {
color.Green("\nparsing configuration file...")
config, err = utils.ParseCreateFineTuningRequestConfig(filename)
if err != nil {
_ = spinner.Stop(ctx)
fmt.Println()
return err
}
} else {
config = &models.CreateFineTuningRequest{}
}

trainingFileID, err := JobWrapper.UploadFileIfLocal(ctx, azdClient, config.TrainingFile)
if err != nil {
return fmt.Errorf("failed to upload training file: %w", err)
// Override config values with command-line parameters if provided
if model != "" {
config.BaseModel = model
}
if trainingFile != "" {

// Upload validation file if provided
var validationFileID string
if config.ValidationFile != "" {
validationFileID, err = JobWrapper.UploadFileIfLocal(ctx, azdClient, config.ValidationFile)
if err != nil {
return fmt.Errorf("failed to upload validation file: %w", err)
}
config.TrainingFile = trainingFile
}
if validationFile != "" {
config.ValidationFile = &validationFile
}
if suffix != "" {
config.Suffix = &suffix
}
if seed != 0 {
config.Seed = &seed
}

// Create fine-tuning job
// Convert YAML configuration to OpenAI job parameters
jobParams, err := ConvertYAMLToJobParams(config, trainingFileID, validationFileID)
fineTuneSvc, err := services.NewFineTuningService(ctx, azdClient, nil)
if err != nil {
return fmt.Errorf("failed to convert configuration to job parameters: %w", err)
_ = spinner.Stop(ctx)
fmt.Println()
return err
}

// Submit the fine-tuning job using CreateJob from JobWrapper
job, err := JobWrapper.CreateJob(ctx, azdClient, jobParams)
job, err := fineTuneSvc.CreateFineTuningJob(ctx, config)
_ = spinner.Stop(ctx)
fmt.Println()

if err != nil {
return err
}

// Print success message
fmt.Println(strings.Repeat("=", 120))
color.Green("\nSuccessfully submitted fine-tuning Job!\n")
fmt.Printf("Job ID: %s\n", job.Id)
fmt.Printf("Model: %s\n", job.Model)
fmt.Println("\n", strings.Repeat("=", 60))
color.Green("\nsuccessfully submitted fine-tuning Job!\n")
fmt.Printf("Job ID: %s\n", job.ID)
fmt.Printf("Model: %s\n", job.BaseModel)
fmt.Printf("Status: %s\n", job.Status)
fmt.Printf("Created: %s\n", job.CreatedAt)
if job.FineTunedModel != "" {
fmt.Printf("Fine-tuned: %s\n", job.FineTunedModel)
}
fmt.Println(strings.Repeat("=", 120))

fmt.Println(strings.Repeat("=", 60))
return nil
},
}

cmd.Flags().StringVarP(&filename, "file", "f", "", "Path to the config file")

cmd.Flags().StringVarP(&filename, "file", "f", "", "Path to the config file.")
cmd.Flags().StringVarP(&model, "model", "m", "", "Base model to fine-tune. Overrides config file. Required if --file is not provided")
cmd.Flags().StringVarP(&trainingFile, "training-file", "t", "", "Training file ID or local path. Use 'local:' prefix for local paths. Required if --file is not provided")
cmd.Flags().StringVarP(&validationFile, "validation-file", "v", "", "Validation file ID or local path. Use 'local:' prefix for local paths.")
cmd.Flags().StringVarP(&suffix, "suffix", "s", "", "An optional string of up to 64 characters that will be added to your fine-tuned model name. Overrides config file.")
cmd.Flags().Int64VarP(&seed, "seed", "r", 0, "Random seed for reproducibility of the job. If a seed is not specified, one will be generated for you. Overrides config file.")

//Either config file should be provided or at least `model` & `training-file` parameters
cmd.MarkFlagFilename("file", "yaml", "yml")
cmd.MarkFlagsOneRequired("file", "model")
cmd.MarkFlagsRequiredTogether("model", "training-file")
return cmd
}

Expand Down
Loading