From 7464429bc8ccc5d920b7533ce5c7f268715cc34d Mon Sep 17 00:00:00 2001 From: David Orchard Date: Thu, 19 Sep 2024 15:34:30 -0700 Subject: [PATCH] add unit tests --- core/capabilities/webapi/trigger.go | 25 ++- core/capabilities/webapi/trigger_test.go | 150 ++++++++++++++++++ .../gateway/web_api_trigger/invoke_trigger.go | 22 +-- 3 files changed, 172 insertions(+), 25 deletions(-) create mode 100644 core/capabilities/webapi/trigger_test.go diff --git a/core/capabilities/webapi/trigger.go b/core/capabilities/webapi/trigger.go index 07812fb533..e9247ff3b7 100644 --- a/core/capabilities/webapi/trigger.go +++ b/core/capabilities/webapi/trigger.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/workflow" - "github.com/smartcontractkit/chainlink/v2/core/services/job" ) const defaultSendChannelBufferSize = 1000 @@ -50,6 +49,7 @@ type triggerConnectorHandler struct { mu sync.Mutex // Will this have to get pulled into a store to have the topic and workflow ID? registeredWorkflows map[string]chan capabilities.TriggerResponse + allowedSendersMap map[string]bool signerKey *ecdsa.PrivateKey rateLimiter *common.RateLimiter } @@ -61,7 +61,7 @@ var _ services.Service = &triggerConnectorHandler{} // Once connected to a Gateway, each connector handler periodically sends metadata messages containing aggregated // config for all registered workflow specs using web-trigger. -func NewTrigger(config TriggerConfig, registry core.CapabilitiesRegistry, connector connector.GatewayConnector, signerKey *ecdsa.PrivateKey, lggr logger.Logger) (job.ServiceCtx, error) { +func NewTrigger(config TriggerConfig, registry core.CapabilitiesRegistry, connector connector.GatewayConnector, signerKey *ecdsa.PrivateKey, lggr logger.Logger) (*triggerConnectorHandler, error) { // TODO (CAPPL-22, CAPPL-24): // - decode config // - create an implementation of the capability API and add it to the Registry @@ -73,13 +73,18 @@ func NewTrigger(config TriggerConfig, registry core.CapabilitiesRegistry, connec if err != nil { return nil, err } + allowedSendersMap := map[string]bool{} + for _, k := range config.AllowedSenders { + allowedSendersMap[k.String()] = true + } handler := &triggerConnectorHandler{ - config: config, - connector: connector, - signerKey: signerKey, - rateLimiter: rateLimiter, - lggr: lggr.Named("WorkflowConnectorHandler"), + allowedSendersMap: allowedSendersMap, + config: config, + connector: connector, + signerKey: signerKey, + rateLimiter: rateLimiter, + lggr: lggr.Named("WorkflowConnectorHandler"), } return handler, nil @@ -131,7 +136,10 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate h.lggr.Errorw("request rate-limited") return } - // TODO: apply allowlist + if !h.allowedSendersMap[sender.String()] { + h.lggr.Errorw("Unauthorized Sender") + return + } h.lggr.Debugw("handling gateway request", "id", gatewayID, "method", body.Method, "sender", sender) var payload TriggerRequestPayload err := json.Unmarshal(body.Payload, &payload) @@ -142,6 +150,7 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate switch body.Method { case workflow.MethodWebAPITrigger: h.lggr.Debugw("added MethodWebAPITrigger message", "payload", string(body.Payload)) + // TODO: Is the staleness check supposed to be in the gateway? currentTime := time.Now() // TODO: check against h.config.MaxAllowedMessageAgeSec if currentTime.Unix()-3000 > payload.Timestamp { diff --git a/core/capabilities/webapi/trigger_test.go b/core/capabilities/webapi/trigger_test.go new file mode 100644 index 0000000000..87222c1913 --- /dev/null +++ b/core/capabilities/webapi/trigger_test.go @@ -0,0 +1,150 @@ +package webapi + +import ( + "encoding/json" + "flag" + "testing" + + ethCommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + registrymock "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" + corelogger "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" +) + +const ( + workflowID1 = "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0" + workflowExecutionID1 = "95ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0abbadeed" + owner1 = "0x00000000000000000000000000000000000000aa" +) + +type testHarness struct { + registry *registrymock.CapabilitiesRegistry + connector *gcmocks.GatewayConnector + lggr logger.Logger + config TriggerConfig + trigger *triggerConnectorHandler +} + +func setup(t *testing.T) testHarness { + privateKey, _ := testutils.NewPrivateKeyAndAddress(t) + registry := registrymock.NewCapabilitiesRegistry(t) + connector := gcmocks.NewGatewayConnector(t) + lggr := corelogger.TestLogger(t) + config := TriggerConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, + AllowedSenders: []ethCommon.Address{ethCommon.HexToAddress("a")}, + } + trigger, err := NewTrigger(config, registry, connector, privateKey, lggr) + require.NoError(t, err) + + return testHarness{ + registry: registry, + connector: connector, + lggr: lggr, + config: config, + trigger: trigger, + } +} + +func gatewayRequest(t *testing.T) *api.Message { + // TODO: are flags like this ok? this is how the upload_workflow test script does it + privateKey := flag.String("private_key", "65456ffb8af4a2b93959256a8e04f6f2fe0943579fb3c9c3350593aabb89023f", "Private key to sign the message with") + messageID := flag.String("id", "12345", "Request ID") + methodName := flag.String("method", "web_trigger", "Method name") + donID := flag.String("don_id", "workflow_don_1", "DON ID") + + flag.Parse() + key, err := crypto.HexToECDSA(*privateKey) + require.NoError(t, err) + + payload := `{ + trigger_id: "web-trigger@1.0.0", + trigger_event_id: "action_1234567890", + timestamp: 1234567890, + topics: ["daily_price_update"], + params: { + bid: "101", + ask: "102" + } + } +` + payloadJSON := []byte(payload) + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: *messageID, + Method: *methodName, + DonId: *donID, + Payload: json.RawMessage(payloadJSON), + }, + } + err = msg.Sign(key) + require.NoError(t, err) + + return msg +} + +func TestCapability_Execute(t *testing.T) { + th := setup(t) + ctx := testutils.Context(t) + + t.Run("happy case", func(t *testing.T) { + triggerReq := capabilities.TriggerRegistrationRequest{ + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + } + _, err := th.trigger.RegisterTrigger(ctx, triggerReq) + require.NoError(t, err) + + gatewayRequest := gatewayRequest(t) + + th.connector.On("SendToGateway", mock.Anything, mock.Anything).Return(nil).Once() + + // TODO: verify SendToGateway called + th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) + + // TODO: verify message sent to trigger channel + }) + + // TODO: allowedSenders fail + // TODO: rateLimit fail + // TODO: empty allowedSenders + // TODO: missing required parameters + // TODO: invalid message + // TODO: other edge cases? empty topics? + // TODO: Test duplicate messages, ie PENDING returned. + // TODO: Test message sent to multiple trigger channels +} + +func TestRegisterUnregister(t *testing.T) { + th := setup(t) + ctx := testutils.Context(t) + + triggerReq := capabilities.TriggerRegistrationRequest{ + Metadata: capabilities.RequestMetadata{ + WorkflowID: workflowID1, + WorkflowOwner: owner1, + }, + } + _, err := th.trigger.RegisterTrigger(ctx, triggerReq) + require.NoError(t, err) + + err = th.trigger.UnregisterTrigger(ctx, triggerReq) + require.NoError(t, err) +} diff --git a/core/scripts/gateway/web_api_trigger/invoke_trigger.go b/core/scripts/gateway/web_api_trigger/invoke_trigger.go index 3eeeca00d9..a1ff0a1ce0 100644 --- a/core/scripts/gateway/web_api_trigger/invoke_trigger.go +++ b/core/scripts/gateway/web_api_trigger/invoke_trigger.go @@ -55,8 +55,6 @@ func main() { messageID := flag.String("id", "12345", "Request ID") methodName := flag.String("method", "web_trigger", "Method name") donID := flag.String("don_id", "workflow_don_1", "DON ID") - // workflowSpec := flag.String("workflow_spec", "[my spec abcd]", "Workflow Spec") - // payloadJSON := []byte("{\"spec\": \"" + *workflowSpec + "\"}") flag.Parse() @@ -81,21 +79,11 @@ func main() { trigger_id: "web-trigger@1.0.0", trigger_event_id: "action_1234567890", timestamp: 1234567890, - sub-events: [ - { - topics: ["daily_price_update"], - params: { - bid: "101", - ask: "102" - } - }, - { - topics: ["daily_message", "summary"], - params: { - message: "all good!", - } - }, - ] + topics: ["daily_price_update"], + params: { + bid: "101", + ask: "102" + } } ` payloadJSON := []byte(payload)