diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 6ad1c937aec..83fad7c1452 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -40,14 +40,16 @@ import ( type PersisterTestSuite struct { suite.Suite registries map[string]driver.Registry - clean func(*testing.T) t1 context.Context t2 context.Context t1NID uuid.UUID t2NID uuid.UUID } -var _ PersisterTestSuite = PersisterTestSuite{} +var _ interface { + suite.SetupAllSuite + suite.TearDownTestSuite +} = (*PersisterTestSuite)(nil) func (s *PersisterTestSuite) SetupSuite() { s.registries = map[string]driver.Registry{ @@ -55,7 +57,7 @@ func (s *PersisterTestSuite) SetupSuite() { } if !testing.Short() { - s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) + s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{}) } s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4()) @@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() { require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig)) actual := persistencesql.OAuth2RequestSQL{Table: "access"} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) require.Equal(t, s.t1NID, actual.NID) require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig)) - require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig)) + require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig))) }) } } diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index fb2faba6c0c..245a6f55efc 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -67,7 +67,7 @@ func (r OAuth2RequestSQL) TableName() string { return "hydra_oauth2_" + string(r.Table) } -func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) { +func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) { subject := "" if r.GetSession() == nil { p.l.Debugf("Got an empty session in sqlSchemaFromRequest") @@ -101,7 +101,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin return &OAuth2RequestSQL{ Request: r.GetID(), ConsentChallenge: challenge, - ID: p.hashSignature(ctx, rawSignature, table), + ID: signature, RequestedAt: r.GetRequestedAt(), Client: r.GetClient().GetID(), Scopes: strings.Join(r.GetRequestedScopes(), "|"), @@ -160,20 +160,6 @@ func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session }, nil } -// SignatureHash hashes the signature to prevent errors where the signature is -// longer than 128 characters (and thus doesn't fit into the pk). -func SignatureHash(signature string) string { - return fmt.Sprintf("%x", sha512.Sum384([]byte(signature))) -} - -// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk). -func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string { - if table == sqlTableAccess { - return SignatureHash(signature) - } - return signature -} - func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid") defer otelx.End(span, &err) @@ -242,19 +228,12 @@ func (p *Persister) createSession(ctx context.Context, signature string, request return nil } -func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) { +func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature") defer otelx.End(span, &err) r := OAuth2RequestSQL{Table: table} - - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. - err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r) + err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r) if errors.Is(err, sql.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrNotFound) } else if err != nil { @@ -276,17 +255,9 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature") defer otelx.End(span, &err) - signature = p.hashSignature(ctx, signature, table) - - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. err = sqlcon.HandleError( p.QueryWithNetwork(ctx). - Where("signature IN (?, ?)", signature, SignatureHash(signature)). + Where("signature = ?", signature). Delete(&OAuth2RequestSQL{Table: table})) if errors.Is(err, sqlcon.ErrNoRows) { @@ -356,7 +327,7 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur return sqlcon.HandleError( p.Connection(ctx). RawQuery( - fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()), + fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()), signature, p.NetworkID(ctx), ). @@ -364,6 +335,12 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur ) } +// SignatureHash hashes the signature to prevent errors where the signature is +// longer than 128 characters (and thus doesn't fit into the pk). +func SignatureHash(signature string) string { + return fmt.Sprintf("%x", sha512.Sum384([]byte(signature))) +} + func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAccessTokenSession") defer otelx.End(span, &err) @@ -372,19 +349,19 @@ func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature stri append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))..., ) - return p.createSession(ctx, signature, requester, sqlTableAccess) + return p.createSession(ctx, SignatureHash(signature), requester, sqlTableAccess) } func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession") defer otelx.End(span, &err) - return p.findSessionBySignature(ctx, signature, session, sqlTableAccess) + return p.findSessionBySignature(ctx, SignatureHash(signature), session, sqlTableAccess) } func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession") defer otelx.End(span, &err) - return p.deleteSessionBySignature(ctx, signature, sqlTableAccess) + return p.deleteSessionBySignature(ctx, SignatureHash(signature), sqlTableAccess) } func toEventOptions(requester fosite.Requester) []trace.EventOption { diff --git a/x/audit_test.go b/x/audit_test.go index ef563c04a53..0a4061551d2 100644 --- a/x/audit_test.go +++ b/x/audit_test.go @@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) { l.Logger.Out = buf LogAudit(r, tc.message, l) - t.Logf("%s", buf.String()) - assert.Contains(t, buf.String(), "audience=audit") for _, expectContain := range tc.expectContains { assert.Contains(t, buf.String(), expectContain) diff --git a/x/clean_sql.go b/x/clean_sql.go index 59628fb3f97..a02a9a054ce 100644 --- a/x/clean_sql.go +++ b/x/clean_sql.go @@ -10,7 +10,6 @@ import ( ) func DeleteHydraRows(t *testing.T, c *pop.Connection) { - t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name()) for _, tb := range []string{ "hydra_oauth2_access", "hydra_oauth2_refresh", @@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) { "schema_migration", } { if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil { - t.Logf(`Unable to clean up table "%s": %s`, tb, err) + t.Fatalf(`Unable to clean up table "%s": %s`, tb, err) } } t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())