Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/cli/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ func runDatasetPush(ctx context.Context, out, errOut io.Writer, a runDatasetPush
// leaves Prompter nil and skips straight to the flag-only path.
if a.Interactive && a.Prompter != nil {
if err := runInteractive(a.Printer, a.Prompter, &a, a.CategorySet); err != nil {
if errors.Is(err, errInteractiveCancelled) {
a.Printer.Infof("Cancelled — nothing was pushed.")
return nil
}
return &exitError{code: 3, err: fmt.Errorf("interactive setup: %w", err)}
}
}
Expand Down
179 changes: 177 additions & 2 deletions internal/cli/interactive.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package cli

import (
"errors"
"fmt"
"os"
"strconv"
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"golang.org/x/term"

"github.com/tracebloc/cli/internal/push"
Expand Down Expand Up @@ -31,13 +36,20 @@ var promptCategories = []string{
// returns scripted answers, so the prompt-mapping logic is unit-
// testable without a pseudo-terminal — the same trick kubernetes.Interface
// uses to let cluster code run against a fake clientset.
// errInteractiveCancelled is returned when the user declines the
// confirm prompt or hits Ctrl-C. It's control flow, not a failure:
// runDatasetPush maps it to a clean exit (0) with a "Cancelled" note.
var errInteractiveCancelled = errors.New("cancelled by user")

type prompter interface {
// Input asks for free text. def pre-fills the answer; validate, if
// non-nil, rejects bad input and re-prompts.
Input(label, help, def string, validate func(string) error) (string, error)
// Select asks the user to pick one of options; def is the
// pre-highlighted choice.
Select(label, help string, options []string, def string) (string, error)
// Confirm asks a yes/no question; def is the answer on a bare Enter.
Confirm(label string, def bool) (bool, error)
}

// surveyPrompter is the production prompter, backed by
Expand All @@ -57,7 +69,7 @@ func (surveyPrompter) Input(label, help, def string, validate func(string) error
}))
}
if err := survey.AskOne(q, &ans, opts...); err != nil {
return "", err
return "", mapErr(err)
}
return ans, nil
}
Expand All @@ -66,11 +78,30 @@ func (surveyPrompter) Select(label, help string, options []string, def string) (
var ans string
q := &survey.Select{Message: label, Help: help, Options: options, Default: def}
if err := survey.AskOne(q, &ans); err != nil {
return "", err
return "", mapErr(err)
}
return ans, nil
}

func (surveyPrompter) Confirm(label string, def bool) (bool, error) {
ans := def
if err := survey.AskOne(&survey.Confirm{Message: label, Default: def}, &ans); err != nil {
return false, mapErr(err)
}
return ans, nil
}

// mapErr translates survey's Ctrl-C (terminal.InterruptErr) into our
// errInteractiveCancelled, so the rest of the code never imports survey
// internals to recognize a cancellation — the prompter seam stays
// leak-free.
func mapErr(err error) error {
if errors.Is(err, terminal.InterruptErr) {
return errInteractiveCancelled
}
return err
}

// isInteractiveTTY reports whether we can run a guided prompt flow:
// both stdin (we read answers) and stdout (we draw prompts) must be a
// real terminal. Piped input, redirected output, or CI all fail this
Expand All @@ -90,13 +121,15 @@ func isInteractiveTTY() bool {
func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categorySet bool) error {
p.PromptHeader("Let's set up your dataset push")
p.Hintf("Press Enter to accept a default; Ctrl-C to cancel.")
prompted := false

if a.LocalPath == "" {
ans, err := pr.Input("Path to your dataset directory", "e.g. ./my-data", "", nil)
if err != nil {
return err
}
a.LocalPath = ans
prompted = true
}

if !categorySet {
Expand All @@ -106,6 +139,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS
return err
}
a.Spec.Category = ans
prompted = true
}

if a.Spec.Table == "" {
Expand All @@ -116,6 +150,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS
return err
}
a.Spec.Table = ans
prompted = true
}

if a.Spec.Intent == "" {
Expand All @@ -125,6 +160,7 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS
return err
}
a.Spec.Intent = ans
prompted = true
}

// masked_language_modeling is self-supervised — no label column.
Expand All @@ -135,7 +171,146 @@ func runInteractive(p *ui.Printer, pr prompter, a *runDatasetPushArgs, categoryS
return err
}
a.Spec.LabelColumn = ans
prompted = true
}

cp, err := promptCategorySpecific(pr, a)
if err != nil {
return err
}
prompted = prompted || cp

