From 78ea28a612e26534643b069b475abd30f075d396 Mon Sep 17 00:00:00 2001 From: hkadakia Date: Tue, 7 Dec 2021 15:04:59 -0800 Subject: [PATCH 1/2] cleanup code & add unit test --- go.mod | 1 + go.sum | 2 + go/v1beta1/storage/pgsqlstore.go | 101 ++--- .../storage/pgsqlstore_functional_test.go | 335 +++++++++++++++ go/v1beta1/storage/pgsqlstore_test.go | 392 +++++------------- go/v1beta1/storage/queries.go | 16 +- 6 files changed, 495 insertions(+), 352 deletions(-) create mode 100644 go/v1beta1/storage/pgsqlstore_functional_test.go diff --git a/go.mod b/go.mod index 7ad9ea3..fcbdb39 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/grafeas/grafeas-pgsql go 1.17 require ( + github.com/DATA-DOG/go-sqlmock v1.4.1 github.com/fernet/fernet-go v0.0.0-20191111064656-eff2850e6001 github.com/golang/protobuf v1.4.2 github.com/google/uuid v1.1.1 diff --git a/go.sum b/go.sum index 7bb6e82..cdf88cf 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/Bowery/prompt v0.0.0-20190916142128-fa8279994f75/go.mod h1:4/6eNcqZ09BZ9wLK3tZOjBA1nDj+B0728nlX5YRlSmQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM= +github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= diff --git a/go/v1beta1/storage/pgsqlstore.go b/go/v1beta1/storage/pgsqlstore.go index 862b655..5a28f94 100644 --- a/go/v1beta1/storage/pgsqlstore.go +++ b/go/v1beta1/storage/pgsqlstore.go @@ -44,6 +44,7 @@ import ( // config.ConvertGenericConfigToSpecificType internally uses json package. type Config struct { Host string `json:"host"` + Port int `json:"port"` // DBName has to alrady exist and can be accessed by User. DBName string `json:"db_name"` User string `json:"user"` @@ -206,21 +207,14 @@ func (pg *PgSQLStore) GetProject(ctx context.Context, pID string) (*prpb.Project // ListProjects returns up to pageSize number of projects beginning at pageToken (or from // start if pageToken is the empty string). func (pg *PgSQLStore) ListProjects(ctx context.Context, filter string, pageSize int, pageToken string) ([]*prpb.Project, string, error) { - count, err := pg.count(ctx, projectCount) - if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to count Projects from database") - } - var filterQuery string if filter != "" { var fs FilterSQL filterQuery = " AND " + fs.ParseFilter(filter) } - - query := fmt.Sprint(listProjects, filterQuery) - var rows *sql.Rows + query := fmt.Sprintf(listProjects, filterQuery) id := decryptInt64(pageToken, pg.paginationKey, 0) - rows, err = pg.DB.QueryContext(ctx, query, id, pageSize) + rows, err := pg.DB.QueryContext(ctx, query, id, pageSize) if err != nil { return nil, "", status.Error(codes.Internal, "Failed to list Projects from database") } @@ -234,7 +228,18 @@ func (pg *PgSQLStore) ListProjects(ctx context.Context, filter string, pageSize } projects = append(projects, &prpb.Project{Name: name}) } - if count == lastID { + if len(projects) == 0 { + return projects, "", nil + } + maxQuery := projectsMaxID + if filterQuery != "" { + maxQuery = fmt.Sprintf("%s WHERE %s", maxQuery, filterQuery) + } + maxID, err := pg.max(ctx, maxQuery) + if err != nil { + return nil, "", status.Error(codes.Internal, "Failed to query max project id from database") + } + if lastID >= maxID { return projects, "", nil } encryptedPage, err := encryptInt64(lastID, pg.paginationKey) @@ -355,21 +360,15 @@ func (pg *PgSQLStore) GetOccurrence(ctx context.Context, pID, oID string) (*pb.O // ListOccurrences returns up to pageSize number of occurrences for this project beginning // at pageToken, or from start if pageToken is the empty string. func (pg *PgSQLStore) ListOccurrences(ctx context.Context, pID, filter, pageToken string, pageSize int32) ([]*pb.Occurrence, string, error) { - count, err := pg.count(ctx, occurrenceCount, pID) - if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to count Occurrences from database") - } - var filterQuery string if filter != "" { var fs FilterSQL filterQuery = " AND " + fs.ParseFilter(filter) } - query := fmt.Sprint(listOccurrences, filterQuery) - var rows *sql.Rows + query := fmt.Sprintf(listOccurrences, filterQuery) id := decryptInt64(pageToken, pg.paginationKey, 0) - rows, err = pg.DB.QueryContext(ctx, query, pID, id, pageSize) + rows, err := pg.DB.QueryContext(ctx, query, pID, id, pageSize) if err != nil { return nil, "", status.Error(codes.Internal, "Failed to list Occurrences from database") } @@ -388,12 +387,20 @@ func (pg *PgSQLStore) ListOccurrences(ctx context.Context, pID, filter, pageToke } os = append(os, &o) } - if count == lastID { + if len(os) == 0 { + return os, "", nil + } + maxQuery := fmt.Sprintf(occurrenceMaxID, filterQuery) + maxID, err := pg.max(ctx, maxQuery, pID) + if err != nil { + return nil, "", status.Error(codes.Internal, "Failed to query max occurrence id from database") + } + if lastID >= maxID { return os, "", nil } encryptedPage, err := encryptInt64(lastID, pg.paginationKey) if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to paginate projects") + return nil, "", status.Error(codes.Internal, "Failed to paginate occurrences") } return os, encryptedPage, nil } @@ -435,9 +442,7 @@ func (pg *PgSQLStore) BatchCreateNotes(ctx context.Context, pID, uID string, not } else { created = append(created, note) } - } - return created, errs } @@ -521,18 +526,13 @@ func (pg *PgSQLStore) GetOccurrenceNote(ctx context.Context, pID, oID string) (* // ListNotes returns up to pageSize number of notes for this project (pID) beginning // at pageToken (or from start if pageToken is the empty string). func (pg *PgSQLStore) ListNotes(ctx context.Context, pID, filter, pageToken string, pageSize int32) ([]*pb.Note, string, error) { - count, err := pg.count(ctx, noteCount, pID) - if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to count Notes from database") - } - var filterQuery string if filter != "" { var fs FilterSQL filterQuery = " AND " + fs.ParseFilter(filter) } - query := fmt.Sprint(listNotes, filterQuery) + query := fmt.Sprintf(listNotes, filterQuery) id := decryptInt64(pageToken, pg.paginationKey, 0) rows, err := pg.DB.QueryContext(ctx, query, pID, id, pageSize) if err != nil { @@ -553,27 +553,33 @@ func (pg *PgSQLStore) ListNotes(ctx context.Context, pID, filter, pageToken stri } ns = append(ns, &n) } - if count == lastID { + if len(ns) == 0 { + return ns, "", nil + } + maxQuery := fmt.Sprintf(notesMaxID, filterQuery) + maxID, err := pg.max(ctx, maxQuery, pID) + if err != nil { + return nil, "", status.Error(codes.Internal, "Failed to query max note id from database") + } + if lastID >= maxID { return ns, "", nil } encryptedPage, err := encryptInt64(lastID, pg.paginationKey) if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to paginate projects") + return nil, "", status.Error(codes.Internal, "Failed to paginate notes") } return ns, encryptedPage, nil } // ListNoteOccurrences returns up to pageSize number of occcurrences on the particular note (nID) // for this project (pID) projects beginning at pageToken (or from start if pageToken is the empty string). +// TODO: implement query filter for NoteOccurrences. +// ListNoteOccurrences is not used by grafeas-client currently. func (pg *PgSQLStore) ListNoteOccurrences(ctx context.Context, pID, nID, filter, pageToken string, pageSize int32) ([]*pb.Occurrence, string, error) { // Verify that note exists if _, err := pg.GetNote(ctx, pID, nID); err != nil { return nil, "", err } - count, err := pg.count(ctx, noteOccurrencesCount, pID, nID) - if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to count Occurrences from database") - } id := decryptInt64(pageToken, pg.paginationKey, 0) rows, err := pg.DB.QueryContext(ctx, listNoteOccurrences, pID, nID, id, pageSize) if err != nil { @@ -594,12 +600,19 @@ func (pg *PgSQLStore) ListNoteOccurrences(ctx context.Context, pID, nID, filter, } os = append(os, &o) } - if count == lastID { + if len(os) == 0 { + return os, "", nil + } + maxID, err := pg.max(ctx, NoteOccurrencesMaxID, pID, nID) + if err != nil { + return nil, "", status.Error(codes.Internal, "Failed to query max NoteOccurrences from database") + } + if lastID >= maxID { return os, "", nil } encryptedPage, err := encryptInt64(lastID, pg.paginationKey) if err != nil { - return nil, "", status.Error(codes.Internal, "Failed to paginate projects") + return nil, "", status.Error(codes.Internal, "Failed to paginate note occurrences") } return os, encryptedPage, nil } @@ -609,16 +622,8 @@ func (pg *PgSQLStore) GetVulnerabilityOccurrencesSummary(ctx context.Context, pr return &pb.VulnerabilityOccurrencesSummary{}, nil } -// CreateSourceString generates DB source path. -func CreateSourceString(user, password, host, dbName, SSLMode string) string { - if user == "" { - return fmt.Sprintf("postgres://%s/%s?sslmode=%s", host, dbName, SSLMode) - } - return fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s", user, password, host, dbName, SSLMode) -} - -// count returns the total number of entries for the specified query (assuming SELECT(*) is used) -func (pg *PgSQLStore) count(ctx context.Context, query string, args ...interface{}) (int64, error) { +// max returns the max ID of entries for the specified query (assuming SELECT(*) is used) +func (pg *PgSQLStore) max(ctx context.Context, query string, args ...interface{}) (int64, error) { row := pg.DB.QueryRowContext(ctx, query, args...) var count int64 err := row.Scan(&count) @@ -628,7 +633,7 @@ func (pg *PgSQLStore) count(ctx context.Context, query string, args ...interface return count, err } -// Encrypt int64 using provided key +// encryptInt64 encrypts int64 using provided key func encryptInt64(v int64, key string) (string, error) { k, err := fernet.DecodeKey(key) if err != nil { @@ -641,7 +646,7 @@ func encryptInt64(v int64, key string) (string, error) { return string(bytes), nil } -// Decrypts encrypted int64 using provided key. Returns defaultValue if decryption fails. +// decryptInt64 decrypts encrypted int64 using provided key. Returns defaultValue if decryption fails. func decryptInt64(encrypted string, key string, defaultValue int64) int64 { k, err := fernet.DecodeKey(key) if err != nil { diff --git a/go/v1beta1/storage/pgsqlstore_functional_test.go b/go/v1beta1/storage/pgsqlstore_functional_test.go new file mode 100644 index 0000000..4f614a0 --- /dev/null +++ b/go/v1beta1/storage/pgsqlstore_functional_test.go @@ -0,0 +1,335 @@ +// Copyright 2019 The Grafeas Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build functional +// +build functional + +// Following tests are forked from https://github.com/grafeas/grafeas/blob/master/go/v1beta1/storage/pgsqlstore_test.go, +// and they require a mock postgres instance in the environment. +// We won't run these tests in SD for now. +// TODO: migrate to go-dog func tests. + +package storage + +import ( + "database/sql" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "testing" + + "github.com/grafeas/grafeas/go/config" + grafeas "github.com/grafeas/grafeas/go/v1beta1/api" + "github.com/grafeas/grafeas/go/v1beta1/project" + "github.com/grafeas/grafeas/go/v1beta1/storage" +) + +type testPgHelper struct { + pgDataPath string + pgBinPath string + startedPg bool + pgConfig *config.PgSQLConfig +} + +var ( + //Unfortunately, not a good way to pass this information around to tests except via a globally scoped var + pgsqlstoreTestPgConfig *testPgHelper +) + +func startupPostgres(pgData *testPgHelper) error { + //Create a test database instance directory + if pgDataPath, err := ioutil.TempDir("", "pg-data-*"); err != nil { + return err + } else { + pgData.pgDataPath = filepath.ToSlash(pgDataPath) + } + + //Make password file + passwordTempFile, err := ioutil.TempFile("", "pgpassword-*") + if err != nil { + return err + } + defer os.Remove(passwordTempFile.Name()) + + if _, err = io.WriteString(passwordTempFile, pgData.pgConfig.Password); err != nil { + return err + } + + if err := passwordTempFile.Sync(); err != nil { + return err + } + + port, err := findAvailablePort() + if err != nil { + return err + } + pgData.pgConfig.Host = fmt.Sprintf("127.0.0.1:%d", port) + + //Init db + pgCtl := filepath.Join(pgData.pgBinPath, "pg_ctl") + fmt.Fprintln(os.Stderr, "testing: intializing test postgres instance under", pgData.pgDataPath) + pgCtlInitDBOptions := fmt.Sprintf("--username %s --pwfile %s", pgData.pgConfig.User, passwordTempFile.Name()) + cmd := exec.Command(pgCtl, "--pgdata", pgData.pgDataPath, "-o", pgCtlInitDBOptions, "initdb") + if err := cmd.Run(); err != nil { + return err + } + + //Start postgres + fmt.Fprintln(os.Stderr, "testing: starting test postgres instance on port", port) + pgCtlStartOptions := fmt.Sprintf("-p %d", port) + cmd = exec.Command(pgCtl, "--pgdata", pgData.pgDataPath, "-o", pgCtlStartOptions, "start") + if err := cmd.Run(); err != nil { + return err + } + + pgData.startedPg = true + + return nil +} + +func findAvailablePort() (availablePort int, err error) { + for availablePort = 5432; availablePort < 6000; availablePort++ { + l, err := net.Listen("tcp", fmt.Sprintf(":%d", availablePort)) + defer l.Close() + if err == nil { + return availablePort, nil + } + } + + return -1, fmt.Errorf("Unable to find an open port") +} + +func isPostgresRunning(config *config.PgSQLConfig) bool { + source := storage.CreateSourceString(config.User, config.Password, config.Host, "postgres", config.SSLMode) + db, err := sql.Open("postgres", source) + if err != nil { + return false + } + defer db.Close() + + if db.Ping() != nil { + return false + } + return true +} + +func getPostgresBinPathFromSystemPath() (binPath string, err error) { + cmd := exec.Command("which", "pg_ctl") + output, err := cmd.Output() + if output != nil && err == nil { + binPath = filepath.ToSlash(filepath.Dir(string(output))) + } + + //Deal with "which" Linux-style output on Windows, a bit of a corner case + regex := regexp.MustCompile("^/([a-z])/(.*)$") + regexMatches := regex.FindStringSubmatch(binPath) + if runtime.GOOS == "windows" && regexMatches != nil && len(regexMatches) == 3 { + binPath = fmt.Sprintf("%s:/%s", regexMatches[1], regexMatches[2]) + } + + return +} + +func setup() (pgData *testPgHelper, err error) { + pgConfig := &config.PgSQLConfig{ + Host: "127.0.0.1:5432", + User: "postgres", + Password: "password", + SSLMode: "disable", + } + + pgData = &testPgHelper{ + startedPg: false, + pgConfig: pgConfig, + } + + //See if postgres is already available and running + if isPostgresRunning(pgConfig) { + return + } + + //Check for a global installation + if pgData.pgBinPath, err = getPostgresBinPathFromSystemPath(); err != nil { + err = fmt.Errorf("Unable to find a running Postgres instance or Postgres binaries necessary for testing on the system PATH: %v", err) + return + } + + //Startup postgres + if err = startupPostgres(pgData); err != nil { + return + } + + return pgData, nil +} + +func stopPostgres(pgData *testPgHelper) error { + if pgData != nil && pgData.startedPg { + //Stop postgres + pgCtl := filepath.Join(pgData.pgBinPath, "pg_ctl") + + fmt.Fprintln(os.Stderr, "testing: stopping test postgres instance") + cmd := exec.Command(pgCtl, "--pgdata", pgData.pgDataPath, "stop") + if err := cmd.Run(); err != nil { + return err + } + + //Cleanup + if err := os.RemoveAll(pgData.pgDataPath); err != nil { + return err + } + } + + return nil +} + +func teardown(pgData *testPgHelper) error { + return stopPostgres(pgData) +} + +func dropDatabase(t *testing.T, config *config.PgSQLConfig) { + t.Helper() + // Open database + source := storage.CreateSourceString(config.User, config.Password, config.Host, "postgres", config.SSLMode) + db, err := sql.Open("postgres", source) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + // Kill opened connection + if _, err := db.Exec(` + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE datname = $1`, config.DbName); err != nil { + t.Fatalf("Failed to drop database: %v", err) + } + // Drop database + if _, err := db.Exec("DROP DATABASE " + config.DbName); err != nil { + t.Fatalf("Failed to drop database: %v", err) + } +} + +func TestMain(m *testing.M) { + var err error + pgsqlstoreTestPgConfig, err = setup() + if err != nil { + log.Fatal(err) + } + + exitVal := m.Run() + + if err := teardown(pgsqlstoreTestPgConfig); err != nil { + log.Fatal(err) + } + + // os.Exit() does not respect defer statements + os.Exit(exitVal) +} + +func TestBetaPgSQLStore(t *testing.T) { + createPgSQLStore := func(t *testing.T) (grafeas.Storage, project.Storage, func()) { + t.Helper() + config := &config.PgSQLConfig{ + Host: pgsqlstoreTestPgConfig.pgConfig.Host, + DbName: "test_db", + User: pgsqlstoreTestPgConfig.pgConfig.User, + Password: pgsqlstoreTestPgConfig.pgConfig.Password, + SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, + PaginationKey: "XxoPtCUzrUv4JV5dS+yQ+MdW7yLEJnRMwigVY/bpgtQ=", + } + pg, err := storage.NewPgSQLStore(config) + if err != nil { + t.Errorf("Error creating PgSQLStore, %s", err) + } + var g grafeas.Storage = pg + var gp project.Storage = pg + return g, gp, func() { dropDatabase(t, config); pg.Close() } + } + + storage.DoTestStorage(t, createPgSQLStore) +} + +func TestPgSQLStoreWithUserAsEnv(t *testing.T) { + createPgSQLStore := func(t *testing.T) (grafeas.Storage, project.Storage, func()) { + t.Helper() + config := &config.PgSQLConfig{ + Host: pgsqlstoreTestPgConfig.pgConfig.Host, + DbName: "test_db", + User: "", + Password: "", + SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, + PaginationKey: "XxoPtCUzrUv4JV5dS+yQ+MdW7yLEJnRMwigVY/bpgtQ=", + } + _ = os.Setenv("PGUSER", pgsqlstoreTestPgConfig.pgConfig.User) + _ = os.Setenv("PGPASSWORD", pgsqlstoreTestPgConfig.pgConfig.Password) + pg, err := storage.NewPgSQLStore(config) + if err != nil { + t.Errorf("Error creating PgSQLStore, %s", err) + } + var g grafeas.Storage = pg + var gp project.Storage = pg + return g, gp, func() { dropDatabase(t, config); pg.Close() } + } + + storage.DoTestStorage(t, createPgSQLStore) +} + +func TestBetaPgSQLStoreWithNoPaginationKey(t *testing.T) { + createPgSQLStore := func(t *testing.T) (grafeas.Storage, project.Storage, func()) { + t.Helper() + config := &config.PgSQLConfig{ + Host: pgsqlstoreTestPgConfig.pgConfig.Host, + DbName: "test_db", + User: pgsqlstoreTestPgConfig.pgConfig.User, + Password: pgsqlstoreTestPgConfig.pgConfig.Password, + SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, + PaginationKey: "", + } + pg, err := storage.NewPgSQLStore(config) + if err != nil { + t.Errorf("Error creating PgSQLStore, %s", err) + } + var g grafeas.Storage = pg + var gp project.Storage = pg + return g, gp, func() { dropDatabase(t, config); pg.Close() } + } + + storage.DoTestStorage(t, createPgSQLStore) +} + +func TestBetaPgSQLStoreWithInvalidPaginationKey(t *testing.T) { + config := &config.PgSQLConfig{ + Host: pgsqlstoreTestPgConfig.pgConfig.Host, + DbName: "test_db", + User: pgsqlstoreTestPgConfig.pgConfig.User, + Password: pgsqlstoreTestPgConfig.pgConfig.Password, + SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, + PaginationKey: "INVALID_VALUE", + } + pg, err := storage.NewPgSQLStore(config) + if pg != nil { + pg.Close() + } + if err == nil { + t.Errorf("expected error for invalid pagination key; got none") + } + if err.Error() != "invalid pagination key; must be 256-bit URL-safe base64" { + t.Errorf("expected error message about invalid pagination key; got: %s", err.Error()) + } +} diff --git a/go/v1beta1/storage/pgsqlstore_test.go b/go/v1beta1/storage/pgsqlstore_test.go index 21559a3..910dedf 100644 --- a/go/v1beta1/storage/pgsqlstore_test.go +++ b/go/v1beta1/storage/pgsqlstore_test.go @@ -1,22 +1,15 @@ package storage import ( - "database/sql" "fmt" - "io" - "io/ioutil" - "log" - "net" - "os" - "os/exec" - "path/filepath" - "regexp" - "runtime" + "reflect" "testing" + "time" - grafeas "github.com/grafeas/grafeas/go/v1beta1/api" - "github.com/grafeas/grafeas/go/v1beta1/project" - "github.com/grafeas/grafeas/go/v1beta1/storage" + "github.com/DATA-DOG/go-sqlmock" + "github.com/grafeas/grafeas/go/name" + prpb "github.com/grafeas/grafeas/proto/v1beta1/project_go_proto" + "golang.org/x/net/context" ) const ( @@ -25,294 +18,99 @@ const ( paginationKey = "nQi0NzMjerFtlMnbylnWzMrIlNCsuyzeq8LnBEkgxrk=" // go get -v github.com/fernet/fernet-go/cmd/fernet-keygen ; fernet-keygen ) -type testPgHelper struct { - pgDataPath string - pgBinPath string - startedPg bool - pgConfig *Config -} - -var ( - //Unfortunately, not a good way to pass this information around to tests except via a globally scoped var - pgsqlstoreTestPgConfig *testPgHelper -) - -func startupPostgres(pgData *testPgHelper) error { - //Create a test database instance directory - if pgDataPath, err := ioutil.TempDir("", "pg-data-*"); err != nil { - return err - } else { - pgData.pgDataPath = filepath.ToSlash(pgDataPath) - } - - //Make password file - passwordTempFile, err := ioutil.TempFile("", "pgpassword-*") - if err != nil { - return err - } - defer os.Remove(passwordTempFile.Name()) - - if _, err = io.WriteString(passwordTempFile, pgData.pgConfig.Password); err != nil { - return err - } - - if err := passwordTempFile.Sync(); err != nil { - return err - } - - port, err := findAvailablePort() - if err != nil { - return err - } - pgData.pgConfig.Host = fmt.Sprintf("127.0.0.1:%d", port) - - //Init db - pgCtl := filepath.Join(pgData.pgBinPath, "pg_ctl") - fmt.Fprintln(os.Stderr, "testing: initializing test postgres instance under", pgData.pgDataPath) - pgCtlInitDBOptions := fmt.Sprintf("--username %s --pwfile %s", pgData.pgConfig.User, passwordTempFile.Name()) - cmd := exec.Command(pgCtl, "--pgdata", pgData.pgDataPath, "-o", pgCtlInitDBOptions, "initdb") - if err := cmd.Run(); err != nil { - return err - } - - //Start postgres - fmt.Fprintln(os.Stderr, "testing: starting test postgres instance on port", port) - pgCtlStartOptions := fmt.Sprintf("-p %d", port) - cmd = exec.Command(pgCtl, "--pgdata", pgData.pgDataPath, "-o", pgCtlStartOptions, "start") - if err := cmd.Run(); err != nil { - return err - } - - pgData.startedPg = true - - return nil -} - -func findAvailablePort() (availablePort int, err error) { - for availablePort = 5432; availablePort < 6000; availablePort++ { - l, err := net.Listen("tcp", fmt.Sprintf(":%d", availablePort)) - l.Close() - if err == nil { - return availablePort, nil - } - } - - return -1, fmt.Errorf("unable to find an open port") -} - -func isPostgresRunning(config *Config) bool { - source := CreateSourceString(config.User, config.Password, config.Host, "postgres", config.SSLMode) - db, err := sql.Open("postgres", source) - if err != nil { - return false - } - defer db.Close() - - if db.Ping() != nil { - return false - } - return true -} - -func getPostgresBinPathFromSystemPath() (binPath string, err error) { - cmd := exec.Command("which", "pg_ctl") - output, err := cmd.Output() - if output != nil && err == nil { - binPath = filepath.ToSlash(filepath.Dir(string(output))) - } - - //Deal with "which" Linux-style output on Windows, a bit of a corner case - regex := regexp.MustCompile("^/([a-z])/(.*)$") - regexMatches := regex.FindStringSubmatch(binPath) - if runtime.GOOS == "windows" && regexMatches != nil && len(regexMatches) == 3 { - binPath = fmt.Sprintf("%s:/%s", regexMatches[1], regexMatches[2]) - } - - return -} - -func setup() (pgData *testPgHelper, err error) { - pgConfig := &Config{ - Host: "127.0.0.1:5432", - User: "postgres", - Password: "password", - SSLMode: "disable", - } - - pgData = &testPgHelper{ - startedPg: false, - pgConfig: pgConfig, - } - - //See if postgres is already available and running - if isPostgresRunning(pgConfig) { - return - } - - //Check for a global installation - if pgData.pgBinPath, err = getPostgresBinPathFromSystemPath(); err != nil { - err = fmt.Errorf("unable to find a running Postgres instance or Postgres binaries necessary for testing on the system PATH: %v", err) - return - } - - //Startup postgres - if err = startupPostgres(pgData); err != nil { - return - } - - return pgData, nil -} - -func stopPostgres(pgData *testPgHelper) error { - if pgData != nil && pgData.startedPg { - //Stop postgres - pgCtl := filepath.Join(pgData.pgBinPath, "pg_ctl") - - fmt.Fprintln(os.Stderr, "testing: stopping test postgres instance") - cmd := exec.Command(pgCtl, "--pgdata", pgData.pgDataPath, "stop") - if err := cmd.Run(); err != nil { - return err - } - - //Cleanup - if err := os.RemoveAll(pgData.pgDataPath); err != nil { - return err +func genTestDataProjects() ([]*prpb.Project, []string, error) { + var prjs []*prpb.Project + var prjsData []string + for i := 1; i <= 5; i++ { + s := name.FormatProject(fmt.Sprintf("projects/p%d", i)) + p := &prpb.Project{ + Name: s, } + prjs = append(prjs, p) + prjsData = append(prjsData, string(s)) } - - return nil -} - -func teardown(pgData *testPgHelper) error { - return stopPostgres(pgData) + return prjs, prjsData, nil } -func dropDatabase(t *testing.T, config *Config) { - t.Helper() - // Open database - source := CreateSourceString(config.User, config.Password, config.Host, "postgres", config.SSLMode) - db, err := sql.Open("postgres", source) +func TestStore_ListProjects(t *testing.T) { + projects, projectsData, err := genTestDataProjects() if err != nil { - t.Fatalf("Failed to open database: %v", err) - } - // Kill opened connection - if _, err := db.Exec(` - SELECT pg_terminate_backend(pid) - FROM pg_stat_activity - WHERE datname = $1`, config.DBName); err != nil { - t.Fatalf("Failed to drop database: %v", err) - } - // Drop database - if _, err := db.Exec("DROP DATABASE " + config.DBName); err != nil { - t.Fatalf("Failed to drop database: %v", err) - } -} - -func TestMain(m *testing.M) { - var err error - pgsqlstoreTestPgConfig, err = setup() - if err != nil { - log.Fatal(err) - } - - exitVal := m.Run() - - if err := teardown(pgsqlstoreTestPgConfig); err != nil { - log.Fatal(err) - } - - // os.Exit() does not respect defer statements - os.Exit(exitVal) -} - -func TestBetaPgSQLStore(t *testing.T) { - createPgSQLStore := func(t *testing.T) (grafeas.Storage, project.Storage, func()) { - t.Helper() - config := &Config{ - Host: pgsqlstoreTestPgConfig.pgConfig.Host, - DBName: "test_db", - User: pgsqlstoreTestPgConfig.pgConfig.User, - Password: pgsqlstoreTestPgConfig.pgConfig.Password, - SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, - PaginationKey: paginationKey, - } - pg, err := NewPgSQLStore(config) - if err != nil { - t.Errorf("Error creating PgSQLStore, %s", err) - } - var g grafeas.Storage = pg - var gp project.Storage = pg - return g, gp, func() { dropDatabase(t, config); pg.Close() } - } - - storage.DoTestStorage(t, createPgSQLStore) -} - -func TestPgSQLStoreWithUserAsEnv(t *testing.T) { - createPgSQLStore := func(t *testing.T) (grafeas.Storage, project.Storage, func()) { - t.Helper() - config := &Config{ - Host: pgsqlstoreTestPgConfig.pgConfig.Host, - DBName: "test_db", - User: "", - Password: "", - SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, - PaginationKey: paginationKey, - } - _ = os.Setenv("PGUSER", pgsqlstoreTestPgConfig.pgConfig.User) - _ = os.Setenv("PGPASSWORD", pgsqlstoreTestPgConfig.pgConfig.Password) - pg, err := NewPgSQLStore(config) - if err != nil { - t.Errorf("Error creating PgSQLStore, %s", err) - } - var g grafeas.Storage = pg - var gp project.Storage = pg - return g, gp, func() { dropDatabase(t, config); pg.Close() } - } - - storage.DoTestStorage(t, createPgSQLStore) -} - -func TestBetaPgSQLStoreWithNoPaginationKey(t *testing.T) { - createPgSQLStore := func(t *testing.T) (grafeas.Storage, project.Storage, func()) { - t.Helper() - config := &Config{ - Host: pgsqlstoreTestPgConfig.pgConfig.Host, - DBName: "test_db", - User: pgsqlstoreTestPgConfig.pgConfig.User, - Password: pgsqlstoreTestPgConfig.pgConfig.Password, - SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, - PaginationKey: "", - } - pg, err := NewPgSQLStore(config) - if err != nil { - t.Errorf("Error creating PgSQLStore, %s", err) - } - var g grafeas.Storage = pg - var gp project.Storage = pg - return g, gp, func() { dropDatabase(t, config); pg.Close() } - } - - storage.DoTestStorage(t, createPgSQLStore) -} - -func TestBetaPgSQLStoreWithInvalidPaginationKey(t *testing.T) { - config := &Config{ - Host: pgsqlstoreTestPgConfig.pgConfig.Host, - DBName: "test_db", - User: pgsqlstoreTestPgConfig.pgConfig.User, - Password: pgsqlstoreTestPgConfig.pgConfig.Password, - SSLMode: pgsqlstoreTestPgConfig.pgConfig.SSLMode, - PaginationKey: "INVALID_VALUE", - } - pg, err := NewPgSQLStore(config) - if pg != nil { - pg.Close() - } - if err == nil { - t.Errorf("expected error for invalid pagination key; got none") - } - if err.Error() != "invalid pagination key; must be 256-bit URL-safe base64" { - t.Errorf("expected error message about invalid pagination key; got: %s", err.Error()) + t.Fatalf("failed to genTestDataProjects, err: %v", err) + } + + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + tests := []struct { + name string + getStore func(t *testing.T) (*PgSQLStore, func()) + filter string + pageToken string + pageSize int + want []*prpb.Project + wantDecryptedID int64 + wantErr bool + }{ + { + name: "happy path", + getStore: func(t *testing.T) (*PgSQLStore, func()) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + rows := sqlmock.NewRows([]string{"id", "data"}) + for i, o := range projectsData { + rows = rows.AddRow(i+1, o) // index id starts from 1 + } + mock.ExpectQuery("SELECT id, name FROM projects"). + WillReturnRows(rows) + mock.ExpectQuery(`SELECT MAX\(id\) FROM projects`). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(len(projectsData)))) + s := &PgSQLStore{DB: db} + return s, func() { db.Close() } + }, + want: projects, + }, + { + name: "pagination", + getStore: func(t *testing.T) (*PgSQLStore, func()) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + rows := sqlmock.NewRows([]string{"id", "data"}) + for i := 0; i < 2; i++ { + rows = rows.AddRow(i+1, projectsData[i]) // index id starts from 1 + } + mock.ExpectQuery("SELECT id, name FROM projects"). + WillReturnRows(rows) + mock.ExpectQuery(`SELECT MAX\(id\) FROM projects`). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(len(projectsData)))) + s := &PgSQLStore{DB: db, paginationKey: paginationKey} + return s, func() { db.Close() } + }, + want: projects[0:2], + wantDecryptedID: 2, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, cancel := tt.getStore(t) + defer cancel() + got, nextToken, err := s.ListProjects(ctx, tt.filter, tt.pageSize, tt.pageToken) + if (err != nil) != tt.wantErr { + t.Errorf("ListProjects() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ListProjects() got = %v, want %v", got, tt.want) + } + decryptedTokenID := decryptInt64(nextToken, s.paginationKey, 0) + if decryptedTokenID != tt.wantDecryptedID { + t.Errorf("ListProjects() got1 = %v, want %v", nextToken, tt.wantDecryptedID) + } + }) } } diff --git a/go/v1beta1/storage/queries.go b/go/v1beta1/storage/queries.go index 6604711..66dd0d0 100644 --- a/go/v1beta1/storage/queries.go +++ b/go/v1beta1/storage/queries.go @@ -46,23 +46,25 @@ const ( insertProject = `INSERT INTO projects(name) VALUES ($1)` projectExists = `SELECT EXISTS (SELECT 1 FROM projects WHERE name = $1)` deleteProject = `DELETE FROM projects WHERE name = $1` - listProjects = `SELECT id, name FROM projects WHERE id > $1 LIMIT $2` - projectCount = `SELECT COUNT(*) FROM projects` + // "ORDER BY id" is required because the default select order of PostgreSQL is not guaranteed. + listProjects = `SELECT id, name FROM projects WHERE %s id > $1 ORDER BY id LIMIT $2` + projectsMaxID = `SELECT MAX(id) FROM projects` insertOccurrence = `INSERT INTO occurrences(project_name, occurrence_name, note_id, data) VALUES ($1, $2, (SELECT id FROM notes WHERE project_name = $3 AND note_name = $4), $5)` searchOccurrence = `SELECT data FROM occurrences WHERE project_name = $1 AND occurrence_name = $2` updateOccurrence = `UPDATE occurrences SET data = $1 WHERE project_name = $2 AND occurrence_name = $3` deleteOccurrence = `DELETE FROM occurrences WHERE project_name = $1 AND occurrence_name = $2` - listOccurrences = `SELECT id, data FROM occurrences WHERE project_name = $1 AND id > $2 LIMIT $3` - occurrenceCount = `SELECT COUNT(*) FROM occurrences WHERE project_name = $1` + // "ORDER BY id" is required because the default select order of PostgreSQL is not guaranteed. + listOccurrences = `SELECT id, data FROM occurrences WHERE project_name = $1 %s AND id > $2 ORDER BY id LIMIT $3` + occurrenceMaxID = `SELECT MAX(id) FROM occurrences WHERE project_name = $1 %s` insertNote = `INSERT INTO notes(project_name, note_name, data) VALUES ($1, $2, $3)` searchNote = `SELECT data FROM notes WHERE project_name = $1 AND note_name = $2` updateNote = `UPDATE notes SET data = $1 WHERE project_name = $2 AND note_name = $3` deleteNote = `DELETE FROM notes WHERE project_name = $1 AND note_name = $2` - listNotes = `SELECT id, data FROM notes WHERE project_name = $1 AND id > $2 LIMIT $3` - noteCount = `SELECT COUNT(*) FROM notes WHERE project_name = $1` + listNotes = `SELECT id, data FROM notes WHERE project_name = $1 %s AND id > $2 ORDER BY id LIMIT $3` + notesMaxID = `SELECT MAX(id) FROM notes WHERE project_name = $1 %s` listNoteOccurrences = `SELECT o.id, o.data FROM occurrences as o, notes as n WHERE n.id = o.note_id AND n.project_name = $1 @@ -70,7 +72,7 @@ const ( AND o.id > $3 LIMIT $4` - noteOccurrencesCount = `SELECT COUNT(*) FROM occurrences as o, notes as n + NoteOccurrencesMaxID = `SELECT MAX(o.id) FROM occurrences as o, notes as n WHERE n.id = o.note_id AND n.project_name = $1 AND n.note_name = $2` From 6df2f6d7ec9c8bc05640f23d55954dd715aa1646 Mon Sep 17 00:00:00 2001 From: hkadakia Date: Wed, 8 Dec 2021 15:06:07 -0800 Subject: [PATCH 2/2] remove todo and create an issue instead --- go/v1beta1/storage/pgsqlstore.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/go/v1beta1/storage/pgsqlstore.go b/go/v1beta1/storage/pgsqlstore.go index 5a28f94..6e8f6a1 100644 --- a/go/v1beta1/storage/pgsqlstore.go +++ b/go/v1beta1/storage/pgsqlstore.go @@ -573,8 +573,6 @@ func (pg *PgSQLStore) ListNotes(ctx context.Context, pID, filter, pageToken stri // ListNoteOccurrences returns up to pageSize number of occcurrences on the particular note (nID) // for this project (pID) projects beginning at pageToken (or from start if pageToken is the empty string). -// TODO: implement query filter for NoteOccurrences. -// ListNoteOccurrences is not used by grafeas-client currently. func (pg *PgSQLStore) ListNoteOccurrences(ctx context.Context, pID, nID, filter, pageToken string, pageSize int32) ([]*pb.Occurrence, string, error) { // Verify that note exists if _, err := pg.GetNote(ctx, pID, nID); err != nil {