Skip to content

Commit

Permalink
context propagation: pkg/database/alerts (#3252)
Browse files Browse the repository at this point in the history
* alerts
* drop CTX from dbclient
* lint
* pkg/database/alerts: context.TODO()
* cscli: context.Background() -> cmd.Context()
  • Loading branch information
mmetc committed Sep 24, 2024
1 parent 1133afe commit 3945a99
Show file tree
Hide file tree
Showing 21 changed files with 141 additions and 123 deletions.
4 changes: 2 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ linters-settings:

maintidx:
# raise this after refactoring
under: 16
under: 15

misspell:
locale: US
Expand Down Expand Up @@ -118,7 +118,7 @@ linters-settings:
arguments: [6]
- name: function-length
# lower this after refactoring
arguments: [110, 235]
arguments: [110, 237]
- name: get-return
disabled: true
- name: increment-decrement
Expand Down
20 changes: 10 additions & 10 deletions cmd/crowdsec-cli/clialert/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (cli *cliAlerts) NewCommand() *cobra.Command {
return cmd
}

func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error {
func (cli *cliAlerts) list(ctx context.Context, alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error {
var err error

*alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals)
Expand Down Expand Up @@ -311,7 +311,7 @@ func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int,
alertListFilter.Contains = new(bool)
}

alerts, _, err := cli.client.Alerts.List(context.Background(), alertListFilter)
alerts, _, err := cli.client.Alerts.List(ctx, alertListFilter)
if err != nil {
return fmt.Errorf("unable to list alerts: %w", err)
}
Expand Down Expand Up @@ -354,7 +354,7 @@ cscli alerts list --type ban`,
Long: `List alerts with optional filters`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.list(alertListFilter, limit, contained, printMachine)
return cli.list(cmd.Context(), alertListFilter, limit, contained, printMachine)
},
}

Expand All @@ -377,7 +377,7 @@ cscli alerts list --type ban`,
return cmd
}

func (cli *cliAlerts) delete(delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error {
func (cli *cliAlerts) delete(ctx context.Context, delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error {
var err error

if !deleteAll {
Expand Down Expand Up @@ -423,12 +423,12 @@ func (cli *cliAlerts) delete(delFilter apiclient.AlertsDeleteOpts, activeDecisio

var alerts *models.DeleteAlertsResponse
if delAlertByID == "" {
alerts, _, err = cli.client.Alerts.Delete(context.Background(), delFilter)
alerts, _, err = cli.client.Alerts.Delete(ctx, delFilter)
if err != nil {
return fmt.Errorf("unable to delete alerts: %w", err)
}
} else {
alerts, _, err = cli.client.Alerts.DeleteOne(context.Background(), delAlertByID)
alerts, _, err = cli.client.Alerts.DeleteOne(ctx, delAlertByID)
if err != nil {
return fmt.Errorf("unable to delete alert: %w", err)
}
Expand Down Expand Up @@ -480,7 +480,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
return nil
},
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.delete(delFilter, activeDecision, deleteAll, delAlertByID, contained)
return cli.delete(cmd.Context(), delFilter, activeDecision, deleteAll, delAlertByID, contained)
},
}

Expand All @@ -498,7 +498,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
return cmd
}

func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error {
func (cli *cliAlerts) inspect(ctx context.Context, details bool, alertIDs ...string) error {
cfg := cli.cfg()

for _, alertID := range alertIDs {
Expand All @@ -507,7 +507,7 @@ func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error {
return fmt.Errorf("bad alert id %s", alertID)
}

alert, _, err := cli.client.Alerts.GetByID(context.Background(), id)
alert, _, err := cli.client.Alerts.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("can't find alert with id %s: %w", alertID, err)
}
Expand Down Expand Up @@ -551,7 +551,7 @@ func (cli *cliAlerts) newInspectCmd() *cobra.Command {
_ = cmd.Help()
return errors.New("missing alert_id")
}
return cli.inspect(details, args...)
return cli.inspect(cmd.Context(), details, args...)
},
}

Expand Down
8 changes: 4 additions & 4 deletions cmd/crowdsec-cli/cliconsole/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (cli *cliConsole) NewCommand() *cobra.Command {
return cmd
}

func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []string, opts []string) error {
func (cli *cliConsole) enroll(ctx context.Context, key string, name string, overwrite bool, tags []string, opts []string) error {
cfg := cli.cfg()
password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password)

Expand Down Expand Up @@ -127,7 +127,7 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st
VersionPrefix: "v3",
})

resp, err := c.Auth.EnrollWatcher(context.Background(), key, name, tags, overwrite)
resp, err := c.Auth.EnrollWatcher(ctx, key, name, tags, overwrite)
if err != nil {
return fmt.Errorf("could not enroll instance: %w", err)
}
Expand Down Expand Up @@ -173,8 +173,8 @@ After running this command your will need to validate the enrollment in the weba
valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")),
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error {
return cli.enroll(args[0], name, overwrite, tags, opts)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.enroll(cmd.Context(), args[0], name, overwrite, tags, opts)
},
}

Expand Down
22 changes: 11 additions & 11 deletions cmd/crowdsec-cli/clidecision/decisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (cli *cliDecisions) NewCommand() *cobra.Command {
return cmd
}

func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error {
func (cli *cliDecisions) list(ctx context.Context, filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error {
var err error

*filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals)
Expand Down Expand Up @@ -249,7 +249,7 @@ func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, con
filter.Contains = new(bool)
}

alerts, _, err := cli.client.Alerts.List(context.Background(), filter)
alerts, _, err := cli.client.Alerts.List(ctx, filter)
if err != nil {
return fmt.Errorf("unable to retrieve decisions: %w", err)
}
Expand Down Expand Up @@ -293,7 +293,7 @@ cscli decisions list --origin lists --scenario list_name
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.list(filter, NoSimu, contained, printMachine)
return cli.list(cmd.Context(), filter, NoSimu, contained, printMachine)
},
}

