diff --git a/internal/cli/dataset.go b/internal/cli/dataset.go index bfb7b8b..6a8b098 100644 --- a/internal/cli/dataset.go +++ b/internal/cli/dataset.go @@ -300,6 +300,10 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush // leaves Prompter nil and skips straight to the flag-only path. if a.Interactive && a.Prompter != nil { if err := runInteractive(a.Printer, a.Prompter, &a, a.CategorySet); err != nil { + if errors.Is(err, errInteractiveCancelled) { + a.Printer.Infof("Cancelled — nothing was pushed.") + return nil + } return &exitError{code: 3, err: fmt.Errorf("interactive setup: %w", err)} } } diff --git a/internal/cli/interactive.go b/internal/cli/interactive.go index 95357a6..aaff559 100644 --- a/internal/cli/interactive.go +++ b/internal/cli/interactive.go @@ -1,9 +1,14 @@ package cli import ( + "errors" + "fmt" "os" + "strconv" + "strings" "github.com/AlecAivazis/survey/v2" + "github.com/AlecAivazis/survey/v2/terminal" "golang.org/x/term" "github.com/tracebloc/cli/internal/push" @@ -31,6 +36,11 @@ var promptCategories = []string{ // returns scripted answers, so the prompt-mapping logic is unit- // testable without a pseudo-terminal — the same trick kubernetes.Interface // uses to let cluster code run against a fake clientset. +// errInteractiveCancelled is returned when the user declines the +// confirm prompt or hits Ctrl-C. It's control flow, not a failure: +// runDatasetPush maps it to a clean exit (0) with a "Cancelled" note. +var errInteractiveCancelled = errors.New("cancelled by user") + type prompter interface { // Input asks for free text. def pre-fills the answer; validate, if // non-nil, rejects bad input and re-prompts. @@ -38,6 +48,8 @@ type prompter interface { // Select asks the user to pick one of options; def is the // pre-highlighted choice. Select(label, help string, options []string, def string) (string, error) + // Confirm asks a yes/no question; def is the answer on a bare Enter. + Confirm(label string, def bool) (bool, error) } // surveyPrompter is the production prompter, backed by @@ -57,7 +69,7 @@ func (surveyPrompter) Input(label, help, def string, validate func(string) error })) } if err := survey.AskOne(q, &ans, opts...); err != nil { - return "", err + return "", mapErr(err) } return ans, nil } @@ -66,11 +78,30 @@ func (surveyPrompter) Select(label, help string, options []string, def string) ( var ans string q := &survey.Select{Message: label, Help: help, Options: options, Default: def} if err := survey.AskOne(q, &ans); err != nil { - return "", err + return "", mapErr(err) + } + return ans, nil +} + +func (surveyPrompter) Confirm(label string, def bool) (bool, error) { + ans := def + if err := survey.AskOne(&survey.Confirm{Message: label, Default: def}, &ans); err != nil { + return false, mapErr(err) } return ans, nil } +// mapErr translates survey's Ctrl-C (terminal.InterruptErr) into our +// errInteractiveCancelled, so the rest of the code never imports survey +// internals to recognize a cancellation — the prompter seam stays +// leak-free. +func mapErr(err error) error { + if errors.Is(err, terminal.InterruptErr) { + return errInteractiveCancelled + } + return err +} + // isInteractiveTTY reports whether we can run a guided prompt flow: // both stdin (we read answers) and stdout (we draw prompts) must be a // real terminal. Piped input, redirected output, or CI all fail this @@ -90,6 +121,7 @@ func isInteractiveTTY() bool { func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categorySet bool) error { p.PromptHeader("Let's set up your dataset push") p.Hintf("Press Enter to accept a default; Ctrl-C to cancel.") + prompted := false if a.LocalPath == "" { ans, err := pr.Input("Path to your dataset directory", "e.g. ./my-data", "", nil) @@ -97,6 +129,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS return err } a.LocalPath = ans + prompted = true } if !categorySet { @@ -106,6 +139,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS return err } a.Spec.Category = ans + prompted = true } if a.Spec.Table == "" { @@ -116,6 +150,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS return err } a.Spec.Table = ans + prompted = true } if a.Spec.Intent == "" { @@ -125,6 +160,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS return err } a.Spec.Intent = ans + prompted = true } // masked_language_modeling is self-supervised — no label column. @@ -135,7 +171,146 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS return err } a.Spec.LabelColumn = ans + prompted = true } + cp, err := promptCategorySpecific(pr, a) + if err != nil { + return err + } + prompted = prompted || cp + + // Confirm only when we actually prompted something — a push that's + // fully specified by flags (on a TTY) isn't nagged with a confirm. + if prompted { + renderReview(p, a) + ok, err := pr.Confirm("Proceed with the push?", true) + if err != nil { + return err + } + if !ok { + return errInteractiveCancelled + } + } return nil } + +// promptCategorySpecific prompts for the inputs a particular category +// needs beyond the core fields, filling only the gaps. Returns whether +// it prompted anything (so the caller knows to show the confirm). +func promptCategorySpecific(pr prompter, a *runDatasetPushArgs) (bool, error) { + cat := a.Spec.Category + prompted := false + switch { + case push.IsImage(cat): + if cat == "keypoint_detection" && a.Spec.NumberOfKeypoints <= 0 { + ans, err := pr.Input("Number of keypoints per sample", + "e.g. 17 for COCO pose", "", validatePositiveInt) + if err != nil { + return prompted, err + } + n, _ := strconv.Atoi(strings.TrimSpace(ans)) + a.Spec.NumberOfKeypoints = n + prompted = true + } + if a.TargetSizeFlag == "" { + ans, err := pr.Input("Image resolution as WxH (blank = auto-detect from the first image)", + "all images must share it; the ingestor validates, it doesn't resize", "", + validateOptionalTargetSize) + if err != nil { + return prompted, err + } + a.TargetSizeFlag = strings.TrimSpace(ans) + prompted = true + } + case push.IsTabular(cat): + if a.SchemaFlag == "" { + ans, err := pr.Input("Column schema as col:TYPE,... (blank = infer from the CSV)", + "e.g. age:INT,price:FLOAT", "", validateOptionalSchema) + if err != nil { + return prompted, err + } + a.SchemaFlag = strings.TrimSpace(ans) + prompted = true + } + if push.IsRegressionClass(cat) && a.Spec.LabelPolicy == "" { + ans, err := pr.Select("Label policy", + "bucket bins the target before it leaves the cluster", + []string{"bucket", "passthrough"}, "bucket") + if err != nil { + return prompted, err + } + a.Spec.LabelPolicy = ans + prompted = true + } + if cat == "time_to_event_prediction" && a.Spec.TimeColumn == "" { + ans, err := pr.Input("Time column", "the duration/time column name", "time", nil) + if err != nil { + return prompted, err + } + a.Spec.TimeColumn = strings.TrimSpace(ans) + prompted = true + } + } + return prompted, nil +} + +// renderReview prints the assembled push inputs before the confirm +// prompt, so the user sees exactly what's about to happen. +func renderReview(p *ui.Printer, a *runDatasetPushArgs) { + p.Section("Review") + p.Field("path", a.LocalPath) + p.Field("category", a.Spec.Category) + p.Field("table", a.Spec.Table) + p.Field("intent", a.Spec.Intent) + if a.Spec.LabelColumn != "" { + p.Field("label column", a.Spec.LabelColumn) + } + if a.Spec.NumberOfKeypoints > 0 { + p.Field("keypoints", strconv.Itoa(a.Spec.NumberOfKeypoints)) + } + switch { + case a.TargetSizeFlag != "": + p.Field("resolution", a.TargetSizeFlag) + case push.IsImage(a.Spec.Category): + p.Field("resolution", "auto-detect") + } + switch { + case a.SchemaFlag != "": + p.Field("schema", a.SchemaFlag) + case push.IsTabular(a.Spec.Category): + p.Field("schema", "infer from CSV") + } + if a.Spec.LabelPolicy != "" { + p.Field("label policy", a.Spec.LabelPolicy) + } + if a.Spec.TimeColumn != "" { + p.Field("time column", a.Spec.TimeColumn) + } +} + +// validatePositiveInt accepts a string that parses to an int > 0. +func validatePositiveInt(s string) error { + if n, err := strconv.Atoi(strings.TrimSpace(s)); err != nil || n <= 0 { + return fmt.Errorf("must be a positive integer") + } + return nil +} + +// validateOptionalTargetSize accepts "" (auto-detect) or a valid WxH. +func validateOptionalTargetSize(s string) error { + if strings.TrimSpace(s) == "" { + return nil + } + _, _, err := push.ParseTargetSize(s) + return err +} + +// validateOptionalSchema accepts "" (infer) or a valid col:TYPE,... set. +func validateOptionalSchema(s string) error { + if strings.TrimSpace(s) == "" { + return nil + } + _, err := push.ParseSchema(s) + return err +} diff --git a/internal/cli/interactive_test.go b/internal/cli/interactive_test.go index 81deec9..262e950 100644 --- a/internal/cli/interactive_test.go +++ b/internal/cli/interactive_test.go @@ -2,6 +2,7 @@ package cli import ( "bytes" + "errors" "testing" "github.com/tracebloc/cli/internal/push" @@ -15,6 +16,7 @@ import ( type fakePrompter struct { answers map[string]string asked []string + confirm *bool // nil → return the prompt's default (true) } func (f *fakePrompter) answer(label, def string) string { @@ -39,6 +41,13 @@ func (f *fakePrompter) Select(label, _ /*help*/ string, _ []string, def string) return f.answer(label, def), nil } +func (f *fakePrompter) Confirm(_ string, def bool) (bool, error) { + if f.confirm != nil { + return *f.confirm, nil + } + return def, nil +} + func discardPrinter() *ui.Printer { return ui.New(&bytes.Buffer{}) } // TestRunInteractive_FillsAllWhenEmpty: a bare invocation prompts for @@ -77,10 +86,12 @@ func TestRunInteractive_FillsAllWhenEmpty(t *testing.T) { // explicit --category) mean nothing is prompted. func TestRunInteractive_SkipsProvidedValues(t *testing.T) { f := &fakePrompter{answers: map[string]string{}} + // text_classification has no category-specific prompts, so with all + // core fields set + an explicit --category, nothing is asked. a := &runDatasetPushArgs{ LocalPath: "./data", Spec: push.SpecArgs{ - Category: "image_classification", Table: "t", Intent: "train", LabelColumn: "label", + Category: "text_classification", Table: "t", Intent: "train", LabelColumn: "label", }, } if err := runInteractive(discardPrinter(), f, a, true /*categorySet*/); err != nil { @@ -91,6 +102,62 @@ func TestRunInteractive_SkipsProvidedValues(t *testing.T) { } } +// TestRunInteractive_Keypoint prompts for the required keypoint count; +// the optional resolution left blank means auto-detect. +func TestRunInteractive_Keypoint(t *testing.T) { + f := &fakePrompter{answers: map[string]string{"Number of keypoints per sample": "17"}} + a := &runDatasetPushArgs{ + LocalPath: "./kp", + Spec: push.SpecArgs{Category: "keypoint_detection", Table: "kp_train", Intent: "train", LabelColumn: "image_label"}, + } + if err := runInteractive(discardPrinter(), f, a, true); err != nil { + t.Fatalf("runInteractive: %v", err) + } + if a.Spec.NumberOfKeypoints != 17 { + t.Errorf("NumberOfKeypoints = %d, want 17", a.Spec.NumberOfKeypoints) + } + if a.TargetSizeFlag != "" { + t.Errorf("TargetSizeFlag = %q, want empty (auto-detect)", a.TargetSizeFlag) + } +} + +// TestRunInteractive_TabularRegression prompts for the label policy +// (regression-class) and leaves the schema to inference. +func TestRunInteractive_TabularRegression(t *testing.T) { + f := &fakePrompter{answers: map[string]string{"Label policy": "passthrough"}} + a := &runDatasetPushArgs{ + LocalPath: "./tab", + Spec: push.SpecArgs{Category: "tabular_regression", Table: "reg_train", Intent: "train", LabelColumn: "Target"}, + } + if err := runInteractive(discardPrinter(), f, a, true); err != nil { + t.Fatalf("runInteractive: %v", err) + } + if a.Spec.LabelPolicy != "passthrough" { + t.Errorf("LabelPolicy = %q, want passthrough", a.Spec.LabelPolicy) + } + if a.SchemaFlag != "" { + t.Errorf("SchemaFlag = %q, want empty (infer)", a.SchemaFlag) + } +} + +// TestRunInteractive_Cancel: declining the confirm returns the +// cancellation sentinel — a clean abort, not a failure. +func TestRunInteractive_Cancel(t *testing.T) { + no := false + f := &fakePrompter{ + answers: map[string]string{"Path to your dataset directory": "./x"}, + confirm: &no, + } + // path is prompted (→ prompted=true → a confirm is shown); the rest + // is pre-set so we reach the confirm cleanly. + a := &runDatasetPushArgs{Spec: push.SpecArgs{ + Category: "image_classification", Table: "t", Intent: "train", LabelColumn: "label", + }} + if err := runInteractive(discardPrinter(), f, a, true); !errors.Is(err, errInteractiveCancelled) { + t.Fatalf("err = %v, want errInteractiveCancelled", err) + } +} + // TestRunInteractive_MLMSkipsLabel: masked_language_modeling has no // label column, so it must not be prompted. func TestRunInteractive_MLMSkipsLabel(t *testing.T) {