Skip to content

Commit

Permalink
feat: redirect to OIDC providers only once in registration flows
Browse files Browse the repository at this point in the history
test(e2e): ensure there is only one OIDC redirect

Co-authored-by: Jakub Fijałkowski <[email protected]>
  • Loading branch information
2 people authored and David-Wobrock committed Sep 13, 2024
1 parent 2c7ff3c commit 04d1348
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
38 changes: 38 additions & 0 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{

type MetadataType string

type OIDCProviderData struct {
Provider string `json:"provider"`
Tokens *identity.CredentialsOIDCEncryptedTokens `json:"tokens"`
Claims Claims `json:"claims"`
}

type VerifiedAddress struct {
Value string `json:"value"`
Via identity.VerifiableAddressType `json:"via"`
Expand All @@ -53,6 +59,8 @@ const (

PublicMetadata MetadataType = "identity.metadata_public"
AdminMetadata MetadataType = "identity.metadata_admin"

InternalContextKeyProviderData = "provider_data"
)

func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) {
Expand Down Expand Up @@ -216,6 +224,25 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
return errors.WithStack(flow.ErrCompletedByStrategy)
}

if oidcProviderData := gjson.GetBytes(f.InternalContext, flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)); oidcProviderData.IsObject() {
var providerData OIDCProviderData
if err := json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil {
return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %s", err)))
}
if pid != providerData.Provider {
return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider)))
}
_, err = s.processRegistration(ctx, w, r, f, providerData.Tokens, &providerData.Claims, provider, &AuthCodeContainer{
FlowID: f.ID.String(),
Traits: p.Traits,
TransientPayload: f.TransientPayload,
})
if err != nil {
return s.handleError(ctx, w, r, f, pid, nil, err)
}
return errors.WithStack(flow.ErrCompletedByStrategy)
}

state, pkce, err := s.GenerateState(ctx, provider, f.ID)
if err != nil {
return s.handleError(ctx, w, r, f, pid, nil, err)
Expand Down Expand Up @@ -312,6 +339,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
return nil, nil
}

providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData)
if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData {
if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Tokens: token, Claims: *claims}); err == nil {
rf.InternalContext = internalContext
}
}

fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(ctx)), fetcher.WithCache(jsonnetCache, 60*time.Minute))
jsonnetMapperSnippet, err := fetch.FetchContext(ctx, provider.Config().Mapper)
if err != nil {
Expand Down Expand Up @@ -350,6 +384,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err)
}

if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil {
rf.InternalContext = internalContext
}

return nil, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,52 @@ context("Social Sign Up Successes", () => {
})
})

it("should redirect to oidc provider only once", () => {
const email = gen.email()

cy.registerOidc({
app,
email,
expectSession: false,
route: registration,
})

cy.get(appPrefix(app) + '[name="traits.email"]').should(
"have.value",
email,
)

cy.get('[name="traits.consent"][type="checkbox"]')
.siblings("label")
.click()
cy.get('[name="traits.newsletter"][type="checkbox"]')
.siblings("label")
.click()
cy.get('[name="traits.website"]').type(website)

cy.intercept("GET", "http://*/oauth2/auth*", {
forceNetworkError: true,
}).as("additionalRedirect")

cy.triggerOidc(app)

cy.get("@additionalRedirect").should("not.exist")

cy.location("pathname").should((loc) => {
expect(loc).to.be.oneOf([
"/welcome",
"/",
"/sessions",
"/verification",
])
})

cy.getSession().should((session) => {
shouldSession(email)(session)
expect(session.identity.traits.consent).to.equal(true)
})
})

it("should pass transient_payload to webhook", () => {
testFlowWebhook(
(hooks) =>
Expand Down

0 comments on commit 04d1348

Please sign in to comment.