Expand All @@ -317,7 +317,7 @@ cscli decisions list --origin lists --scenario list_name
return cmd
}

func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error {
func (cli *cliDecisions) add(ctx context.Context, addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error {
alerts := models.AddAlertsRequest{}
origin := types.CscliOrigin
capacity := int32(0)
Expand Down Expand Up @@ -386,7 +386,7 @@ func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, a
}
alerts = append(alerts, &alert)

_, _, err = cli.client.Alerts.Add(context.Background(), alerts)
_, _, err = cli.client.Alerts.Add(ctx, alerts)
if err != nil {
return err
}
Expand Down Expand Up @@ -419,7 +419,7 @@ cscli decisions add --scope username --value foobar
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.add(addIP, addRange, addDuration, addValue, addScope, addReason, addType)
return cli.add(cmd.Context(), addIP, addRange, addDuration, addValue, addScope, addReason, addType)
},
}

Expand All @@ -436,7 +436,7 @@ cscli decisions add --scope username --value foobar
return cmd
}

func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error {
func (cli *cliDecisions) delete(ctx context.Context, delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error {
var err error

/*take care of shorthand options*/
Expand Down Expand Up @@ -480,7 +480,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci
var decisions *models.DeleteDecisionResponse

if delDecisionID == "" {
decisions, _, err = cli.client.Decisions.Delete(context.Background(), delFilter)
decisions, _, err = cli.client.Decisions.Delete(ctx, delFilter)
if err != nil {
return fmt.Errorf("unable to delete decisions: %w", err)
}
Expand All @@ -489,7 +489,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci
return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err)
}

decisions, _, err = cli.client.Decisions.DeleteOne(context.Background(), delDecisionID)
decisions, _, err = cli.client.Decisions.DeleteOne(ctx, delDecisionID)
if err != nil {
return fmt.Errorf("unable to delete decision: %w", err)
}
Expand Down Expand Up @@ -543,8 +543,8 @@ cscli decisions delete --origin lists --scenario list_name

return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
return cli.delete(delFilter, delDecisionID, contained)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.delete(cmd.Context(), delFilter, delDecisionID, contained)
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clilapi/lapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login
Scenarios: itemsForAPI,
}

_, _, err = client.Auth.AuthenticateWatcher(context.Background(), t)
_, _, err = client.Auth.AuthenticateWatcher(ctx, t)
if err != nil {
return false, err
}
Expand Down
8 changes: 4 additions & 4 deletions cmd/crowdsec-cli/clinotifications/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,9 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
PreRunE: func(_ *cobra.Command, args []string) error {
PreRunE: func(cmd *cobra.Command, args []string) error {
var err error
alert, err = cli.fetchAlertFromArgString(args[0])
alert, err = cli.fetchAlertFromArgString(cmd.Context(), args[0])
if err != nil {
return err
}
Expand Down Expand Up @@ -447,7 +447,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
return cmd
}

func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Alert, error) {
func (cli *cliNotifications) fetchAlertFromArgString(ctx context.Context, toParse string) (*models.Alert, error) {
cfg := cli.cfg()

id, err := strconv.Atoi(toParse)
Expand All @@ -470,7 +470,7 @@ func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Al
return nil, fmt.Errorf("error creating the client for the API: %w", err)
}

alert, _, err := client.Alerts.GetByID(context.Background(), id)
alert, _, err := client.Alerts.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("can't find alert with id %d: %w", id, err)
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
}
}

func (a *apic) CAPIPullIsOld() (bool, error) {
func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) {
/*only pull community blocklist if it's older than 1h30 */
alerts := a.dbClient.Ent.Alert.Query()
alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert

count, err := alerts.Count(a.dbClient.CTX)
count, err := alerts.Count(ctx)
if err != nil {
return false, fmt.Errorf("while looking for CAPI alert: %w", err)
}
Expand Down Expand Up @@ -634,7 +634,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
}

if !forcePull {
if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil {
if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil {
return err
} else if !lastPullIsOld {
return nil
Expand Down Expand Up @@ -769,6 +769,8 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
}

func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
ctx := context.TODO()

for _, alert := range alertsFromCapi {
setAlertScenario(alert, addCounters, deleteCounters)
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
Expand All @@ -777,7 +779,7 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string
log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
}

alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert)
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert)
if err != nil {
return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
func TestAPICCAPIPullIsOld(t *testing.T) {
api := getAPIC(t)

isOld, err := api.CAPIPullIsOld()
ctx := context.Background()

isOld, err := api.CAPIPullIsOld(ctx)
require.NoError(t, err)
assert.True(t, isOld)

Expand All @@ -124,17 +126,17 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
SetScope("Country").
SetValue("Blah").
SetOrigin(types.CAPIOrigin).
SaveX(context.Background())
SaveX(ctx)

api.dbClient.Ent.Alert.Create().
SetCreatedAt(time.Now()).
SetScenario("crowdsec/test").
AddDecisions(
decision,
).
SaveX(context.Background())
SaveX(ctx)

isOld, err = api.CAPIPullIsOld()
isOld, err = api.CAPIPullIsOld(ctx)
require.NoError(t, err)

assert.False(t, isOld)
Expand Down
Loading

0 comments on commit 3945a99

Please sign in to comment.