Skip to content

Commit

Permalink
Improve error checks for the "-extra" flag.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674381032
Change-Id: I051ea87047195ac85d73e6cde35191310b20275f
  • Loading branch information
Sax Authors authored and copybara-github committed Sep 13, 2024
1 parent 9603088 commit 1ee6482
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 24 deletions.
1 change: 1 addition & 0 deletions saxml/bin/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ go_library(
"//saxml/common:addr",
"//saxml/common:cell",
"//saxml/common:config",
"//saxml/common:errors",
"//saxml/common:naming",
"//saxml/common:watchable",
"//saxml/common/platform:env",
Expand Down
10 changes: 8 additions & 2 deletions saxml/bin/saxutil_cmd_am.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package saxcommand

import (
"context"
"fmt"
"os"
"strconv"

Expand Down Expand Up @@ -48,7 +49,7 @@ func (*RecognizeCmd) Usage() string {

// SetFlags sets flags for AudioToTextCmd.
func (c *RecognizeCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "Extra arguments for Recognize().")
f.StringVar(&c.extra, "extra", "", "Extra arguments for Recognize(),"+ExtraInputsHelp)
}

// Execute executes AudioToTextCmd.
Expand All @@ -57,6 +58,11 @@ func (c *RecognizeCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any
log.Errorf("Provide model and audio path.")
return subcommands.ExitUsageError
}
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

m, err := sax.Open(f.Args()[0])
if err != nil {
Expand All @@ -69,7 +75,7 @@ func (c *RecognizeCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any
filePath := f.Args()[1]
var contents []byte
contents = readFile(filePath)
results, err := am.Recognize(ctx, contents, ExtraInputs(c.extra)...)
results, err := am.Recognize(ctx, contents, extra...)
if err != nil {
log.Errorf("Failed to transrbie audio (%s) due to %v", filePath, err)
return subcommands.ExitFailure
Expand Down
14 changes: 10 additions & 4 deletions saxml/bin/saxutil_cmd_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ import (
"flag"
log "github.com/golang/glog"
"saxml/client/go/sax"
"saxml/common/errors"
)

// ExtraInputsHelp is a help message for the "-extra" flag.
const ExtraInputsHelp = `, in format "key0:value0,key1:value1,...". For example "temperature:0.2,nlayers:32".`

var cmdTimeout = flag.Duration(
"sax_timeout",
60*time.Second,
Expand Down Expand Up @@ -92,14 +96,16 @@ func formatFloat(val float64) string {
}

// ExtraInputs creates a list of options setters from a string in the form of "a:0.5,b:1.2,c:'/foo/bar'".
func ExtraInputs(extra string) []sax.ModelOptionSetter {
func ExtraInputs(extra string) ([]sax.ModelOptionSetter, error) {
if extra == "" {
return nil, nil
}
extraFields := strings.Split(extra, ",")
options := []sax.ModelOptionSetter{}
for _, option := range extraFields {
kv := strings.Split(option, ":")
if len(kv) != 2 {
log.V(1).Infof("Cannot get k-v pair by splitting %s with ':'\n", option)
continue
return nil, fmt.Errorf("key-value pair for an extra input must be separated by ':', but found '%s': %w", extra, errors.ErrInvalidArgument)
}

key, val := kv[0], kv[1]
Expand All @@ -122,5 +128,5 @@ func ExtraInputs(extra string) []sax.ModelOptionSetter {
options = append(options, sax.WithExtraInputString(key, val))
}
log.V(1).Infof("options %v", options)
return options
return options, nil
}
47 changes: 37 additions & 10 deletions saxml/bin/saxutil_cmd_lm.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,20 @@ Usage: saxutil lm.generate [-n=<N>] [-stream] [-terse] [-extra=<extra>] <ModelID

// SetFlags sets the flags for GenerateCmd.
func (c *GenerateCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "extra arguments for Generate().")
f.StringVar(&c.extra, "extra", "", "extra arguments for Generate()"+ExtraInputsHelp)
f.StringVar(&c.proxy, "proxy", "", "SAX Proxy address, e.g., sax.server.lm.lmservice-prod.blade.gslb.googleprod.com")
f.BoolVar(&c.stream, "stream", false, "stream responses")
f.BoolVar(&c.terse, "terse", false, "print generated texts one line per result, descending by score")
f.IntVar(&c.maxOutputs, "n", 0, "maximum number of generated texts to output or zero for all")
}

func (c *GenerateCmd) streamingGenerate(ctx context.Context, query string, lm *sax.LanguageModel) subcommands.ExitStatus {
chanStreamResults := lm.GenerateStream(ctx, query, ExtraInputs(c.extra)...)
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}
chanStreamResults := lm.GenerateStream(ctx, query, extra...)
var accumulatedResults []string
var allScores [][]float64
var lastScore [][]float64
Expand Down Expand Up @@ -177,6 +182,12 @@ func (c *GenerateCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any)
query = string(readStdin())
}

extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

ctx, cancel := context.WithTimeout(ctx, *cmdTimeout)
defer cancel()

Expand All @@ -194,7 +205,7 @@ func (c *GenerateCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any)
}

// Non-streaming generate.
generates, err := lm.Generate(ctx, query, ExtraInputs(c.extra)...)
generates, err := lm.Generate(ctx, query, extra...)
if err != nil {
log.Errorf("Failed to generate query: %v", err)
return subcommands.ExitFailure
Expand Down Expand Up @@ -246,7 +257,7 @@ func (*ScoreCmd) Usage() string {

// SetFlags sets flags for ScoreCmd.
func (c *ScoreCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "extra arguments for Score().")
f.StringVar(&c.extra, "extra", "", "extra arguments for Score()"+ExtraInputsHelp)
f.StringVar(&c.proxy, "proxy", "", "SAX Proxy address, e.g., sax.server.lm.lmservice-prod.blade.gslb.googleprod.com")
}

Expand All @@ -272,7 +283,14 @@ func (c *ScoreCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any) su
defer cancel()

suffixes := f.Args()[2:]
logPs, err := lm.Score(ctx, f.Args()[1], suffixes, ExtraInputs(c.extra)...)

extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

logPs, err := lm.Score(ctx, f.Args()[1], suffixes, extra...)
if err != nil {
log.Errorf("Failed to score prefix/suffix: %v", err)
return subcommands.ExitFailure
Expand Down Expand Up @@ -312,7 +330,7 @@ func (*EmbedTextCmd) Usage() string {

// SetFlags sets flags for EmbedTextCmd.
func (c *EmbedTextCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "Extra arguments for Embed().")
f.StringVar(&c.extra, "extra", "", "Extra arguments for Embed()"+ExtraInputsHelp)
f.StringVar(&c.proxy, "proxy", "", "Sax Proxy address, e.g., sax.server.lm.lmservice-prod.blade.gslb.googleprod.com")
}

Expand All @@ -333,7 +351,12 @@ func (c *EmbedTextCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any
defer cancel()
text := f.Args()[1]

results, err := lm.Embed(ctx, text, ExtraInputs(c.extra)...)
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}
results, err := lm.Embed(ctx, text, extra...)
if err != nil {
log.Errorf("Failed to embed text (%s) due to %v", text, err)
return subcommands.ExitFailure
Expand Down Expand Up @@ -370,7 +393,7 @@ func (*GradientCmd) Usage() string {

// SetFlags sets flags for ScoreCmd.
func (c *GradientCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "extra arguments for Gradient().")
f.StringVar(&c.extra, "extra", "", "extra arguments for Gradient()"+ExtraInputsHelp)
f.StringVar(&c.proxy, "proxy", "", "SAX Proxy address, e.g., sax.server.lm.lmservice-prod.blade.gslb.googleprod.com")
}

Expand All @@ -391,11 +414,15 @@ func (c *GradientCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any)
log.Errorf("Failed to create language model: %v", err)
return subcommands.ExitFailure
}

extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}
ctx, cancel := context.WithTimeout(ctx, *cmdTimeout)
defer cancel()

scores, gradients, err := lm.Gradient(ctx, f.Args()[1], f.Args()[2], ExtraInputs(c.extra)...)
scores, gradients, err := lm.Gradient(ctx, f.Args()[1], f.Args()[2], extra...)
if err != nil {
log.Errorf("Failed to get the gradient of prefix/suffix: %v", err)
return subcommands.ExitFailure
Expand Down
36 changes: 28 additions & 8 deletions saxml/bin/saxutil_cmd_vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (*ClassifyCmd) Usage() string {

// SetFlags sets flags for ClassifyCmd.
func (c *ClassifyCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "Extra arguments for Classify().")
f.StringVar(&c.extra, "extra", "", "Extra arguments for Classify(),"+ExtraInputsHelp)
}

// Execute executes ClassifyCmd.
Expand All @@ -62,6 +62,11 @@ func (c *ClassifyCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any)
log.Errorf("Provide model and image path for classify")
return subcommands.ExitUsageError
}
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

m, err := sax.Open(f.Args()[0])
if err != nil {
Expand All @@ -79,7 +84,7 @@ func (c *ClassifyCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any)
contents = readFile(imagePath)
}

results, err := vm.Classify(ctx, contents, ExtraInputs(c.extra)...)
results, err := vm.Classify(ctx, contents, extra...)
if err != nil {
log.Errorf("Failed to classify image (%s) due to %v", imagePath, err)
return subcommands.ExitFailure
Expand Down Expand Up @@ -131,7 +136,7 @@ func (*TextToImageCmd) Usage() string {

// SetFlags sets flags for TextToImageCmd.
func (c *TextToImageCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "Extra arguments for TextToImage().")
f.StringVar(&c.extra, "extra", "", "Extra arguments for TextToImage(),"+ExtraInputsHelp)
}

// Execute executes TextToImageCmd.
Expand All @@ -140,6 +145,11 @@ func (c *TextToImageCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...a
log.Errorf("Provide model ID, text and output directory for text_to_image.")
return subcommands.ExitUsageError
}
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

m, err := sax.Open(f.Args()[0])
if err != nil {
Expand All @@ -151,7 +161,7 @@ func (c *TextToImageCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...a
defer cancel()
text := f.Args()[1]

results, err := vm.TextToImage(ctx, text, ExtraInputs(c.extra)...)
results, err := vm.TextToImage(ctx, text, extra...)
if err != nil {
log.Errorf("Failed to generate images for (%s) due to [%v].", text, err)
return subcommands.ExitFailure
Expand Down Expand Up @@ -214,7 +224,7 @@ func (*EmbedImageCmd) Usage() string {

// SetFlags sets flags for EmbedImageCmd.
func (c *EmbedImageCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "Extra arguments for Embed().")
f.StringVar(&c.extra, "extra", "", "Extra arguments for Embed(),"+ExtraInputsHelp)
}

// Execute executes EmbedImageCmd.
Expand All @@ -223,6 +233,11 @@ func (c *EmbedImageCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...an
log.Errorf("Provide model and image path for embed")
return subcommands.ExitUsageError
}
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

m, err := sax.Open(f.Args()[0])
if err != nil {
Expand All @@ -240,7 +255,7 @@ func (c *EmbedImageCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...an
contents = readFile(imagePath)
}

results, err := vm.Embed(ctx, contents, ExtraInputs(c.extra)...)
results, err := vm.Embed(ctx, contents, extra...)
if err != nil {
log.Errorf("Failed to embed image (%s) due to %v", imagePath, err)
return subcommands.ExitFailure
Expand Down Expand Up @@ -286,7 +301,7 @@ func (*DetectCmd) Usage() string {

// SetFlags sets flags for DetectCmd.
func (c *DetectCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&c.extra, "extra", "", "Extra arguments for Detect().")
f.StringVar(&c.extra, "extra", "", "Extra arguments for Detect(),"+ExtraInputsHelp)
}

// Execute executes DetectCmd.
Expand All @@ -295,6 +310,11 @@ func (c *DetectCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any) s
log.Errorf("Provide model and image path for detect")
return subcommands.ExitUsageError
}
extra, err := ExtraInputs(c.extra)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse extra inputs: %v", err)
return subcommands.ExitFailure
}

m, err := sax.Open(f.Args()[0])
if err != nil {
Expand All @@ -315,7 +335,7 @@ func (c *DetectCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...any) s
text := f.Args()[2:]
// Placeholder boxes for command line API. Boxes input not supported at the moment.
var boxes = []sax.BoundingBox{}
results, err := vm.Detect(ctx, contents, text, boxes, ExtraInputs(c.extra)...)
results, err := vm.Detect(ctx, contents, text, boxes, extra...)
if err != nil {
log.Errorf("Failed to detect objects in image (%s) due to %v", imagePath, err)
return subcommands.ExitFailure
Expand Down

0 comments on commit 1ee6482

Please sign in to comment.