// Confirm only when we actually prompted something — a push that's
// fully specified by flags (on a TTY) isn't nagged with a confirm.
if prompted {
renderReview(p, a)
ok, err := pr.Confirm("Proceed with the push?", true)
if err != nil {
return err
}
if !ok {
return errInteractiveCancelled
}
}
return nil
}

// promptCategorySpecific prompts for the inputs a particular category
// needs beyond the core fields, filling only the gaps. Returns whether
// it prompted anything (so the caller knows to show the confirm).
func promptCategorySpecific(pr prompter, a *runDatasetPushArgs) (bool, error) {
cat := a.Spec.Category
prompted := false
switch {
case push.IsImage(cat):
if cat == "keypoint_detection" && a.Spec.NumberOfKeypoints <= 0 {
ans, err := pr.Input("Number of keypoints per sample",
"e.g. 17 for COCO pose", "", validatePositiveInt)
if err != nil {
return prompted, err
}
n, _ := strconv.Atoi(strings.TrimSpace(ans))
a.Spec.NumberOfKeypoints = n
prompted = true
}
if a.TargetSizeFlag == "" {
ans, err := pr.Input("Image resolution as WxH (blank = auto-detect from the first image)",
"all images must share it; the ingestor validates, it doesn't resize", "",
validateOptionalTargetSize)
if err != nil {
return prompted, err
}
a.TargetSizeFlag = strings.TrimSpace(ans)
prompted = true
}
case push.IsTabular(cat):
if a.SchemaFlag == "" {
ans, err := pr.Input("Column schema as col:TYPE,... (blank = infer from the CSV)",
"e.g. age:INT,price:FLOAT", "", validateOptionalSchema)
if err != nil {
return prompted, err
}
a.SchemaFlag = strings.TrimSpace(ans)
prompted = true
}
if push.IsRegressionClass(cat) && a.Spec.LabelPolicy == "" {
ans, err := pr.Select("Label policy",
"bucket bins the target before it leaves the cluster",
[]string{"bucket", "passthrough"}, "bucket")
if err != nil {
return prompted, err
}
a.Spec.LabelPolicy = ans
prompted = true
}
if cat == "time_to_event_prediction" && a.Spec.TimeColumn == "" {
ans, err := pr.Input("Time column", "the duration/time column name", "time", nil)
if err != nil {
return prompted, err
}
a.Spec.TimeColumn = strings.TrimSpace(ans)
prompted = true
}
}
return prompted, nil
}

// renderReview prints the assembled push inputs before the confirm
// prompt, so the user sees exactly what's about to happen.
func renderReview(p *ui.Printer, a *runDatasetPushArgs) {
p.Section("Review")
p.Field("path", a.LocalPath)
p.Field("category", a.Spec.Category)
p.Field("table", a.Spec.Table)
p.Field("intent", a.Spec.Intent)
if a.Spec.LabelColumn != "" {
p.Field("label column", a.Spec.LabelColumn)
}
if a.Spec.NumberOfKeypoints > 0 {
p.Field("keypoints", strconv.Itoa(a.Spec.NumberOfKeypoints))
}
switch {
case a.TargetSizeFlag != "":
p.Field("resolution", a.TargetSizeFlag)
case push.IsImage(a.Spec.Category):
p.Field("resolution", "auto-detect")
}
switch {
case a.SchemaFlag != "":
p.Field("schema", a.SchemaFlag)
case push.IsTabular(a.Spec.Category):
p.Field("schema", "infer from CSV")
}
if a.Spec.LabelPolicy != "" {
p.Field("label policy", a.Spec.LabelPolicy)
}
if a.Spec.TimeColumn != "" {
p.Field("time column", a.Spec.TimeColumn)
}
}

// validatePositiveInt accepts a string that parses to an int > 0.
func validatePositiveInt(s string) error {
if n, err := strconv.Atoi(strings.TrimSpace(s)); err != nil || n <= 0 {
return fmt.Errorf("must be a positive integer")
}
return nil
}

// validateOptionalTargetSize accepts "" (auto-detect) or a valid WxH.
func validateOptionalTargetSize(s string) error {
if strings.TrimSpace(s) == "" {
return nil
}
_, _, err := push.ParseTargetSize(s)
return err
}

// validateOptionalSchema accepts "" (infer) or a valid col:TYPE,... set.
func validateOptionalSchema(s string) error {
if strings.TrimSpace(s) == "" {
return nil
}
_, err := push.ParseSchema(s)
return err
}
69 changes: 68 additions & 1 deletion internal/cli/interactive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"bytes"
"errors"
"testing"

