Skip to content

Commit

Permalink
fix: broken JSON round-tripping for custom claims
Browse files Browse the repository at this point in the history
Adding custom claims with numerical types (think JavaScript Number) previously did not
round-trip through Hydra correctly. For example, passing UNIX timestamps in custom claims
would end up as floating points in exponential notation in the final token. That, in turn,
confused or broke downstream consumers of the token, including Kratos.

Ref go-jose/go-jose#144
  • Loading branch information
alnr committed Aug 13, 2024
1 parent 23c7464 commit a28aed9
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 32 deletions.
42 changes: 29 additions & 13 deletions consent/strategy_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ func TestStrategyLoginConsentNext(t *testing.T) {

subject := "aeneas-rekkas"
c := createDefaultClient(t)
now := 1723546027 // Unix timestamps must round-trip through Hydra without converting to floats or similar
testhelpers.NewLoginConsentUI(t, reg.Config(),
acceptLoginHandler(t, subject, &hydra.AcceptOAuth2LoginRequest{
Remember: pointerx.Bool(true),
Expand All @@ -297,8 +298,14 @@ func TestStrategyLoginConsentNext(t *testing.T) {
Remember: pointerx.Bool(true),
GrantScope: []string{"openid"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{"foo": "bar"},
IdToken: map[string]interface{}{"bar": "baz"},
AccessToken: map[string]interface{}{
"foo": "bar",
"ts1": now,
},
IdToken: map[string]interface{}{
"bar": "baz",
"ts2": now,
},
},
}))

Expand All @@ -314,12 +321,14 @@ func TestStrategyLoginConsentNext(t *testing.T) {
require.NoError(t, err)

claims := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.Equal(t, "bar", claims.Get("ext.foo").String(), "%s", claims.Raw)
assert.Equalf(t, `"bar"`, claims.Get("ext.foo").Raw, "%s", claims.Raw) // Raw rather than .Int() or .Value() to verify the exact JSON payload
assert.Equalf(t, "1723546027", claims.Get("ext.ts1").Raw, "%s", claims.Raw) // must round-trip as integer

idClaims := testhelpers.DecodeIDToken(t, token)
assert.Equal(t, "baz", idClaims.Get("bar").String(), "%s", idClaims.Raw)
assert.Equalf(t, `"baz"`, idClaims.Get("bar").Raw, "%s", idClaims.Raw) // Raw rather than .Int() or .Value() to verify the exact JSON payload
assert.Equalf(t, "1723546027", idClaims.Get("ts2").Raw, "%s", idClaims.Raw) // must round-trip as integer
sid = idClaims.Get("sid").String()
assert.NotNil(t, sid)
assert.NotEmpty(t, sid)
}

t.Run("perform first flow", run)
Expand All @@ -334,21 +343,28 @@ func TestStrategyLoginConsentNext(t *testing.T) {
assert.Empty(t, pointerx.StringR(res.Client.ClientSecret))
return hydra.AcceptOAuth2LoginRequest{
Subject: subject,
Context: map[string]interface{}{"foo": "bar"},
Context: map[string]interface{}{"xyz": "abc"},
}
}),
checkAndAcceptConsentHandler(t, adminClient, func(t *testing.T, res *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
checkAndAcceptConsentHandler(t, adminClient, func(t *testing.T, req *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
require.NoError(t, err)
assert.True(t, *res.Skip)
assert.Equal(t, sid, *res.LoginSessionId)
assert.Equal(t, subject, *res.Subject)
assert.Empty(t, pointerx.StringR(res.Client.ClientSecret))
assert.True(t, *req.Skip)
assert.Equal(t, sid, *req.LoginSessionId)
assert.Equal(t, subject, *req.Subject)
assert.Empty(t, pointerx.StringR(req.Client.ClientSecret))
assert.Equal(t, map[string]interface{}{"xyz": "abc"}, req.Context)
return hydra.AcceptOAuth2ConsentRequest{
Remember: pointerx.Bool(true),
GrantScope: []string{"openid"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{"foo": "bar"},
IdToken: map[string]interface{}{"bar": "baz"},
AccessToken: map[string]interface{}{
"foo": "bar",
"ts1": now,
},
IdToken: map[string]interface{}{
"bar": "baz",
"ts2": now,
},
},
}
}))
Expand Down
3 changes: 2 additions & 1 deletion oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"amr": [],
"c_hash": "",
"ext": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"headers": {
Expand Down
3 changes: 2 additions & 1 deletion oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"amr": [],
"c_hash": "",
"ext": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"headers": {
Expand Down
3 changes: 2 additions & 1 deletion oauth2/fixtures/v1.11.8-session.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"AuthenticationMethodsReferences": [],
"CodeHash": "",
"Extra": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"Headers": {
Expand Down
3 changes: 2 additions & 1 deletion oauth2/fixtures/v1.11.9-session.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"amr": [],
"c_hash": "",
"ext": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"headers": {
Expand Down
28 changes: 15 additions & 13 deletions oauth2/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
package oauth2

import (
"bytes"
"context"
"encoding/json"
"time"

jjson "github.com/go-jose/go-jose/v3/json"
"github.com/mohae/deepcopy"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/mohae/deepcopy"

"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"

"github.com/ory/x/logrusx"
"github.com/ory/x/stringslice"
)
Expand Down Expand Up @@ -60,33 +59,33 @@ func NewSessionWithCustomClaims(ctx context.Context, p *config.DefaultProvider,
}

func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer {
//a slice of claims that are reserved and should not be overridden
var reservedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}
// a slice of claims that are reserved and should not be overridden
reservedClaims := []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}

//remove any reserved claims from the custom claims
// remove any reserved claims from the custom claims
allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool {
return stringslice.Has(reservedClaims, s)
})

//our new extra map which will be added to the jwt
var topLevelExtraWithMirrorExt = map[string]interface{}{}
// our new extra map which will be added to the jwt
topLevelExtraWithMirrorExt := map[string]interface{}{}

//setting every allowed claim top level in jwt with respective value
// setting every allowed claim top level in jwt with respective value
for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved {
if cl, ok := s.Extra[allowedClaim]; ok {
topLevelExtraWithMirrorExt[allowedClaim] = cl
}
}

//for every other claim that was already reserved and for mirroring, add original extra under "ext"
// for every other claim that was already reserved and for mirroring, add original extra under "ext"
if s.MirrorTopLevelClaims {
topLevelExtraWithMirrorExt["ext"] = s.Extra
}

claims := &jwt.JWTClaims{
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
//set our custom extra map as claims.Extra
// set our custom extra map as claims.Extra
Extra: topLevelExtraWithMirrorExt,
ExpiresAt: s.GetExpiresAt(fosite.AccessToken),
IssuedAt: time.Now(),
Expand Down Expand Up @@ -185,8 +184,11 @@ func (s *Session) UnmarshalJSON(original []byte) (err error) {
}
}

// https://github.com/go-jose/go-jose/issues/144
dec := jjson.NewDecoder(bytes.NewReader(transformed))
dec.SetNumberType(jjson.UnmarshalIntOrFloat)
type t Session
if err := json.Unmarshal(transformed, (*t)(s)); err != nil {
if err := dec.Decode((*t)(s)); err != nil {
return errors.WithStack(err)
}

Expand Down
5 changes: 3 additions & 2 deletions oauth2/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func TestUnmarshalSession(t *testing.T) {
AuthenticationMethodsReferences: []string{},
CodeHash: "",
Extra: map[string]interface{}{
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027,
},
},
Headers: &jwt.Headers{Extra: map[string]interface{}{
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestUnmarshalSession(t *testing.T) {
snapshotx.SnapshotTExcept(t, &actual, nil)
})

t.Run("v1.11.9", func(t *testing.T) {
t.Run("v1.11.9" /* and later versions */, func(t *testing.T) {
var actual Session
require.NoError(t, json.Unmarshal(v1119Session, &actual))
assertx.EqualAsJSON(t, expect, &actual)
Expand Down

0 comments on commit a28aed9

Please sign in to comment.