Skip to content

Commit

Permalink
Merge pull request #14 from hkadakia/ut
Browse files Browse the repository at this point in the history
cleanup code & add unit test
  • Loading branch information
aysylu committed Dec 9, 2021
2 parents d479fe0 + 6df2f6d commit 497ecf5
Show file tree
Hide file tree
Showing 6 changed files with 493 additions and 352 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/grafeas/grafeas-pgsql
go 1.17

require (
github.com/DATA-DOG/go-sqlmock v1.4.1
github.com/fernet/fernet-go v0.0.0-20191111064656-eff2850e6001
github.com/golang/protobuf v1.4.2
github.com/google/uuid v1.1.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
github.com/Bowery/prompt v0.0.0-20190916142128-fa8279994f75/go.mod h1:4/6eNcqZ09BZ9wLK3tZOjBA1nDj+B0728nlX5YRlSmQ=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM=
github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
Expand Down
99 changes: 51 additions & 48 deletions go/v1beta1/storage/pgsqlstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
// config.ConvertGenericConfigToSpecificType internally uses json package.
type Config struct {
Host string `json:"host"`
Port int `json:"port"`
// DBName has to alrady exist and can be accessed by User.
DBName string `json:"db_name"`
User string `json:"user"`
Expand Down Expand Up @@ -206,21 +207,14 @@ func (pg *PgSQLStore) GetProject(ctx context.Context, pID string) (*prpb.Project
// ListProjects returns up to pageSize number of projects beginning at pageToken (or from
// start if pageToken is the empty string).
func (pg *PgSQLStore) ListProjects(ctx context.Context, filter string, pageSize int, pageToken string) ([]*prpb.Project, string, error) {
count, err := pg.count(ctx, projectCount)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to count Projects from database")
}

var filterQuery string
if filter != "" {
var fs FilterSQL
filterQuery = " AND " + fs.ParseFilter(filter)
}

query := fmt.Sprint(listProjects, filterQuery)
var rows *sql.Rows
query := fmt.Sprintf(listProjects, filterQuery)
id := decryptInt64(pageToken, pg.paginationKey, 0)
rows, err = pg.DB.QueryContext(ctx, query, id, pageSize)
rows, err := pg.DB.QueryContext(ctx, query, id, pageSize)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to list Projects from database")
}
Expand All @@ -234,7 +228,18 @@ func (pg *PgSQLStore) ListProjects(ctx context.Context, filter string, pageSize
}
projects = append(projects, &prpb.Project{Name: name})
}
if count == lastID {
if len(projects) == 0 {
return projects, "", nil
}
maxQuery := projectsMaxID
if filterQuery != "" {
maxQuery = fmt.Sprintf("%s WHERE %s", maxQuery, filterQuery)
}
maxID, err := pg.max(ctx, maxQuery)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to query max project id from database")
}
if lastID >= maxID {
return projects, "", nil
}
encryptedPage, err := encryptInt64(lastID, pg.paginationKey)
Expand Down Expand Up @@ -355,21 +360,15 @@ func (pg *PgSQLStore) GetOccurrence(ctx context.Context, pID, oID string) (*pb.O
// ListOccurrences returns up to pageSize number of occurrences for this project beginning
// at pageToken, or from start if pageToken is the empty string.
func (pg *PgSQLStore) ListOccurrences(ctx context.Context, pID, filter, pageToken string, pageSize int32) ([]*pb.Occurrence, string, error) {
count, err := pg.count(ctx, occurrenceCount, pID)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to count Occurrences from database")
}

var filterQuery string
if filter != "" {
var fs FilterSQL
filterQuery = " AND " + fs.ParseFilter(filter)
}

query := fmt.Sprint(listOccurrences, filterQuery)
var rows *sql.Rows
query := fmt.Sprintf(listOccurrences, filterQuery)
id := decryptInt64(pageToken, pg.paginationKey, 0)
rows, err = pg.DB.QueryContext(ctx, query, pID, id, pageSize)
rows, err := pg.DB.QueryContext(ctx, query, pID, id, pageSize)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to list Occurrences from database")
}
Expand All @@ -388,12 +387,20 @@ func (pg *PgSQLStore) ListOccurrences(ctx context.Context, pID, filter, pageToke
}
os = append(os, &o)
}
if count == lastID {
if len(os) == 0 {
return os, "", nil
}
maxQuery := fmt.Sprintf(occurrenceMaxID, filterQuery)
maxID, err := pg.max(ctx, maxQuery, pID)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to query max occurrence id from database")
}
if lastID >= maxID {
return os, "", nil
}
encryptedPage, err := encryptInt64(lastID, pg.paginationKey)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to paginate projects")
return nil, "", status.Error(codes.Internal, "Failed to paginate occurrences")
}
return os, encryptedPage, nil
}
Expand Down Expand Up @@ -435,9 +442,7 @@ func (pg *PgSQLStore) BatchCreateNotes(ctx context.Context, pID, uID string, not
} else {
created = append(created, note)
}

}

