Skip to content

Commit

Permalink
fix: don't query by raw signature
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Aug 3, 2023
1 parent ca68fe9 commit 8b0fae6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 36 deletions.
12 changes: 7 additions & 5 deletions persistence/sql/persister_nid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,24 @@ import (
type PersisterTestSuite struct {
suite.Suite
registries map[string]driver.Registry
clean func(*testing.T)
t1 context.Context
t2 context.Context
t1NID uuid.UUID
t2NID uuid.UUID
}

var _ PersisterTestSuite = PersisterTestSuite{}
var _ interface {
suite.SetupAllSuite
suite.TearDownTestSuite
} = (*PersisterTestSuite)(nil)

func (s *PersisterTestSuite) SetupSuite() {
s.registries = map[string]driver.Registry{
"memory": internal.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}),
}

if !testing.Short() {
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
}

s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4())
Expand Down Expand Up @@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() {
require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig))

actual := persistencesql.OAuth2RequestSQL{Table: "access"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
require.Equal(t, s.t1NID, actual.NID)

require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
})
}
}
Expand Down
44 changes: 17 additions & 27 deletions persistence/sql/persister_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin
}
}

signature := rawSignature
if table == sqlTableAccess {
signature = SignatureHash(signature)
}

return &OAuth2RequestSQL{
Request: r.GetID(),
ConsentChallenge: challenge,
ID: p.hashSignature(ctx, rawSignature, table),
ID: signature,
RequestedAt: r.GetRequestedAt(),
Client: r.GetClient().GetID(),
Scopes: strings.Join(r.GetRequestedScopes(), "|"),
Expand Down Expand Up @@ -166,14 +171,6 @@ func SignatureHash(signature string) string {
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
}

// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk).
func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string {
if table == sqlTableAccess {
return SignatureHash(signature)
}
return signature
}

func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid")
defer otelx.End(span, &err)
Expand Down Expand Up @@ -242,19 +239,16 @@ func (p *Persister) createSession(ctx context.Context, signature string, request
return nil
}

func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature")
defer otelx.End(span, &err)

r := OAuth2RequestSQL{Table: table}
if table == sqlTableAccess {
signature = SignatureHash(signature)
}

// We look for the signature as well as the hash of the signature here.
// This is because we now always store the hash of the signature in the database,
// regardless of the type of the signature. In previous versions, we only stored
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r)
r := OAuth2RequestSQL{Table: table}
err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrNotFound)
} else if err != nil {
Expand All @@ -276,17 +270,13 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature")
defer otelx.End(span, &err)

signature = p.hashSignature(ctx, signature, table)
if table == sqlTableAccess {
signature = SignatureHash(signature)
}

// We look for the signature as well as the hash of the signature here.
// This is because we now always store the hash of the signature in the database,
// regardless of the type of the signature. In previous versions, we only stored
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature IN (?, ?)", signature, SignatureHash(signature)).
Where("signature = ?", signature).
Delete(&OAuth2RequestSQL{Table: table}))

if errors.Is(err, sqlcon.ErrNoRows) {
Expand Down Expand Up @@ -356,7 +346,7 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
signature,
p.NetworkID(ctx),
).
Expand Down
2 changes: 0 additions & 2 deletions x/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) {
l.Logger.Out = buf
LogAudit(r, tc.message, l)

t.Logf("%s", buf.String())

assert.Contains(t, buf.String(), "audience=audit")
for _, expectContain := range tc.expectContains {
assert.Contains(t, buf.String(), expectContain)
Expand Down
3 changes: 1 addition & 2 deletions x/clean_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
)

func DeleteHydraRows(t *testing.T, c *pop.Connection) {
t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name())
for _, tb := range []string{
"hydra_oauth2_access",
"hydra_oauth2_refresh",
Expand Down Expand Up @@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) {
"schema_migration",
} {
if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil {
t.Logf(`Unable to clean up table "%s": %s`, tb, err)
t.Fatalf(`Unable to clean up table "%s": %s`, tb, err)
}
}
t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())
Expand Down

0 comments on commit 8b0fae6

Please sign in to comment.