"github.com/tracebloc/cli/internal/push"
Expand All @@ -15,6 +16,7 @@ import (
type fakePrompter struct {
answers map[string]string
asked []string
confirm *bool // nil → return the prompt's default (true)
}

func (f *fakePrompter) answer(label, def string) string {
Expand All @@ -39,6 +41,13 @@ func (f *fakePrompter) Select(label, _ /*help*/ string, _ []string, def string)
return f.answer(label, def), nil
}

func (f *fakePrompter) Confirm(_ string, def bool) (bool, error) {
if f.confirm != nil {
return *f.confirm, nil
}
return def, nil
}

func discardPrinter() *ui.Printer { return ui.New(&bytes.Buffer{}) }

// TestRunInteractive_FillsAllWhenEmpty: a bare invocation prompts for
Expand Down Expand Up @@ -77,10 +86,12 @@ func TestRunInteractive_FillsAllWhenEmpty(t *testing.T) {
// explicit --category) mean nothing is prompted.
func TestRunInteractive_SkipsProvidedValues(t *testing.T) {
f := &fakePrompter{answers: map[string]string{}}
// text_classification has no category-specific prompts, so with all
// core fields set + an explicit --category, nothing is asked.
a := &runDatasetPushArgs{
LocalPath: "./data",
Spec: push.SpecArgs{
Category: "image_classification", Table: "t", Intent: "train", LabelColumn: "label",
Category: "text_classification", Table: "t", Intent: "train", LabelColumn: "label",
},
}
if err := runInteractive(discardPrinter(), f, a, true /*categorySet*/); err != nil {
Expand All @@ -91,6 +102,62 @@ func TestRunInteractive_SkipsProvidedValues(t *testing.T) {
}
}

// TestRunInteractive_Keypoint prompts for the required keypoint count;
// the optional resolution left blank means auto-detect.
func TestRunInteractive_Keypoint(t *testing.T) {
f := &fakePrompter{answers: map[string]string{"Number of keypoints per sample": "17"}}
a := &runDatasetPushArgs{
LocalPath: "./kp",
Spec: push.SpecArgs{Category: "keypoint_detection", Table: "kp_train", Intent: "train", LabelColumn: "image_label"},
}
if err := runInteractive(discardPrinter(), f, a, true); err != nil {
t.Fatalf("runInteractive: %v", err)
}
if a.Spec.NumberOfKeypoints != 17 {
t.Errorf("NumberOfKeypoints = %d, want 17", a.Spec.NumberOfKeypoints)
}
if a.TargetSizeFlag != "" {
t.Errorf("TargetSizeFlag = %q, want empty (auto-detect)", a.TargetSizeFlag)
}
}

// TestRunInteractive_TabularRegression prompts for the label policy
// (regression-class) and leaves the schema to inference.
func TestRunInteractive_TabularRegression(t *testing.T) {
f := &fakePrompter{answers: map[string]string{"Label policy": "passthrough"}}
a := &runDatasetPushArgs{
LocalPath: "./tab",
Spec: push.SpecArgs{Category: "tabular_regression", Table: "reg_train", Intent: "train", LabelColumn: "Target"},
}
if err := runInteractive(discardPrinter(), f, a, true); err != nil {
t.Fatalf("runInteractive: %v", err)
}
if a.Spec.LabelPolicy != "passthrough" {
t.Errorf("LabelPolicy = %q, want passthrough", a.Spec.LabelPolicy)
}
if a.SchemaFlag != "" {
t.Errorf("SchemaFlag = %q, want empty (infer)", a.SchemaFlag)
}
}

// TestRunInteractive_Cancel: declining the confirm returns the
// cancellation sentinel — a clean abort, not a failure.
func TestRunInteractive_Cancel(t *testing.T) {
no := false
f := &fakePrompter{
answers: map[string]string{"Path to your dataset directory": "./x"},
confirm: &no,
}
// path is prompted (→ prompted=true → a confirm is shown); the rest
// is pre-set so we reach the confirm cleanly.
a := &runDatasetPushArgs{Spec: push.SpecArgs{
Category: "image_classification", Table: "t", Intent: "train", LabelColumn: "label",
}}
if err := runInteractive(discardPrinter(), f, a, true); !errors.Is(err, errInteractiveCancelled) {
t.Fatalf("err = %v, want errInteractiveCancelled", err)
}
}

// TestRunInteractive_MLMSkipsLabel: masked_language_modeling has no
// label column, so it must not be prompted.
func TestRunInteractive_MLMSkipsLabel(t *testing.T) {
Expand Down
Loading