return created, errs
}

Expand Down Expand Up @@ -521,18 +526,13 @@ func (pg *PgSQLStore) GetOccurrenceNote(ctx context.Context, pID, oID string) (*
// ListNotes returns up to pageSize number of notes for this project (pID) beginning
// at pageToken (or from start if pageToken is the empty string).
func (pg *PgSQLStore) ListNotes(ctx context.Context, pID, filter, pageToken string, pageSize int32) ([]*pb.Note, string, error) {
count, err := pg.count(ctx, noteCount, pID)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to count Notes from database")
}

var filterQuery string
if filter != "" {
var fs FilterSQL
filterQuery = " AND " + fs.ParseFilter(filter)
}

query := fmt.Sprint(listNotes, filterQuery)
query := fmt.Sprintf(listNotes, filterQuery)
id := decryptInt64(pageToken, pg.paginationKey, 0)
rows, err := pg.DB.QueryContext(ctx, query, pID, id, pageSize)
if err != nil {
Expand All @@ -553,12 +553,20 @@ func (pg *PgSQLStore) ListNotes(ctx context.Context, pID, filter, pageToken stri
}
ns = append(ns, &n)
}
if count == lastID {
if len(ns) == 0 {
return ns, "", nil
}
maxQuery := fmt.Sprintf(notesMaxID, filterQuery)
maxID, err := pg.max(ctx, maxQuery, pID)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to query max note id from database")
}
if lastID >= maxID {
return ns, "", nil
}
encryptedPage, err := encryptInt64(lastID, pg.paginationKey)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to paginate projects")
return nil, "", status.Error(codes.Internal, "Failed to paginate notes")
}
return ns, encryptedPage, nil
}
Expand All @@ -570,10 +578,6 @@ func (pg *PgSQLStore) ListNoteOccurrences(ctx context.Context, pID, nID, filter,
if _, err := pg.GetNote(ctx, pID, nID); err != nil {
return nil, "", err
}
count, err := pg.count(ctx, noteOccurrencesCount, pID, nID)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to count Occurrences from database")
}
id := decryptInt64(pageToken, pg.paginationKey, 0)
rows, err := pg.DB.QueryContext(ctx, listNoteOccurrences, pID, nID, id, pageSize)
if err != nil {
Expand All @@ -594,12 +598,19 @@ func (pg *PgSQLStore) ListNoteOccurrences(ctx context.Context, pID, nID, filter,
}
os = append(os, &o)
}
if count == lastID {
if len(os) == 0 {
return os, "", nil
}
maxID, err := pg.max(ctx, NoteOccurrencesMaxID, pID, nID)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to query max NoteOccurrences from database")
}
if lastID >= maxID {
return os, "", nil
}
encryptedPage, err := encryptInt64(lastID, pg.paginationKey)
if err != nil {
return nil, "", status.Error(codes.Internal, "Failed to paginate projects")
return nil, "", status.Error(codes.Internal, "Failed to paginate note occurrences")
}
return os, encryptedPage, nil
}
Expand All @@ -609,16 +620,8 @@ func (pg *PgSQLStore) GetVulnerabilityOccurrencesSummary(ctx context.Context, pr
return &pb.VulnerabilityOccurrencesSummary{}, nil
}

// CreateSourceString generates DB source path.
func CreateSourceString(user, password, host, dbName, SSLMode string) string {
if user == "" {
return fmt.Sprintf("postgres://%s/%s?sslmode=%s", host, dbName, SSLMode)
}
return fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s", user, password, host, dbName, SSLMode)
}

// count returns the total number of entries for the specified query (assuming SELECT(*) is used)
func (pg *PgSQLStore) count(ctx context.Context, query string, args ...interface{}) (int64, error) {
// max returns the max ID of entries for the specified query (assuming SELECT(*) is used)
func (pg *PgSQLStore) max(ctx context.Context, query string, args ...interface{}) (int64, error) {
row := pg.DB.QueryRowContext(ctx, query, args...)
var count int64
err := row.Scan(&count)
Expand All @@ -628,7 +631,7 @@ func (pg *PgSQLStore) count(ctx context.Context, query string, args ...interface
return count, err
}

// Encrypt int64 using provided key
// encryptInt64 encrypts int64 using provided key
func encryptInt64(v int64, key string) (string, error) {
k, err := fernet.DecodeKey(key)
if err != nil {
Expand All @@ -641,7 +644,7 @@ func encryptInt64(v int64, key string) (string, error) {
return string(bytes), nil
}

// Decrypts encrypted int64 using provided key. Returns defaultValue if decryption fails.
// decryptInt64 decrypts encrypted int64 using provided key. Returns defaultValue if decryption fails.
func decryptInt64(encrypted string, key string, defaultValue int64) int64 {
k, err := fernet.DecodeKey(key)
if err != nil {
Expand Down
Loading

0 comments on commit 497ecf5

Please sign in to comment.