From 2ba598959a223d3d21e0353daa4050485c39f7ca Mon Sep 17 00:00:00 2001 From: Stephanie Lamb Date: Wed, 18 Sep 2024 16:46:02 -0500 Subject: [PATCH] ran prepare-for-codereview --- cmd/api/src/api/v2/audit_test.go | 208 +++++++++--------- cmd/api/src/api/v2/helpers.go | 18 +- cmd/api/src/bootstrap/initializer.go | 2 +- cmd/api/src/cmd/bhapi/main.go | 6 +- cmd/api/src/database/parameters_test.go | 16 ++ cmd/api/src/model/appcfg/parameter_test.go | 16 ++ .../utils/validation/duration_validator.go | 16 ++ .../validation/duration_validator_test.go | 16 ++ packages/go/analysis/ad/esc4.go | 2 +- 9 files changed, 182 insertions(+), 118 deletions(-) diff --git a/cmd/api/src/api/v2/audit_test.go b/cmd/api/src/api/v2/audit_test.go index bb060835f..58eb0051b 100644 --- a/cmd/api/src/api/v2/audit_test.go +++ b/cmd/api/src/api/v2/audit_test.go @@ -243,122 +243,122 @@ func TestResources_ListAuditLogs_Filtered(t *testing.T) { } func TestResources_ListAuditLogs_SkipAndOffset(t *testing.T) { - var ( - mockCtrl = gomock.NewController(t) - mockDB = mocks.NewMockDatabase(mockCtrl) - resources = v2.Resources{DB: mockDB} - ) - defer mockCtrl.Finish() - - mockDB.EXPECT().ListAuditLogs(gomock.Any(), gomock.Any(), gomock.Any(), 10, gomock.Any(), "", model.SQLFilter{}).Return(model.AuditLogs{}, 1000, nil) - - endpoint := "/api/v2/audit" - - if req, err := http.NewRequest("GET", endpoint, nil); err != nil { - t.Fatal(err) - } else { - q := url.Values{} - q.Add("skip", "10") - q.Add("offset", "20") - - req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) - req.URL.RawQuery = q.Encode() - - router := mux.NewRouter() - router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") - - response := httptest.NewRecorder() - router.ServeHTTP(response, req) - require.Equal(t, http.StatusOK, response.Code) - } + var ( + mockCtrl = gomock.NewController(t) + mockDB = mocks.NewMockDatabase(mockCtrl) + resources = v2.Resources{DB: mockDB} + ) + defer mockCtrl.Finish() + + mockDB.EXPECT().ListAuditLogs(gomock.Any(), gomock.Any(), gomock.Any(), 10, gomock.Any(), "", model.SQLFilter{}).Return(model.AuditLogs{}, 1000, nil) + + endpoint := "/api/v2/audit" + + if req, err := http.NewRequest("GET", endpoint, nil); err != nil { + t.Fatal(err) + } else { + q := url.Values{} + q.Add("skip", "10") + q.Add("offset", "20") + + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + req.URL.RawQuery = q.Encode() + + router := mux.NewRouter() + router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") + + response := httptest.NewRecorder() + router.ServeHTTP(response, req) + require.Equal(t, http.StatusOK, response.Code) + } } func TestResources_ListAuditLogs_OnlyOffset(t *testing.T) { - var ( - mockCtrl = gomock.NewController(t) - mockDB = mocks.NewMockDatabase(mockCtrl) - resources = v2.Resources{DB: mockDB} - ) - defer mockCtrl.Finish() + var ( + mockCtrl = gomock.NewController(t) + mockDB = mocks.NewMockDatabase(mockCtrl) + resources = v2.Resources{DB: mockDB} + ) + defer mockCtrl.Finish() - mockDB.EXPECT().ListAuditLogs(gomock.Any(), gomock.Any(), gomock.Any(), 20, gomock.Any(), "", model.SQLFilter{}).Return(model.AuditLogs{}, 1000, nil) + mockDB.EXPECT().ListAuditLogs(gomock.Any(), gomock.Any(), gomock.Any(), 20, gomock.Any(), "", model.SQLFilter{}).Return(model.AuditLogs{}, 1000, nil) - endpoint := "/api/v2/audit" + endpoint := "/api/v2/audit" - if req, err := http.NewRequest("GET", endpoint, nil); err != nil { - t.Fatal(err) - } else { - q := url.Values{} - q.Add("offset", "20") + if req, err := http.NewRequest("GET", endpoint, nil); err != nil { + t.Fatal(err) + } else { + q := url.Values{} + q.Add("offset", "20") - req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) - req.URL.RawQuery = q.Encode() + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + req.URL.RawQuery = q.Encode() - router := mux.NewRouter() - router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") + router := mux.NewRouter() + router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") - response := httptest.NewRecorder() - router.ServeHTTP(response, req) - require.Equal(t, http.StatusOK, response.Code) - } + response := httptest.NewRecorder() + router.ServeHTTP(response, req) + require.Equal(t, http.StatusOK, response.Code) + } } func TestResources_ListAuditLogs_OnlySkip(t *testing.T) { - var ( - mockCtrl = gomock.NewController(t) - mockDB = mocks.NewMockDatabase(mockCtrl) - resources = v2.Resources{DB: mockDB} - ) - defer mockCtrl.Finish() - - // Expect skip to be 5 (from "skip" parameter) - mockDB.EXPECT().ListAuditLogs(gomock.Any(), gomock.Any(), gomock.Any(), 5, gomock.Any(), gomock.Any(), gomock.Any()).Return(model.AuditLogs{}, 1000, nil) - - endpoint := "/api/v2/audit" - - if req, err := http.NewRequest("GET", endpoint, nil); err != nil { - t.Fatal(err) - } else { - q := url.Values{} - q.Add("skip", "5") - - req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) - req.URL.RawQuery = q.Encode() - - router := mux.NewRouter() - router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") - - response := httptest.NewRecorder() - router.ServeHTTP(response, req) - require.Equal(t, http.StatusOK, response.Code) - } + var ( + mockCtrl = gomock.NewController(t) + mockDB = mocks.NewMockDatabase(mockCtrl) + resources = v2.Resources{DB: mockDB} + ) + defer mockCtrl.Finish() + + // Expect skip to be 5 (from "skip" parameter) + mockDB.EXPECT().ListAuditLogs(gomock.Any(), gomock.Any(), gomock.Any(), 5, gomock.Any(), gomock.Any(), gomock.Any()).Return(model.AuditLogs{}, 1000, nil) + + endpoint := "/api/v2/audit" + + if req, err := http.NewRequest("GET", endpoint, nil); err != nil { + t.Fatal(err) + } else { + q := url.Values{} + q.Add("skip", "5") + + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + req.URL.RawQuery = q.Encode() + + router := mux.NewRouter() + router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") + + response := httptest.NewRecorder() + router.ServeHTTP(response, req) + require.Equal(t, http.StatusOK, response.Code) + } } func TestResources_ListAuditLogs_InvalidSkip(t *testing.T) { - var ( - mockCtrl = gomock.NewController(t) - mockDB = mocks.NewMockDatabase(mockCtrl) - resources = v2.Resources{DB: mockDB} - ) - defer mockCtrl.Finish() - - endpoint := "/api/v2/audit" - - if req, err := http.NewRequest("GET", endpoint, nil); err != nil { - t.Fatal(err) - } else { - q := url.Values{} - q.Add("skip", "invalid") - - req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) - req.URL.RawQuery = q.Encode() - - router := mux.NewRouter() - router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") - - response := httptest.NewRecorder() - router.ServeHTTP(response, req) - require.Equal(t, http.StatusBadRequest, response.Code) - require.Contains(t, response.Body.String(), "query parameter \\\"skip\\\" is malformed") - } -} \ No newline at end of file + var ( + mockCtrl = gomock.NewController(t) + mockDB = mocks.NewMockDatabase(mockCtrl) + resources = v2.Resources{DB: mockDB} + ) + defer mockCtrl.Finish() + + endpoint := "/api/v2/audit" + + if req, err := http.NewRequest("GET", endpoint, nil); err != nil { + t.Fatal(err) + } else { + q := url.Values{} + q.Add("skip", "invalid") + + req.Header.Set(headers.ContentType.String(), mediatypes.ApplicationJson.String()) + req.URL.RawQuery = q.Encode() + + router := mux.NewRouter() + router.HandleFunc(endpoint, resources.ListAuditLogs).Methods("GET") + + response := httptest.NewRecorder() + router.ServeHTTP(response, req) + require.Equal(t, http.StatusBadRequest, response.Code) + require.Contains(t, response.Body.String(), "query parameter \\\"skip\\\" is malformed") + } +} diff --git a/cmd/api/src/api/v2/helpers.go b/cmd/api/src/api/v2/helpers.go index bd5daec72..360f702b8 100644 --- a/cmd/api/src/api/v2/helpers.go +++ b/cmd/api/src/api/v2/helpers.go @@ -58,15 +58,15 @@ func ParseSkipQueryParameter(params url.Values, defaultValue int) (int, error) { } func ParseSkipQueryParameterWithKey(params url.Values, key string, defaultValue int) (int, error) { - if param := params.Get(key); param == "" { - return defaultValue, nil - } else if skip, err := strconv.Atoi(param); err != nil { - return 0, fmt.Errorf("error converting %s value %v to int: %v", key, param, err) - } else if skip < 0 { - return 0, fmt.Errorf(utils.ErrorInvalidSkip, skip) - } else { - return skip, nil - } + if param := params.Get(key); param == "" { + return defaultValue, nil + } else if skip, err := strconv.Atoi(param); err != nil { + return 0, fmt.Errorf("error converting %s value %v to int: %v", key, param, err) + } else if skip < 0 { + return 0, fmt.Errorf(utils.ErrorInvalidSkip, skip) + } else { + return skip, nil + } } func ParseLimitQueryParameter(params url.Values, defaultValue int) (int, error) { diff --git a/cmd/api/src/bootstrap/initializer.go b/cmd/api/src/bootstrap/initializer.go index bc3c8e94e..47f669d3b 100644 --- a/cmd/api/src/bootstrap/initializer.go +++ b/cmd/api/src/bootstrap/initializer.go @@ -39,7 +39,7 @@ type Initializer[DBType database.Database, GraphType graph.Database] struct { Configuration config.Configuration PreMigrationDaemons InitializerLogic[DBType, GraphType] Entrypoint InitializerLogic[DBType, GraphType] - DBConnector DatabaseConstructor[DBType, GraphType] + DBConnector DatabaseConstructor[DBType, GraphType] } func (s Initializer[DBType, GraphType]) Launch(parentCtx context.Context, handleSignals bool) error { diff --git a/cmd/api/src/cmd/bhapi/main.go b/cmd/api/src/cmd/bhapi/main.go index b69f7e497..cd67a8e19 100644 --- a/cmd/api/src/cmd/bhapi/main.go +++ b/cmd/api/src/cmd/bhapi/main.go @@ -64,10 +64,10 @@ func main() { log.Fatalf("Unable to read configuration %s: %v", configFilePath, err) } else { initializer := bootstrap.Initializer[*database.BloodhoundDB, *graph.DatabaseSwitch]{ - Configuration: cfg, - DBConnector: services.ConnectDatabases, + Configuration: cfg, + DBConnector: services.ConnectDatabases, PreMigrationDaemons: services.PreMigrationDaemons, - Entrypoint: services.Entrypoint, + Entrypoint: services.Entrypoint, } if err := initializer.Launch(context.Background(), true); err != nil { diff --git a/cmd/api/src/database/parameters_test.go b/cmd/api/src/database/parameters_test.go index ea239a132..c4b7db2b5 100644 --- a/cmd/api/src/database/parameters_test.go +++ b/cmd/api/src/database/parameters_test.go @@ -1,3 +1,19 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + package database_test import ( diff --git a/cmd/api/src/model/appcfg/parameter_test.go b/cmd/api/src/model/appcfg/parameter_test.go index c9b0085e4..0abd42440 100644 --- a/cmd/api/src/model/appcfg/parameter_test.go +++ b/cmd/api/src/model/appcfg/parameter_test.go @@ -1,3 +1,19 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + package appcfg_test import ( diff --git a/cmd/api/src/utils/validation/duration_validator.go b/cmd/api/src/utils/validation/duration_validator.go index eea3f8ed4..2ff8b8bde 100644 --- a/cmd/api/src/utils/validation/duration_validator.go +++ b/cmd/api/src/utils/validation/duration_validator.go @@ -1,3 +1,19 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + package validation import ( diff --git a/cmd/api/src/utils/validation/duration_validator_test.go b/cmd/api/src/utils/validation/duration_validator_test.go index 24bf049ba..111f035e4 100644 --- a/cmd/api/src/utils/validation/duration_validator_test.go +++ b/cmd/api/src/utils/validation/duration_validator_test.go @@ -1,3 +1,19 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + package validation_test import ( diff --git a/packages/go/analysis/ad/esc4.go b/packages/go/analysis/ad/esc4.go index a78c0cb36..9bbf4bd53 100644 --- a/packages/go/analysis/ad/esc4.go +++ b/packages/go/analysis/ad/esc4.go @@ -53,7 +53,7 @@ func PostADCSESC4(ctx context.Context, tx graph.Transaction, outC chan<- analysi } else { var ( - enterpriseCAEnrollers = cache.GetEnterpriseCAEnrollers(enterpriseCA.ID) + enterpriseCAEnrollers = cache.GetEnterpriseCAEnrollers(enterpriseCA.ID) certTemplateControllers = cache.GetCertTemplateControllers(certTemplate.ID) )