Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

context propagation: pkg/database/config #3246

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
return fmt.Errorf("unable to get PAPI permissions: %w", err)
}

lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey)
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)

Check warning on line 77 in cmd/crowdsec-cli/clipapi/papi.go

View check run for this annotation

Codecov / codecov/patch

cmd/crowdsec-cli/clipapi/papi.go#L77

Added line #L77 was not covered by tests
if err != nil {
lastTimestampStr = ptr.Of("never")
}
Expand Down
28 changes: 14 additions & 14 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@
// we receive a list of decisions and links for blocklist and we need to create a list of alerts :
// one alert for "community blocklist"
// one alert per list we're subscribed to
func (a *apic) PullTop(forcePull bool) error {
func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
var err error

// A mutex with TryLock would be a bit simpler
Expand Down Expand Up @@ -655,7 +655,7 @@

log.Infof("Starting community-blocklist update")

data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup})
data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup})
if err != nil {
return fmt.Errorf("get stream: %w", err)
}
Expand Down Expand Up @@ -700,17 +700,17 @@
}

// update blocklists
if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil {
return fmt.Errorf("while updating blocklists: %w", err)
}

return nil
}

// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error {
addCounters, _ := makeAddAndDeleteCounters()
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{blocklist},
}, addCounters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err)
Expand Down Expand Up @@ -820,7 +820,7 @@
return false, nil
}

func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
if blocklist.Scope == nil {
log.Warningf("blocklist has no scope")
return nil
Expand Down Expand Up @@ -848,13 +848,13 @@
)

if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName)

Check warning on line 851 in pkg/apiserver/apic.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apic.go#L851

Added line #L851 was not covered by tests
if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
}

decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp)
if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -869,7 +869,7 @@
return nil
}

err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -892,7 +892,7 @@
return nil
}

func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
if links == nil {
return nil
}
Expand All @@ -908,7 +908,7 @@
}

for _, blocklist := range links.Blocklists {
if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil {
return err
}
}
Expand All @@ -931,7 +931,7 @@
}
}

func (a *apic) Pull() error {
func (a *apic) Pull(ctx context.Context) error {
defer trace.CatchPanic("lapi/pullFromAPIC")

toldOnce := false
Expand All @@ -955,7 +955,7 @@
time.Sleep(1 * time.Second)
}

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
}

Expand All @@ -967,7 +967,7 @@
case <-ticker.C:
ticker.Reset(a.pullInterval)

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
continue
}
Expand Down
24 changes: 15 additions & 9 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
}

func TestAPICWhitelists(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
// one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
Expand Down Expand Up @@ -685,7 +686,7 @@ func TestAPICWhitelists(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
Expand Down Expand Up @@ -736,6 +737,7 @@ func TestAPICWhitelists(t *testing.T) {
}

func TestAPICPullTop(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin).
Expand Down Expand Up @@ -826,7 +828,7 @@ func TestAPICPullTop(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5)
Expand Down Expand Up @@ -860,6 +862,7 @@ func TestAPICPullTop(t *testing.T) {
}

func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
ctx := context.Background()
// no decision in db, no last modified parameter.
api := getAPIC(t)

Expand Down Expand Up @@ -913,11 +916,11 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

blocklistConfigItemName := "blocklist:blocklist1:last_pull"
lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.NotEqual(t, "", *lastPullTimestamp)

Expand All @@ -927,14 +930,15 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
return httpmock.NewStringResponse(304, ""), nil
})

err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
}

func TestAPICPullTopBLCacheForceCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)

httpmock.Activate()
Expand Down Expand Up @@ -1005,11 +1009,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)
}

func TestAPICPullBlocklistCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)

