From de1346d4601fabe71316b91287fb39fe38962765 Mon Sep 17 00:00:00 2001 From: marco Date: Fri, 13 Sep 2024 10:30:53 +0200 Subject: [PATCH] context propagation: pass context to NewAPIC() --- cmd/crowdsec-cli/clipapi/papi.go | 12 +++++++----- pkg/apiserver/apic.go | 4 ++-- pkg/apiserver/apic_test.go | 4 +++- pkg/apiserver/apiserver.go | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index 747b8c01b9b..c0f08157f31 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -59,7 +59,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command { func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error { cfg := cli.cfg() - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } @@ -118,11 +118,11 @@ func (cli *cliPapi) newStatusCmd() *cobra.Command { return cmd } -func (cli *cliPapi) sync(out io.Writer, db *database.Client) error { +func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error { cfg := cli.cfg() t := tomb.Tomb{} - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } @@ -159,12 +159,14 @@ func (cli *cliPapi) newSyncCmd() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { cfg := cli.cfg() - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) if err != nil { return err } - return cli.sync(color.Output, db) + return cli.sync(ctx, color.Output, db) }, } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 73061637ad9..3ed2e12ea54 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -174,7 +174,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) return signal } -func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { +func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error ret := &apic{ @@ -237,7 +237,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return ret, fmt.Errorf("get scenario in db: %w", err) } - authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &config.Credentials.Login, Password: &password, Scenarios: scenarios, diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 51887006ad4..328d5c4ae09 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -230,6 +230,8 @@ func TestNewAPIC(t *testing.T) { }, } + ctx := context.Background() + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { setConfig() @@ -246,7 +248,7 @@ func TestNewAPIC(t *testing.T) { ), )) tc.action() - _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) + _, err := NewAPIC(ctx, testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 42dcb219379..8bf406e0a79 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -249,7 +249,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") - apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) + apiClient, err = NewAPIC(ctx, config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { return nil, err }