diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 91a0a8273f7..e5821e4c1e2 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -38,7 +38,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, } for _, bouncer := range bouncers { - dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(bouncer.Name) + dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(ctx, bouncer.Name) if err != nil { log.Errorf("unable to get bouncer usage metrics: %s", err) continue @@ -81,7 +81,7 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, } for _, lp := range lps { - dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(lp.MachineId) + dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(ctx, lp.MachineId) if err != nil { log.Errorf("unable to get LP usage metrics: %s", err) continue @@ -181,8 +181,8 @@ func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, return allMetrics, metricsIds, nil } -func (a *apic) MarkUsageMetricsAsSent(ids []int) error { - return a.dbClient.MarkUsageMetricsAsSent(ids) +func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + return a.dbClient.MarkUsageMetricsAsSent(ctx, ids) } func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { @@ -379,7 +379,7 @@ func (a *apic) SendUsageMetrics() { } } - err = a.MarkUsageMetricsAsSent(metricsId) + err = a.MarkUsageMetricsAsSent(ctx, metricsId) if err != nil { log.Errorf("unable to mark usage metrics as sent: %s", err) continue diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index 41dd0ccdc2c..019de5fb970 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -13,6 +13,8 @@ import ( ) func TestLPMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string body string @@ -198,7 +200,7 @@ func TestLPMetrics(t *testing.T) { assert.Contains(t, w.Body.String(), tt.expectedResponse) machine, _ := dbClient.QueryMachineByID("test") - metrics, _ := dbClient.GetLPUsageMetricsByMachineID("test") + metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) assert.Equal(t, tt.expectedOSName, machine.Osname) @@ -214,6 +216,8 @@ func TestLPMetrics(t *testing.T) { } func TestRCMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string body string @@ -368,7 +372,7 @@ func TestRCMetrics(t *testing.T) { assert.Contains(t, w.Body.String(), tt.expectedResponse) bouncer, _ := dbClient.SelectBouncerByName("test") - metrics, _ := dbClient.GetBouncerUsageMetricsByName("test") + metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) assert.Equal(t, tt.expectedOSName, bouncer.Osname) diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go index 1619fcc923b..99ba90c80b8 100644 --- a/pkg/database/metrics.go +++ b/pkg/database/metrics.go @@ -25,14 +25,14 @@ func (c *Client) CreateMetric(ctx context.Context, generatedType metric.Generate return metric, nil } -func (c *Client) GetLPUsageMetricsByMachineID(machineId string) ([]*ent.Metric, error) { +func (c *Client) GetLPUsageMetricsByMachineID(ctx context.Context, machineId string) ([]*ent.Metric, error) { metrics, err := c.Ent.Metric.Query(). Where( metric.GeneratedTypeEQ(metric.GeneratedTypeLP), metric.GeneratedByEQ(machineId), metric.PushedAtIsNil(), ). - All(c.CTX) + All(ctx) if err != nil { c.Log.Warningf("GetLPUsageMetricsByOrigin: %s", err) return nil, fmt.Errorf("getting LP usage metrics by origin %s: %w", machineId, err) @@ -41,14 +41,14 @@ func (c *Client) GetLPUsageMetricsByMachineID(machineId string) ([]*ent.Metric, return metrics, nil } -func (c *Client) GetBouncerUsageMetricsByName(bouncerName string) ([]*ent.Metric, error) { +func (c *Client) GetBouncerUsageMetricsByName(ctx context.Context, bouncerName string) ([]*ent.Metric, error) { metrics, err := c.Ent.Metric.Query(). Where( metric.GeneratedTypeEQ(metric.GeneratedTypeRC), metric.GeneratedByEQ(bouncerName), metric.PushedAtIsNil(), ). - All(c.CTX) + All(ctx) if err != nil { c.Log.Warningf("GetBouncerUsageMetricsByName: %s", err) return nil, fmt.Errorf("getting bouncer usage metrics by name %s: %w", bouncerName, err) @@ -57,11 +57,11 @@ func (c *Client) GetBouncerUsageMetricsByName(bouncerName string) ([]*ent.Metric return metrics, nil } -func (c *Client) MarkUsageMetricsAsSent(ids []int) error { +func (c *Client) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { _, err := c.Ent.Metric.Update(). Where(metric.IDIn(ids...)). SetPushedAt(time.Now().UTC()). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("MarkUsageMetricsAsSent: %s", err) return fmt.Errorf("marking usage metrics as sent: %w", err)