httpmock.Activate()
Expand All @@ -1032,7 +1037,7 @@ func TestAPICPullBlocklistCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullBlocklist(&modelscapi.BlocklistLink{
err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{
URL: ptr.Of("http://api.crowdsec.net/blocklist1"),
Name: ptr.Of("blocklist1"),
Scope: ptr.Of("Ip"),
Expand Down Expand Up @@ -1134,6 +1139,7 @@ func TestAPICPush(t *testing.T) {
}

func TestAPICPull(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
tests := []struct {
name string
Expand Down Expand Up @@ -1204,7 +1210,7 @@ func TestAPICPull(t *testing.T) {
go func() {
logrus.SetOutput(&buf)

if err := api.Pull(); err != nil {
if err := api.Pull(ctx); err != nil {
panic(err)
}
}()
Expand Down
18 changes: 10 additions & 8 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,17 @@
return nil
}

func (s *APIServer) apicPull() error {
if err := s.apic.Pull(); err != nil {
func (s *APIServer) apicPull(ctx context.Context) error {
if err := s.apic.Pull(ctx); err != nil {
log.Errorf("capi pull: %s", err)
return err
}

return nil
}

func (s *APIServer) papiPull() error {
if err := s.papi.Pull(); err != nil {
func (s *APIServer) papiPull(ctx context.Context) error {
if err := s.papi.Pull(ctx); err != nil {

Check warning on line 323 in pkg/apiserver/apiserver.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apiserver.go#L322-L323

Added lines #L322 - L323 were not covered by tests
log.Errorf("papi pull: %s", err)
return err
}
Expand All @@ -337,16 +337,16 @@
return nil
}

func (s *APIServer) initAPIC() {
func (s *APIServer) initAPIC(ctx context.Context) {
s.apic.pushTomb.Go(s.apicPush)
s.apic.pullTomb.Go(s.apicPull)
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })

// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
if s.apic.apiClient.IsEnrolled() {
if s.consoleConfig.IsPAPIEnabled() {
if s.papi.URL != "" {
log.Info("Starting PAPI decision receiver")
s.papi.pullTomb.Go(s.papiPull)
s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) })

Check warning on line 349 in pkg/apiserver/apiserver.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apiserver.go#L349

Added line #L349 was not covered by tests
s.papi.syncTomb.Go(s.papiSync)
} else {
log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.")
Expand Down Expand Up @@ -381,8 +381,10 @@
TLSConfig: tlsCfg,
}

ctx := context.TODO()

Check warning on line 385 in pkg/apiserver/apiserver.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apiserver.go#L385

Added line #L385 was not covered by tests
if s.apic != nil {
s.initAPIC()
s.initAPIC(ctx)
}

s.httpServerTomb.Go(func() error {
Expand Down
8 changes: 4 additions & 4 deletions pkg/apiserver/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@
}

// PullPAPI is the long polling client for real-time decisions from PAPI
func (p *Papi) Pull() error {
func (p *Papi) Pull(ctx context.Context) error {

Check warning on line 233 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L233

Added line #L233 was not covered by tests
defer trace.CatchPanic("lapi/PullPAPI")
p.Logger.Infof("Starting Polling API Pull")

lastTimestamp := time.Time{}

lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey)

Check warning on line 239 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L239

Added line #L239 was not covered by tests
if err != nil {
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
}
Expand All @@ -248,7 +248,7 @@
return fmt.Errorf("failed to serialize last timestamp: %w", err)
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {

Check warning on line 251 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L251

Added line #L251 was not covered by tests
p.Logger.Errorf("error setting papi pull last key: %s", err)
} else {
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
Expand Down Expand Up @@ -277,7 +277,7 @@
continue
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {

Check warning on line 280 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L280

Added line #L280 was not covered by tests
return fmt.Errorf("failed to update last timestamp: %w", err)
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/apiserver/papi_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apiserver

import (
"context"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -215,17 +216,19 @@
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
}

ctx := context.TODO()

Check warning on line 220 in pkg/apiserver/papi_cmd.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi_cmd.go#L219-L220

Added lines #L219 - L220 were not covered by tests
if forcePullMsg.Blocklist == nil {
p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")

err = p.apic.PullTop(true)
err = p.apic.PullTop(ctx, true)

Check warning on line 224 in pkg/apiserver/papi_cmd.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi_cmd.go#L224

Added line #L224 was not covered by tests
if err != nil {
return fmt.Errorf("failed to force pull operation: %w", err)
}
} else {
p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)

err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{
err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{

Check warning on line 231 in pkg/apiserver/papi_cmd.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi_cmd.go#L231

Added line #L231 was not covered by tests
Name: &forcePullMsg.Blocklist.Name,
URL: &forcePullMsg.Blocklist.Url,
Remediation: &forcePullMsg.Blocklist.Remediation,
Expand Down
12 changes: 6 additions & 6 deletions pkg/database/config.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package database

import (
"context"
"github.com/pkg/errors"

"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
)

func (c *Client) GetConfigItem(key string) (*string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX)
func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
if err != nil && ent.IsNotFound(err) {
return nil, nil
}
Expand All @@ -19,11 +20,10 @@ func (c *Client) GetConfigItem(key string) (*string, error) {
return &result.Value, nil
}

func (c *Client) SetConfigItem(key string, value string) error {

nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX)
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX)
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
if err != nil {
return errors.Wrapf(QueryFail, "insert config item: %s", err)
}
Expand Down