diff --git a/.docker-home/.gitignore b/.docker-home/.gitignore deleted file mode 100644 index c96a04f008e..00000000000 --- a/.docker-home/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore \ No newline at end of file diff --git a/.docker/Dockerfile-alpine b/.docker/Dockerfile-alpine index 683cc224d19..27567a473cf 100644 --- a/.docker/Dockerfile-alpine +++ b/.docker/Dockerfile-alpine @@ -2,7 +2,7 @@ FROM alpine:3.18 RUN addgroup -S ory; \ adduser -S ory -G ory -D -H -s /bin/nologin -RUN apk --no-cache --upgrade --latest add ca-certificates +RUN apk --no-cache --upgrade add ca-certificates COPY hydra /usr/bin/hydra diff --git a/.docker/Dockerfile-build b/.docker/Dockerfile-build index 12cd5f6ed66..f1457cf2661 100644 --- a/.docker/Dockerfile-build +++ b/.docker/Dockerfile-build @@ -1,4 +1,5 @@ -FROM golang:1.19 AS builder +FROM golang:1.20-alpine3.17 AS builder + WORKDIR /go/src/github.com/ory/hydra diff --git a/.docker/Dockerfile-hsm b/.docker/Dockerfile-hsm index 72cdc015d7a..24582758092 100644 --- a/.docker/Dockerfile-hsm +++ b/.docker/Dockerfile-hsm @@ -1,4 +1,4 @@ -FROM golang:1.19 AS builder +FROM golang:1.20-alpine3.18 AS builder WORKDIR /go/src/github.com/ory/hydra diff --git a/.dockerignore b/.dockerignore index 4d913fbbc91..cf7558fc017 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,7 +3,6 @@ docs node_modules .circleci -.docker-home .github scripts sdk/js diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 54686c72e9e..853b6de937d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -23,9 +23,9 @@ jobs: # We must fetch at least the immediate parents so that if this is # a pull request then we can checkout the head. fetch-depth: 2 - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v3 with: - go-version: "1.19" + go-version: "1.20" - name: Start service run: ./test/conformance/start.sh - name: Run tests @@ -80,22 +80,21 @@ jobs: path: | internal/httpclient key: ${{ needs.sdk-generate.outputs.sdk-cache-key }} - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v4 with: - go-version: "1.19" + go-version: "1.20" - run: go list -json > go.list - name: Run nancy uses: sonatype-nexus-community/nancy-github-action@v1.0.2 with: nancyVersion: v1.0.42 - name: Run golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v3 env: GOGC: 100 with: args: --timeout 10m0s - version: v1.47.3 - skip-go-installation: true + version: v1.53.2 skip-pkg-cache: true - name: Run go-acc (tests) run: | @@ -124,9 +123,9 @@ jobs: path: | internal/httpclient key: ${{ needs.sdk-generate.outputs.sdk-cache-key }} - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v3 with: - go-version: "1.19" + go-version: "1.20" - name: Setup HSM libs and packages run: | sudo apt install -y softhsm opensc @@ -175,9 +174,9 @@ jobs: docker start cockroach name: Start CockroachDB - uses: ory/ci/checkout@master - - uses: actions/setup-go@v2 + - uses: actions/setup-go@v3 with: - go-version: "1.19" + go-version: "1.20" - uses: actions/cache@v2 with: path: ./test/e2e/hydra diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index a7a720ebc0a..80515a61723 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -11,7 +11,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: "1.20" - run: make format - name: Indicate formatting issues run: git diff HEAD --exit-code --color diff --git a/.github/workflows/licenses.yml b/.github/workflows/licenses.yml index a4592c63ced..6f219dbced1 100644 --- a/.github/workflows/licenses.yml +++ b/.github/workflows/licenses.yml @@ -11,10 +11,10 @@ jobs: check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-go@v3 with: - go-version: "1.18" + go-version: "1.20" - uses: actions/setup-node@v2 with: node-version: "18" diff --git a/.golangci.yml b/.golangci.yml index c3461c51f45..00ee1f9963c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,9 +5,7 @@ linters: - gosimple - bodyclose - staticcheck - # Disabled due to Go 1.19 changes and Go-Swagger incompatibility - # https://github.com/ory/hydra/issues/3227 - # - goimports + - goimports disable: - ineffassign - deadcode diff --git a/.grype.yml b/.grype.yml new file mode 100644 index 00000000000..56d262246ac --- /dev/null +++ b/.grype.yml @@ -0,0 +1,2 @@ +ignore: + - vulnerability: CVE-2023-2650 diff --git a/.schema/config.schema.json b/.schema/config.schema.json index f71021f34d9..821f4e24eae 100644 --- a/.schema/config.schema.json +++ b/.schema/config.schema.json @@ -1040,7 +1040,7 @@ "examples": ["cpu"] }, "tracing": { - "$ref": "https://raw.githubusercontent.com/ory/x/v0.0.555/otelx/config.schema.json" + "$ref": "https://raw.githubusercontent.com/ory/x/v0.0.559/otelx/config.schema.json" }, "sqa": { "type": "object", diff --git a/.schema/version.schema.json b/.schema/version.schema.json index c29393045cc..2377ac483f1 100644 --- a/.schema/version.schema.json +++ b/.schema/version.schema.json @@ -2,6 +2,23 @@ "$id": "https://github.com/ory/hydra/.schema/versions.config.schema.json", "$schema": "http://json-schema.org/draft-07/schema#", "oneOf": [ + { + "allOf": [ + { + "properties": { + "version": { + "const": "v2.2.0-rc.2" + } + }, + "required": [ + "version" + ] + }, + { + "$ref": "https://raw.githubusercontent.com/ory/hydra/v2.2.0-rc.2/.schema/config.schema.json" + } + ] + }, { "allOf": [ { diff --git a/.trivyignore b/.trivyignore new file mode 100644 index 00000000000..73859219e24 --- /dev/null +++ b/.trivyignore @@ -0,0 +1 @@ +CVE-2023-2650 diff --git a/CHANGELOG.md b/CHANGELOG.md index 811365a9c83..62fd9d64b4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,265 +4,273 @@ **Table of Contents** -- [0.0.0 (2023-06-10)](#000-2023-06-10) - - [Bug Fixes](#bug-fixes) +- [0.0.0 (2023-06-19)](#000-2023-06-19) - [Features](#features) +- [2.2.0-rc.2 (2023-06-13)](#220-rc2-2023-06-13) + - [Bug Fixes](#bug-fixes) + - [Code Generation](#code-generation) + - [Features](#features-1) +- [2.2.0-rc.1 (2023-06-12)](#220-rc1-2023-06-12) + - [Breaking Changes](#breaking-changes) + - [Bug Fixes](#bug-fixes-1) + - [Code Generation](#code-generation-1) + - [Features](#features-2) - [Unclassified](#unclassified) - [2.1.2 (2023-05-24)](#212-2023-05-24) - - [Bug Fixes](#bug-fixes-1) - - [Code Generation](#code-generation) + - [Bug Fixes](#bug-fixes-2) + - [Code Generation](#code-generation-2) - [Documentation](#documentation) - - [Features](#features-1) + - [Features](#features-3) - [2.1.1 (2023-04-11)](#211-2023-04-11) - - [Bug Fixes](#bug-fixes-2) - - [Code Generation](#code-generation-1) -- [2.1.0 (2023-04-06)](#210-2023-04-06) - [Bug Fixes](#bug-fixes-3) - - [Code Generation](#code-generation-2) -- [2.1.0-pre.2 (2023-04-03)](#210-pre2-2023-04-03) - [Code Generation](#code-generation-3) -- [2.1.0-pre.1 (2023-04-03)](#210-pre1-2023-04-03) - - [Code Generation](#code-generation-4) -- [2.1.0-pre.0 (2023-03-31)](#210-pre0-2023-03-31) +- [2.1.0 (2023-04-06)](#210-2023-04-06) - [Bug Fixes](#bug-fixes-4) + - [Code Generation](#code-generation-4) +- [2.1.0-pre.2 (2023-04-03)](#210-pre2-2023-04-03) - [Code Generation](#code-generation-5) +- [2.1.0-pre.1 (2023-04-03)](#210-pre1-2023-04-03) + - [Code Generation](#code-generation-6) +- [2.1.0-pre.0 (2023-03-31)](#210-pre0-2023-03-31) + - [Bug Fixes](#bug-fixes-5) + - [Code Generation](#code-generation-7) - [Documentation](#documentation-1) - - [Features](#features-2) + - [Features](#features-4) - [2.0.3 (2022-12-08)](#203-2022-12-08) - - [Bug Fixes](#bug-fixes-5) - - [Code Generation](#code-generation-6) - - [Features](#features-3) -- [2.0.2 (2022-11-10)](#202-2022-11-10) - [Bug Fixes](#bug-fixes-6) - - [Code Generation](#code-generation-7) + - [Code Generation](#code-generation-8) + - [Features](#features-5) +- [2.0.2 (2022-11-10)](#202-2022-11-10) + - [Bug Fixes](#bug-fixes-7) + - [Code Generation](#code-generation-9) - [Documentation](#documentation-2) - - [Features](#features-4) + - [Features](#features-6) - [Tests](#tests) - [2.0.1 (2022-10-27)](#201-2022-10-27) - - [Bug Fixes](#bug-fixes-7) - - [Code Generation](#code-generation-8) + - [Bug Fixes](#bug-fixes-8) + - [Code Generation](#code-generation-10) - [Documentation](#documentation-3) - [2.0.0 (2022-10-27)](#200-2022-10-27) - - [Breaking Changes](#breaking-changes) - - [Bug Fixes](#bug-fixes-8) - - [Code Generation](#code-generation-9) + - [Breaking Changes](#breaking-changes-1) + - [Bug Fixes](#bug-fixes-9) + - [Code Generation](#code-generation-11) - [Code Refactoring](#code-refactoring) - [Documentation](#documentation-4) - - [Features](#features-5) + - [Features](#features-7) - [Tests](#tests-1) - [Unclassified](#unclassified-1) - [1.11.10 (2022-08-25)](#11110-2022-08-25) - - [Bug Fixes](#bug-fixes-9) - - [Code Generation](#code-generation-10) -- [1.11.9 (2022-08-01)](#1119-2022-08-01) - [Bug Fixes](#bug-fixes-10) - - [Code Generation](#code-generation-11) + - [Code Generation](#code-generation-12) +- [1.11.9 (2022-08-01)](#1119-2022-08-01) + - [Bug Fixes](#bug-fixes-11) + - [Code Generation](#code-generation-13) - [Documentation](#documentation-5) - - [Features](#features-6) + - [Features](#features-8) - [1.11.8 (2022-05-04)](#1118-2022-05-04) - - [Bug Fixes](#bug-fixes-11) - - [Code Generation](#code-generation-12) + - [Bug Fixes](#bug-fixes-12) + - [Code Generation](#code-generation-14) - [Documentation](#documentation-6) - - [Features](#features-7) + - [Features](#features-9) - [Tests](#tests-2) - [1.11.7 (2022-02-23)](#1117-2022-02-23) - - [Code Generation](#code-generation-13) + - [Code Generation](#code-generation-15) - [1.11.6 (2022-02-23)](#1116-2022-02-23) - - [Bug Fixes](#bug-fixes-12) - - [Code Generation](#code-generation-14) -- [1.11.5 (2022-02-21)](#1115-2022-02-21) - [Bug Fixes](#bug-fixes-13) - - [Code Generation](#code-generation-15) -- [1.11.4 (2022-02-16)](#1114-2022-02-16) - - [Bug Fixes](#bug-fixes-14) - [Code Generation](#code-generation-16) -- [1.11.3 (2022-02-15)](#1113-2022-02-15) - - [Bug Fixes](#bug-fixes-15) +- [1.11.5 (2022-02-21)](#1115-2022-02-21) + - [Bug Fixes](#bug-fixes-14) - [Code Generation](#code-generation-17) -- [1.11.2 (2022-02-11)](#1112-2022-02-11) +- [1.11.4 (2022-02-16)](#1114-2022-02-16) + - [Bug Fixes](#bug-fixes-15) - [Code Generation](#code-generation-18) -- [1.11.1 (2022-02-11)](#1111-2022-02-11) +- [1.11.3 (2022-02-15)](#1113-2022-02-15) - [Bug Fixes](#bug-fixes-16) - [Code Generation](#code-generation-19) +- [1.11.2 (2022-02-11)](#1112-2022-02-11) + - [Code Generation](#code-generation-20) +- [1.11.1 (2022-02-11)](#1111-2022-02-11) + - [Bug Fixes](#bug-fixes-17) + - [Code Generation](#code-generation-21) - [Code Refactoring](#code-refactoring-1) - [Documentation](#documentation-7) - [1.11.0 (2022-01-21)](#1110-2022-01-21) - - [Breaking Changes](#breaking-changes-1) - - [Bug Fixes](#bug-fixes-17) - - [Code Generation](#code-generation-20) - - [Documentation](#documentation-8) - - [Features](#features-8) -- [1.10.7 (2021-10-27)](#1107-2021-10-27) - [Breaking Changes](#breaking-changes-2) - [Bug Fixes](#bug-fixes-18) - - [Code Generation](#code-generation-21) + - [Code Generation](#code-generation-22) + - [Documentation](#documentation-8) + - [Features](#features-10) +- [1.10.7 (2021-10-27)](#1107-2021-10-27) + - [Breaking Changes](#breaking-changes-3) + - [Bug Fixes](#bug-fixes-19) + - [Code Generation](#code-generation-23) - [Code Refactoring](#code-refactoring-2) - [Documentation](#documentation-9) - - [Features](#features-9) + - [Features](#features-11) - [1.10.6 (2021-08-28)](#1106-2021-08-28) - - [Bug Fixes](#bug-fixes-19) - - [Code Generation](#code-generation-22) + - [Bug Fixes](#bug-fixes-20) + - [Code Generation](#code-generation-24) - [Documentation](#documentation-10) - [1.10.5 (2021-08-13)](#1105-2021-08-13) - - [Bug Fixes](#bug-fixes-20) - - [Code Generation](#code-generation-23) + - [Bug Fixes](#bug-fixes-21) + - [Code Generation](#code-generation-25) - [Documentation](#documentation-11) - - [Features](#features-10) + - [Features](#features-12) - [1.10.3 (2021-07-14)](#1103-2021-07-14) - - [Bug Fixes](#bug-fixes-21) - - [Code Generation](#code-generation-24) + - [Bug Fixes](#bug-fixes-22) + - [Code Generation](#code-generation-26) - [Code Refactoring](#code-refactoring-3) - [Documentation](#documentation-12) - - [Features](#features-11) + - [Features](#features-13) - [1.10.2 (2021-05-04)](#1102-2021-05-04) - - [Breaking Changes](#breaking-changes-3) - - [Bug Fixes](#bug-fixes-22) - - [Code Generation](#code-generation-25) + - [Breaking Changes](#breaking-changes-4) + - [Bug Fixes](#bug-fixes-23) + - [Code Generation](#code-generation-27) - [Code Refactoring](#code-refactoring-4) - [Documentation](#documentation-13) - - [Features](#features-12) + - [Features](#features-14) - [1.10.1 (2021-03-25)](#1101-2021-03-25) - - [Bug Fixes](#bug-fixes-23) - - [Code Generation](#code-generation-26) + - [Bug Fixes](#bug-fixes-24) + - [Code Generation](#code-generation-28) - [Documentation](#documentation-14) - - [Features](#features-13) + - [Features](#features-15) - [Tests](#tests-3) - [Unclassified](#unclassified-2) - [1.9.2 (2021-01-29)](#192-2021-01-29) - - [Code Generation](#code-generation-27) - - [Features](#features-14) + - [Code Generation](#code-generation-29) + - [Features](#features-16) - [1.9.1 (2021-01-27)](#191-2021-01-27) - - [Code Generation](#code-generation-28) + - [Code Generation](#code-generation-30) - [Documentation](#documentation-15) - [1.9.0 (2021-01-12)](#190-2021-01-12) - - [Code Generation](#code-generation-29) + - [Code Generation](#code-generation-31) - [1.9.0-rc.0 (2021-01-12)](#190-rc0-2021-01-12) - - [Code Generation](#code-generation-30) + - [Code Generation](#code-generation-32) - [1.9.0-alpha.4.pre.0 (2021-01-12)](#190-alpha4pre0-2021-01-12) - - [Bug Fixes](#bug-fixes-24) - - [Code Generation](#code-generation-31) + - [Bug Fixes](#bug-fixes-25) + - [Code Generation](#code-generation-33) - [Documentation](#documentation-16) - [1.9.0-alpha.3 (2020-12-08)](#190-alpha3-2020-12-08) - - [Breaking Changes](#breaking-changes-4) - - [Bug Fixes](#bug-fixes-25) - - [Code Generation](#code-generation-32) + - [Breaking Changes](#breaking-changes-5) + - [Bug Fixes](#bug-fixes-26) + - [Code Generation](#code-generation-34) - [Code Refactoring](#code-refactoring-5) - [Documentation](#documentation-17) - - [Features](#features-15) + - [Features](#features-17) - [Tests](#tests-4) - [Unclassified](#unclassified-3) - [1.9.0-alpha.2 (2020-10-29)](#190-alpha2-2020-10-29) - - [Bug Fixes](#bug-fixes-26) - - [Code Generation](#code-generation-33) + - [Bug Fixes](#bug-fixes-27) + - [Code Generation](#code-generation-35) - [Documentation](#documentation-18) - - [Features](#features-16) + - [Features](#features-18) - [Tests](#tests-5) - [1.9.0-alpha.1 (2020-10-20)](#190-alpha1-2020-10-20) - - [Bug Fixes](#bug-fixes-27) - - [Code Generation](#code-generation-34) + - [Bug Fixes](#bug-fixes-28) + - [Code Generation](#code-generation-36) - [Code Refactoring](#code-refactoring-6) - [Documentation](#documentation-19) - - [Features](#features-17) + - [Features](#features-19) - [Tests](#tests-6) - [1.8.5 (2020-10-03)](#185-2020-10-03) - - [Code Generation](#code-generation-35) + - [Code Generation](#code-generation-37) - [1.8.0-pre.1 (2020-10-03)](#180-pre1-2020-10-03) - - [Bug Fixes](#bug-fixes-28) - - [Code Generation](#code-generation-36) - - [Features](#features-18) -- [1.8.0-pre.0 (2020-10-02)](#180-pre0-2020-10-02) - - [Breaking Changes](#breaking-changes-5) - [Bug Fixes](#bug-fixes-29) - - [Code Generation](#code-generation-37) - - [Documentation](#documentation-20) - - [Features](#features-19) -- [1.7.4 (2020-08-31)](#174-2020-08-31) - - [Bug Fixes](#bug-fixes-30) - [Code Generation](#code-generation-38) -- [1.7.3 (2020-08-31)](#173-2020-08-31) - - [Code Generation](#code-generation-39) -- [1.7.1 (2020-08-31)](#171-2020-08-31) + - [Features](#features-20) +- [1.8.0-pre.0 (2020-10-02)](#180-pre0-2020-10-02) - [Breaking Changes](#breaking-changes-6) + - [Bug Fixes](#bug-fixes-30) + - [Code Generation](#code-generation-39) + - [Documentation](#documentation-20) + - [Features](#features-21) +- [1.7.4 (2020-08-31)](#174-2020-08-31) - [Bug Fixes](#bug-fixes-31) - [Code Generation](#code-generation-40) +- [1.7.3 (2020-08-31)](#173-2020-08-31) + - [Code Generation](#code-generation-41) +- [1.7.1 (2020-08-31)](#171-2020-08-31) + - [Breaking Changes](#breaking-changes-7) + - [Bug Fixes](#bug-fixes-32) + - [Code Generation](#code-generation-42) - [Code Refactoring](#code-refactoring-7) - [Documentation](#documentation-21) - - [Features](#features-20) + - [Features](#features-22) - [Unclassified](#unclassified-4) - [1.7.0 (2020-08-14)](#170-2020-08-14) - - [Breaking Changes](#breaking-changes-7) - - [Bug Fixes](#bug-fixes-32) - - [Code Generation](#code-generation-41) + - [Breaking Changes](#breaking-changes-8) + - [Bug Fixes](#bug-fixes-33) + - [Code Generation](#code-generation-43) - [Code Refactoring](#code-refactoring-8) - [Documentation](#documentation-22) - - [Features](#features-21) + - [Features](#features-23) - [Unclassified](#unclassified-5) - [1.6.0 (2020-07-20)](#160-2020-07-20) - - [Bug Fixes](#bug-fixes-33) - - [Code Generation](#code-generation-42) + - [Bug Fixes](#bug-fixes-34) + - [Code Generation](#code-generation-44) - [Documentation](#documentation-23) - [Unclassified](#unclassified-6) - [1.5.2 (2020-06-23)](#152-2020-06-23) - - [Bug Fixes](#bug-fixes-34) - - [Code Generation](#code-generation-43) - - [Features](#features-22) + - [Bug Fixes](#bug-fixes-35) + - [Code Generation](#code-generation-45) + - [Features](#features-24) - [1.5.1 (2020-06-16)](#151-2020-06-16) - - [Code Generation](#code-generation-44) + - [Code Generation](#code-generation-46) - [1.5.0 (2020-06-16)](#150-2020-06-16) - - [Bug Fixes](#bug-fixes-35) + - [Bug Fixes](#bug-fixes-36) - [Chores](#chores) - [Documentation](#documentation-24) - - [Features](#features-23) + - [Features](#features-25) - [Unclassified](#unclassified-7) - [1.5.0-beta.5 (2020-05-28)](#150-beta5-2020-05-28) - - [Bug Fixes](#bug-fixes-36) + - [Bug Fixes](#bug-fixes-37) - [Chores](#chores-1) - [Documentation](#documentation-25) - - [Features](#features-24) + - [Features](#features-26) - [1.5.0-beta.3 (2020-05-23)](#150-beta3-2020-05-23) - [Chores](#chores-2) - [1.5.0-beta.2 (2020-05-23)](#150-beta2-2020-05-23) - - [Bug Fixes](#bug-fixes-37) + - [Bug Fixes](#bug-fixes-38) - [Chores](#chores-3) - [Code Refactoring](#code-refactoring-9) - [Documentation](#documentation-26) - [1.5.0-beta.1 (2020-04-30)](#150-beta1-2020-04-30) - - [Breaking Changes](#breaking-changes-8) + - [Breaking Changes](#breaking-changes-9) - [Chores](#chores-4) - [Code Refactoring](#code-refactoring-10) - [1.4.10 (2020-04-30)](#1410-2020-04-30) - - [Bug Fixes](#bug-fixes-38) + - [Bug Fixes](#bug-fixes-39) - [Chores](#chores-5) - [Documentation](#documentation-27) - [Unclassified](#unclassified-8) - [1.4.9 (2020-04-25)](#149-2020-04-25) - - [Bug Fixes](#bug-fixes-39) + - [Bug Fixes](#bug-fixes-40) - [Chores](#chores-6) - [1.4.8 (2020-04-24)](#148-2020-04-24) - - [Bug Fixes](#bug-fixes-40) + - [Bug Fixes](#bug-fixes-41) - [Chores](#chores-7) - [Documentation](#documentation-28) - - [Features](#features-25) + - [Features](#features-27) - [1.4.7 (2020-04-24)](#147-2020-04-24) - - [Bug Fixes](#bug-fixes-41) + - [Bug Fixes](#bug-fixes-42) - [Chores](#chores-8) - [Documentation](#documentation-29) - [1.4.6 (2020-04-17)](#146-2020-04-17) - - [Bug Fixes](#bug-fixes-42) + - [Bug Fixes](#bug-fixes-43) - [Documentation](#documentation-30) - [1.4.5 (2020-04-16)](#145-2020-04-16) - - [Bug Fixes](#bug-fixes-43) + - [Bug Fixes](#bug-fixes-44) - [Documentation](#documentation-31) - [1.4.3 (2020-04-16)](#143-2020-04-16) - - [Bug Fixes](#bug-fixes-44) + - [Bug Fixes](#bug-fixes-45) - [Code Refactoring](#code-refactoring-11) - [Documentation](#documentation-32) - - [Features](#features-26) + - [Features](#features-28) - [1.4.2 (2020-04-03)](#142-2020-04-03) - [Chores](#chores-9) - [Documentation](#documentation-33) - [1.4.1 (2020-04-02)](#141-2020-04-02) - - [Bug Fixes](#bug-fixes-45) + - [Bug Fixes](#bug-fixes-46) - [1.4.0 (2020-04-02)](#140-2020-04-02) - [GHSA-3p3g-vpw6-4w66](#ghsa-3p3g-vpw6-4w66) - [Impact](#impact) @@ -271,7 +279,7 @@ - [Workarounds](#workarounds) - [References](#references) - [Upstream](#upstream) - - [Breaking Changes](#breaking-changes-9) + - [Breaking Changes](#breaking-changes-10) - [GHSA-3p3g-vpw6-4w66](#ghsa-3p3g-vpw6-4w66-1) - [Impact](#impact-1) - [Severity](#severity-1) @@ -279,21 +287,21 @@ - [Workarounds](#workarounds-1) - [References](#references-1) - [Upstream](#upstream-1) - - [Bug Fixes](#bug-fixes-46) + - [Bug Fixes](#bug-fixes-47) - [Code Refactoring](#code-refactoring-12) - [Documentation](#documentation-34) - - [Features](#features-27) + - [Features](#features-29) - [Unclassified](#unclassified-9) - [1.3.2 (2020-02-17)](#132-2020-02-17) - - [Bug Fixes](#bug-fixes-47) + - [Bug Fixes](#bug-fixes-48) - [Chores](#chores-10) - [Documentation](#documentation-35) - [1.3.1 (2020-02-16)](#131-2020-02-16) - [Continuous Integration](#continuous-integration) - [1.3.0 (2020-02-14)](#130-2020-02-14) - - [Bug Fixes](#bug-fixes-48) + - [Bug Fixes](#bug-fixes-49) - [Documentation](#documentation-36) - - [Features](#features-28) + - [Features](#features-30) - [Unclassified](#unclassified-10) - [1.2.3 (2020-01-31)](#123-2020-01-31) - [Unclassified](#unclassified-11) @@ -686,7 +694,57 @@ -# [0.0.0](https://github.com/ory/hydra/compare/v2.1.2...v0.0.0) (2023-06-10) +# [0.0.0](https://github.com/ory/hydra/compare/v2.2.0-rc.2...v0.0.0) (2023-06-19) + + +### Features + +* Add event tracing ([#3546](https://github.com/ory/hydra/issues/3546)) ([44ed0ac](https://github.com/ory/hydra/commit/44ed0ac89558bd83513e5240e8c937c908514d76)) + + +# [2.2.0-rc.2](https://github.com/ory/hydra/compare/v2.2.0-rc.1...v2.2.0-rc.2) (2023-06-13) + +This release optimizes the performance of authorization code grant flows by minimizing the number of database queries. We acheive this by storing the flow in an AEAD-encoded cookie and AEAD-encoded request parameters for the authentication and consent screens. + +BREAKING CHANGE: + +* The client that is used as part of the authorization grant flow is stored in the AEAD-encoding. Therefore, running flows will not observe updates to the client after they were started. +* Because the login and consent challenge values now include the AEAD-encoded flow, their size increased to around 1kB for a flow without any metadata (and increases linearly with the amount of metadata). Please adjust your ingress / gateway accordingly. + + + + + +### Bug Fixes + +* Version clash in apk install ([24ebdd3](https://github.com/ory/hydra/commit/24ebdd3feb302f655000a243dad032b04cf25afc)) + +### Code Generation + +* Pin v2.2.0-rc.2 release commit ([b183040](https://github.com/ory/hydra/commit/b183040a0d6c33abd4db01eb21a1bb0e141ea9ec)) + +### Features + +* Hot-reload Oauth2 CORS settings ([#3537](https://github.com/ory/hydra/issues/3537)) ([a8ecf80](https://github.com/ory/hydra/commit/a8ecf807b2c6bfa6cc2d8b474f527a2fda12daef)) +* Sqa metrics v2 ([#3533](https://github.com/ory/hydra/issues/3533)) ([3ec683d](https://github.com/ory/hydra/commit/3ec683d7cf582443f29bd93c4c88392b3ce692a4)) + + +# [2.2.0-rc.1](https://github.com/ory/hydra/compare/v2.1.2...v2.2.0-rc.1) (2023-06-12) + +This release optimizes the performance of authorization code grant flows by minimizing the number of database queries. We acheive this by storing the flow in an AEAD-encoded cookie and AEAD-encoded request parameters for the authentication and consent screens. + +BREAKING CHANGE: + +* The client that is used as part of the authorization grant flow is stored in the AEAD-encoding. Therefore, running flows will not observe updates to the client after they were started. +* Because the login and consent challenge values now include the AEAD-encoded flow, their size increased to around 1kB for a flow without any metadata (and increases linearly with the amount of metadata). Please adjust your ingress / gateway accordingly. + + + +## Breaking Changes + +* The client that is used as part of the authorization grant flow is stored in the AEAD-encoding. Therefore, running flows will not observe updates to the client after they were started. +* Because the login and consent challenge values now include the AEAD-encoded flow, their size increased to around 1kB for a flow without any metadata (and increases linearly with the amount of metadata). Please adjust your ingress / gateway accordingly. + ### Bug Fixes @@ -698,9 +756,17 @@ v1.11.10 to v2. +### Code Generation + +* Pin v2.2.0-rc.1 release commit ([262ebbb](https://github.com/ory/hydra/commit/262ebbb5a7a585a26117a8c0fba6c257fc97b7b4)) + ### Features * Add metrics to disabled access log ([#3526](https://github.com/ory/hydra/issues/3526)) ([fc7af90](https://github.com/ory/hydra/commit/fc7af904407b27d1b5c0e5e62f82fd81ab81ecb2)) +* Stateless authorization code flow ([#3515](https://github.com/ory/hydra/issues/3515)) ([f29fe3a](https://github.com/ory/hydra/commit/f29fe3af97fb72061f2d6d7a2fc454cea5e870e9)): + + This patch optimizes the performance of authorization code grant flows by minimizing the number of database queries. We acheive this by storing the flow in an AEAD-encoded cookie and AEAD-encoded request parameters for the authentication and consent screens. + ### Unclassified diff --git a/Makefile b/Makefile index a38c1436dc1..8adf128efd3 100644 --- a/Makefile +++ b/Makefile @@ -5,10 +5,11 @@ export PATH := .bin:${PATH} export PWD := $(shell pwd) export IMAGE_TAG := $(if $(IMAGE_TAG),$(IMAGE_TAG),latest) -GOLANGCI_LINT_VERSION = 1.46.2 +GOLANGCI_LINT_VERSION = 1.53.2 GO_DEPENDENCIES = github.com/ory/go-acc \ github.com/golang/mock/mockgen \ + golang.org/x/tools/cmd/goimports \ github.com/go-swagger/go-swagger/cmd/swagger define make-go-dependency @@ -37,9 +38,6 @@ node_modules: package-lock.json docs/cli: .bin/clidoc clidoc . -.bin/goimports: go.sum Makefile - GOBIN=$(shell pwd)/.bin go install golang.org/x/tools/cmd/goimports@latest - .bin/licenses: Makefile curl https://raw.githubusercontent.com/ory/ci/master/licenses/install | sh @@ -63,12 +61,9 @@ test: .bin/go-acc # Resets the test databases .PHONY: test-resetdb test-resetdb: node_modules - docker kill hydra_test_database_mysql || true - docker kill hydra_test_database_postgres || true - docker kill hydra_test_database_cockroach || true - docker rm -f hydra_test_database_mysql || true - docker rm -f hydra_test_database_postgres || true - docker rm -f hydra_test_database_cockroach || true + docker rm --force --volumes hydra_test_database_mysql || true + docker rm --force --volumes hydra_test_database_postgres || true + docker rm --force --volumes hydra_test_database_cockroach || true docker run --rm --name hydra_test_database_mysql --platform linux/amd64 -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.26 docker run --rm --name hydra_test_database_postgres --platform linux/amd64 -p 3445:5432 -e POSTGRES_PASSWORD=secret -e POSTGRES_DB=postgres -d postgres:11.8 docker run --rm --name hydra_test_database_cockroach --platform linux/amd64 -p 3446:26257 -d cockroachdb/cockroach:v22.1.10 start-single-node --insecure @@ -122,6 +117,7 @@ sdk: .bin/swagger .bin/ory node_modules swagger generate spec -m -o spec/swagger.json \ -c github.com/ory/hydra/v2/client \ -c github.com/ory/hydra/v2/consent \ + -c github.com/ory/hydra/v2/flow \ -c github.com/ory/hydra/v2/health \ -c github.com/ory/hydra/v2/jwk \ -c github.com/ory/hydra/v2/oauth2 \ diff --git a/aead/aead.go b/aead/aead.go new file mode 100644 index 00000000000..a3cb8b89ffe --- /dev/null +++ b/aead/aead.go @@ -0,0 +1,28 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package aead + +import ( + "context" + + "github.com/ory/fosite" +) + +// Cipher provides AEAD (authenticated encryption with associated data). The +// ciphertext is returned base64url-encoded. +type Cipher interface { + // Encrypt encrypts and encodes the given plaintext, optionally using + // additiona data. + Encrypt(ctx context.Context, plaintext, additionalData []byte) (ciphertext string, err error) + + // Decrypt decodes, decrypts, and verifies the plaintext and additional data + // from the ciphertext. The ciphertext must be given in the form as returned + // by Encrypt. + Decrypt(ctx context.Context, ciphertext string, additionalData []byte) (plaintext []byte, err error) +} + +type Dependencies interface { + fosite.GlobalSecretProvider + fosite.RotatedGlobalSecretsProvider +} diff --git a/aead/aead_test.go b/aead/aead_test.go new file mode 100644 index 00000000000..4cb93f5c3e7 --- /dev/null +++ b/aead/aead_test.go @@ -0,0 +1,154 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package aead_test + +import ( + "context" + "crypto/rand" + "fmt" + "io" + "testing" + + "github.com/ory/hydra/v2/aead" + "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/internal" + + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func secret(t *testing.T) string { + bytes := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, bytes) + require.NoError(t, err) + return fmt.Sprintf("%X", bytes) +} + +func TestAEAD(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + name string + new func(aead.Dependencies) aead.Cipher + }{ + {"AES-GCM", func(d aead.Dependencies) aead.Cipher { return aead.NewAESGCM(d) }}, + {"XChaChaPoly", func(d aead.Dependencies) aead.Cipher { return aead.NewXChaCha20Poly1305(d) }}, + } { + tc := tc + + t.Run("cipher="+tc.name, func(t *testing.T) { + NewCipher := tc.new + + t.Run("case=without-rotation", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + c := internal.NewConfigurationWithDefaults() + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) + a := NewCipher(c) + + plain := []byte(uuid.New()) + ct, err := a.Encrypt(ctx, plain, nil) + assert.NoError(t, err) + + ct2, err := a.Encrypt(ctx, plain, nil) + assert.NoError(t, err) + assert.NotEqual(t, ct, ct2, "ciphertexts for the same plaintext must be different each time") + + res, err := a.Decrypt(ctx, ct, nil) + assert.NoError(t, err) + assert.Equal(t, plain, res) + }) + + t.Run("case=wrong-secret", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + c := internal.NewConfigurationWithDefaults() + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) + a := NewCipher(c) + + ct, err := a.Encrypt(ctx, []byte(uuid.New()), nil) + require.NoError(t, err) + + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) + _, err = a.Decrypt(ctx, ct, nil) + require.Error(t, err) + }) + + t.Run("case=with-rotation", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + c := internal.NewConfigurationWithDefaults() + old := secret(t) + c.MustSet(ctx, config.KeyGetSystemSecret, []string{old}) + a := NewCipher(c) + + plain := []byte(uuid.New()) + ct, err := a.Encrypt(ctx, plain, nil) + require.NoError(t, err) + + // Sets the old secret as a rotated secret and creates a new one. + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), old}) + res, err := a.Decrypt(ctx, ct, nil) + require.NoError(t, err) + assert.Equal(t, plain, res) + + // THis should also work when we re-encrypt the same plain text. + ct2, err := a.Encrypt(ctx, plain, nil) + require.NoError(t, err) + assert.NotEqual(t, ct2, ct) + + res, err = a.Decrypt(ctx, ct, nil) + require.NoError(t, err) + assert.Equal(t, plain, res) + }) + + t.Run("case=with-rotation-wrong-secret", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + c := internal.NewConfigurationWithDefaults() + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) + a := NewCipher(c) + + plain := []byte(uuid.New()) + ct, err := a.Encrypt(ctx, plain, nil) + require.NoError(t, err) + + // When the secrets do not match, an error should be thrown during decryption. + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), secret(t)}) + _, err = a.Decrypt(ctx, ct, nil) + require.Error(t, err) + }) + + t.Run("suite=with additional data", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + c := internal.NewConfigurationWithDefaults() + c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) + a := NewCipher(c) + + plain := []byte(uuid.New()) + ct, err := a.Encrypt(ctx, plain, []byte("additional data")) + assert.NoError(t, err) + + t.Run("case=additional data matches", func(t *testing.T) { + res, err := a.Decrypt(ctx, ct, []byte("additional data")) + assert.NoError(t, err) + assert.Equal(t, plain, res) + }) + + t.Run("case=additional data does not match", func(t *testing.T) { + res, err := a.Decrypt(ctx, ct, []byte("wrong data")) + assert.Error(t, err) + assert.Nil(t, res) + }) + + t.Run("case=missing additional data", func(t *testing.T) { + res, err := a.Decrypt(ctx, ct, nil) + assert.Error(t, err) + assert.Nil(t, res) + }) + }) + }) + } +} diff --git a/aead/aesgcm.go b/aead/aesgcm.go new file mode 100644 index 00000000000..86ae12839ec --- /dev/null +++ b/aead/aesgcm.go @@ -0,0 +1,126 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package aead + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "io" + + "github.com/pkg/errors" + + "github.com/ory/x/errorsx" +) + +type AESGCM struct { + c Dependencies +} + +func NewAESGCM(c Dependencies) *AESGCM { + return &AESGCM{c: c} +} + +func aeadKey(key []byte) *[32]byte { + var result [32]byte + copy(result[:], key[:32]) + return &result +} + +func (c *AESGCM) Encrypt(ctx context.Context, plaintext, additionalData []byte) (string, error) { + key, err := encryptionKey(ctx, c.c, 32) + if err != nil { + return "", err + } + + ciphertext, err := aesGCMEncrypt(plaintext, aeadKey(key), additionalData) + if err != nil { + return "", errorsx.WithStack(err) + } + + return base64.URLEncoding.EncodeToString(ciphertext), nil +} + +func (c *AESGCM) Decrypt(ctx context.Context, ciphertext string, aad []byte) (plaintext []byte, err error) { + msg, err := base64.URLEncoding.DecodeString(ciphertext) + if err != nil { + return nil, errorsx.WithStack(err) + } + + keys, err := allKeys(ctx, c.c) + if err != nil { + return nil, errorsx.WithStack(err) + } + + for _, key := range keys { + if plaintext, err = c.decrypt(msg, key, aad); err == nil { + return plaintext, nil + } + } + + return nil, err +} + +func (c *AESGCM) decrypt(ciphertext []byte, key, additionalData []byte) ([]byte, error) { + if len(key) != 32 { + return nil, errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(key)) + } + + plaintext, err := aesGCMDecrypt(ciphertext, aeadKey(key), additionalData) + if err != nil { + return nil, errorsx.WithStack(err) + } + + return plaintext, nil +} + +// aesGCMEncrypt encrypts data using 256-bit AES-GCM. This both hides the content of +// the data and provides a check that it hasn't been altered. Output takes the +// form nonce|ciphertext|tag where '|' indicates concatenation. +func aesGCMEncrypt(plaintext []byte, key *[32]byte, additionalData []byte) (ciphertext []byte, err error) { + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + _, err = io.ReadFull(rand.Reader, nonce) + if err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, plaintext, additionalData), nil +} + +// aesGCMDecrypt decrypts data using 256-bit AES-GCM. This both hides the content of +// the data and provides a check that it hasn't been altered. Expects input +// form nonce|ciphertext|tag where '|' indicates concatenation. +func aesGCMDecrypt(ciphertext []byte, key *[32]byte, additionalData []byte) (plaintext []byte, err error) { + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + if len(ciphertext) < gcm.NonceSize() { + return nil, errors.New("malformed ciphertext") + } + + return gcm.Open(nil, + ciphertext[:gcm.NonceSize()], + ciphertext[gcm.NonceSize():], + additionalData, + ) +} diff --git a/aead/helpers.go b/aead/helpers.go new file mode 100644 index 00000000000..7acd06c3a0d --- /dev/null +++ b/aead/helpers.go @@ -0,0 +1,41 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package aead + +import ( + "context" + "fmt" +) + +func encryptionKey(ctx context.Context, d Dependencies, keySize int) ([]byte, error) { + keys, err := allKeys(ctx, d) + if err != nil { + return nil, err + } + + key := keys[0] + if len(key) != keySize { + return nil, fmt.Errorf("key must be exactly %d bytes long, got %d bytes", keySize, len(key)) + } + + return key, nil +} + +func allKeys(ctx context.Context, d Dependencies) ([][]byte, error) { + global, err := d.GetGlobalSecret(ctx) + if err != nil { + return nil, err + } + + rotated, err := d.GetRotatedGlobalSecrets(ctx) + if err != nil { + return nil, err + } + + keys := append([][]byte{global}, rotated...) + if len(keys) == 0 { + return nil, fmt.Errorf("at least one encryption key must be defined but none were") + } + return keys, nil +} diff --git a/aead/xchacha20.go b/aead/xchacha20.go new file mode 100644 index 00000000000..cb1d2fbf278 --- /dev/null +++ b/aead/xchacha20.go @@ -0,0 +1,80 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package aead + +import ( + "context" + "crypto/cipher" + cryptorand "crypto/rand" + "encoding/base64" + "fmt" + + "golang.org/x/crypto/chacha20poly1305" + + "github.com/ory/x/errorsx" +) + +var _ Cipher = (*XChaCha20Poly1305)(nil) + +type ( + XChaCha20Poly1305 struct { + d Dependencies + } +) + +func NewXChaCha20Poly1305(d Dependencies) *XChaCha20Poly1305 { + return &XChaCha20Poly1305{d} +} + +func (x *XChaCha20Poly1305) Encrypt(ctx context.Context, plaintext, additionalData []byte) (string, error) { + key, err := encryptionKey(ctx, x.d, chacha20poly1305.KeySize) + if err != nil { + return "", err + } + + aead, err := chacha20poly1305.NewX(key) + if err != nil { + return "", errorsx.WithStack(err) + } + + nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(plaintext)+aead.Overhead()) + _, err = cryptorand.Read(nonce) + if err != nil { + return "", errorsx.WithStack(err) + } + + ciphertext := aead.Seal(nonce, nonce, plaintext, additionalData) + return base64.URLEncoding.EncodeToString(ciphertext), nil +} + +func (x *XChaCha20Poly1305) Decrypt(ctx context.Context, ciphertext string, aad []byte) (plaintext []byte, err error) { + msg, err := base64.URLEncoding.DecodeString(ciphertext) + if err != nil { + return nil, errorsx.WithStack(err) + } + + if len(msg) < chacha20poly1305.NonceSizeX { + return nil, errorsx.WithStack(fmt.Errorf("malformed ciphertext: too short")) + } + nonce, ciphered := msg[:chacha20poly1305.NonceSizeX], msg[chacha20poly1305.NonceSizeX:] + + keys, err := allKeys(ctx, x.d) + if err != nil { + return nil, errorsx.WithStack(err) + } + + var aead cipher.AEAD + for _, key := range keys { + aead, err = chacha20poly1305.NewX(key) + if err != nil { + continue + } + plaintext, err = aead.Open(nil, nonce, ciphered, aad) + if err == nil { + return plaintext, nil + } + } + + return nil, errorsx.WithStack(err) +} diff --git a/client/client.go b/client/client.go index 3f01b1099f0..57fdca8b46f 100644 --- a/client/client.go +++ b/client/client.go @@ -4,9 +4,12 @@ package client import ( + "strconv" "strings" "time" + "github.com/twmb/murmur3" + "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/stringsx" @@ -560,3 +563,13 @@ func AccessTokenStrategySource(client fosite.Client) config.AccessTokenStrategyS } return nil } + +func (c *Client) CookieSuffix() string { + return CookieSuffix(c) +} + +type IDer interface{ GetID() string } + +func CookieSuffix(client IDer) string { + return strconv.Itoa(int(murmur3.Sum32([]byte(client.GetID())))) +} diff --git a/client/handler.go b/client/handler.go index eb1f7b7660c..756b47e9baa 100644 --- a/client/handler.go +++ b/client/handler.go @@ -67,6 +67,8 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin, public *httprouterx. // OAuth 2.0 Client Creation Parameters // // swagger:parameters createOAuth2Client +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type createOAuth2Client struct { // OAuth 2.0 Client Request Body // @@ -107,6 +109,8 @@ func (h *Handler) createOAuth2Client(w http.ResponseWriter, r *http.Request, _ h // OpenID Connect Dynamic Client Registration Parameters // // swagger:parameters createOidcDynamicClient +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type createOidcDynamicClient struct { // Dynamic Client Registration Request Body // @@ -214,6 +218,8 @@ func (h *Handler) CreateClient(r *http.Request, validator func(context.Context, // Set OAuth 2.0 Client Parameters // // swagger:parameters setOAuth2Client +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type setOAuth2Client struct { // OAuth 2.0 Client ID // @@ -290,6 +296,8 @@ func (h *Handler) updateClient(ctx context.Context, c *Client, validator func(co // Set Dynamic Client Parameters // // swagger:parameters setOidcDynamicClient +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type setOidcDynamicClient struct { // OAuth 2.0 Client ID // @@ -383,6 +391,8 @@ func (h *Handler) setOidcDynamicClient(w http.ResponseWriter, r *http.Request, p // Patch OAuth 2.0 Client Parameters // // swagger:parameters patchOAuth2Client +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type patchOAuth2Client struct { // The id of the OAuth 2.0 Client. // @@ -460,6 +470,8 @@ func (h *Handler) patchOAuth2Client(w http.ResponseWriter, r *http.Request, ps h // Paginated OAuth2 Client List Response // // swagger:response listOAuth2Clients +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type listOAuth2ClientsResponse struct { tokenpagination.ResponseHeaders @@ -472,6 +484,8 @@ type listOAuth2ClientsResponse struct { // Paginated OAuth2 Client List Parameters // // swagger:parameters listOAuth2Clients +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type listOAuth2ClientsParameters struct { tokenpagination.RequestParameters @@ -540,6 +554,8 @@ func (h *Handler) listOAuth2Clients(w http.ResponseWriter, r *http.Request, ps h // Get OAuth2 Client Parameters // // swagger:parameters getOAuth2Client +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type adminGetOAuth2Client struct { // The id of the OAuth 2.0 Client. // @@ -583,6 +599,8 @@ func (h *Handler) Get(w http.ResponseWriter, r *http.Request, ps httprouter.Para // Get OpenID Connect Dynamic Client Parameters // // swagger:parameters getOidcDynamicClient +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getOidcDynamicClient struct { // The id of the OAuth 2.0 Client. // @@ -644,6 +662,8 @@ func (h *Handler) getOidcDynamicClient(w http.ResponseWriter, r *http.Request, p // Delete OAuth2 Client Parameters // // swagger:parameters deleteOAuth2Client +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type deleteOAuth2Client struct { // The id of the OAuth 2.0 Client. // @@ -687,6 +707,8 @@ func (h *Handler) deleteOAuth2Client(w http.ResponseWriter, r *http.Request, ps // Set OAuth 2.0 Client Token Lifespans // // swagger:parameters setOAuth2ClientLifespans +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type setOAuth2ClientLifespans struct { // OAuth 2.0 Client ID // @@ -738,6 +760,8 @@ func (h *Handler) setOAuth2ClientLifespans(w http.ResponseWriter, r *http.Reques } // swagger:parameters deleteOidcDynamicClient +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type dynamicClientRegistrationDeleteOAuth2Client struct { // The id of the OAuth 2.0 Client. // diff --git a/client/manager.go b/client/manager.go index ad8cca7df51..6b0d9c5de05 100644 --- a/client/manager.go +++ b/client/manager.go @@ -49,3 +49,7 @@ type Storage interface { GetConcreteClient(ctx context.Context, id string) (*Client, error) } + +type ManagerProvider interface { + ClientManager() Manager +} diff --git a/cmd/cli/handler_janitor.go b/cmd/cli/handler_janitor.go index e5082bf4b57..69035148b6e 100644 --- a/cmd/cli/handler_janitor.go +++ b/cmd/cli/handler_janitor.go @@ -6,6 +6,7 @@ package cli import ( "context" "fmt" + "io" "time" "github.com/ory/x/servicelocatorx" @@ -52,12 +53,13 @@ func NewJanitorHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsMo } } -func (_ *JanitorHandler) Args(cmd *cobra.Command, args []string) error { +func (*JanitorHandler) Args(cmd *cobra.Command, args []string) error { if len(args) == 0 && !flagx.MustGetBool(cmd, ReadFromEnv) && len(flagx.MustGetStringSlice(cmd, Config)) == 0 { fmt.Printf("%s\n", cmd.UsageString()) + //lint:ignore ST1005 formatted error string used in CLI output return fmt.Errorf("%s\n%s\n%s\n", "A DSN is required as a positional argument when not passing any of the following flags:", "- Using the environment variable with flag -e, --read-from-env", @@ -65,6 +67,7 @@ func (_ *JanitorHandler) Args(cmd *cobra.Command, args []string) error { } if !flagx.MustGetBool(cmd, OnlyTokens) && !flagx.MustGetBool(cmd, OnlyRequests) && !flagx.MustGetBool(cmd, OnlyGrants) { + //lint:ignore ST1005 formatted error string used in CLI output return fmt.Errorf("%s\n%s\n", cmd.UsageString(), "Janitor requires at least one of --tokens, --requests or --grants to be set") } @@ -72,10 +75,12 @@ func (_ *JanitorHandler) Args(cmd *cobra.Command, args []string) error { limit := flagx.MustGetInt(cmd, Limit) batchSize := flagx.MustGetInt(cmd, BatchSize) if limit <= 0 || batchSize <= 0 { + //lint:ignore ST1005 formatted error string used in CLI output return fmt.Errorf("%s\n%s\n", cmd.UsageString(), "Values for --limit and --batch-size should both be greater than 0") } if batchSize > limit { + //lint:ignore ST1005 formatted error string used in CLI output return fmt.Errorf("%s\n%s\n", cmd.UsageString(), "Value for --batch-size must not be greater than value for --limit") } @@ -130,6 +135,7 @@ func purge(cmd *cobra.Command, args []string, sl *servicelocatorx.Options, dOpts } if len(d.Config().DSN()) == 0 { + //lint:ignore ST1005 formatted error string used in CLI output return fmt.Errorf("%s\n%s\n%s\n", cmd.UsageString(), "When using flag -e, environment variable DSN must be set.", "When using flag -c, the dsn property should be set.") @@ -154,20 +160,20 @@ func purge(cmd *cobra.Command, args []string, sl *servicelocatorx.Options, dOpts routineFlags = append(routineFlags, OnlyGrants) } - return cleanupRun(cmd.Context(), notAfter, limit, batchSize, addRoutine(p, routineFlags...)...) + return cleanupRun(cmd.Context(), notAfter, limit, batchSize, addRoutine(cmd.OutOrStdout(), p, routineFlags...)...) } -func addRoutine(p persistence.Persister, names ...string) []cleanupRoutine { +func addRoutine(out io.Writer, p persistence.Persister, names ...string) []cleanupRoutine { var routines []cleanupRoutine for _, n := range names { switch n { case OnlyTokens: - routines = append(routines, cleanup(p.FlushInactiveAccessTokens, "access tokens")) - routines = append(routines, cleanup(p.FlushInactiveRefreshTokens, "refresh tokens")) + routines = append(routines, cleanup(out, p.FlushInactiveAccessTokens, "access tokens")) + routines = append(routines, cleanup(out, p.FlushInactiveRefreshTokens, "refresh tokens")) case OnlyRequests: - routines = append(routines, cleanup(p.FlushInactiveLoginConsentRequests, "login-consent requests")) + routines = append(routines, cleanup(out, p.FlushInactiveLoginConsentRequests, "login-consent requests")) case OnlyGrants: - routines = append(routines, cleanup(p.FlushInactiveGrants, "grants")) + routines = append(routines, cleanup(out, p.FlushInactiveGrants, "grants")) } } return routines @@ -175,12 +181,12 @@ func addRoutine(p persistence.Persister, names ...string) []cleanupRoutine { type cleanupRoutine func(ctx context.Context, notAfter time.Time, limit int, batchSize int) error -func cleanup(cr cleanupRoutine, routineName string) cleanupRoutine { +func cleanup(out io.Writer, cr cleanupRoutine, routineName string) cleanupRoutine { return func(ctx context.Context, notAfter time.Time, limit int, batchSize int) error { if err := cr(ctx, notAfter, limit, batchSize); err != nil { return errors.Wrap(errorsx.WithStack(err), fmt.Sprintf("Could not cleanup inactive %s", routineName)) } - fmt.Printf("Successfully completed Janitor run on %s\n", routineName) + fmt.Fprintf(out, "Successfully completed Janitor run on %s\n", routineName) return nil } } diff --git a/cmd/cli/handler_janitor_test.go b/cmd/cli/handler_janitor_test.go index 9f73beea846..7806f7c5471 100644 --- a/cmd/cli/handler_janitor_test.go +++ b/cmd/cli/handler_janitor_test.go @@ -48,7 +48,7 @@ func TestJanitorHandler_PurgeTokenNotAfter(t *testing.T) { fmt.Sprintf("--%s=%s", cli.AccessLifespan, jt.GetAccessTokenLifespan(ctx).String()), fmt.Sprintf("--%s=%s", cli.RefreshLifespan, jt.GetRefreshTokenLifespan(ctx).String()), fmt.Sprintf("--%s", cli.OnlyTokens), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) @@ -80,13 +80,13 @@ func TestJanitorHandler_PurgeLoginConsentNotAfter(t *testing.T) { fmt.Sprintf("--%s=%s", cli.KeepIfYounger, v.String()), fmt.Sprintf("--%s=%s", cli.ConsentRequestLifespan, jt.GetConsentRequestLifespan(ctx).String()), fmt.Sprintf("--%s", cli.OnlyRequests), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) notAfter := time.Now().Round(time.Second).Add(-v) consentLifespan := time.Now().Round(time.Second).Add(-jt.GetConsentRequestLifespan(ctx)) - t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentLifespan, reg.ConsentManager())) + t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentLifespan, reg)) }) } @@ -107,14 +107,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) { require.NoError(t, err) // setup - t.Run("step=setup", jt.LoginTimeoutSetup(ctx, reg.ConsentManager(), reg.ClientManager())) + t.Run("step=setup", jt.LoginTimeoutSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { cmdx.ExecNoErr(t, newJanitorCmd(), "janitor", fmt.Sprintf("--%s", cli.OnlyRequests), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) @@ -129,14 +129,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) { require.NoError(t, err) // setup - t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, reg.ConsentManager(), reg.ClientManager())) + t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, reg)) // run cleanup t.Run("step=cleanup", func(t *testing.T) { cmdx.ExecNoErr(t, newJanitorCmd(), "janitor", fmt.Sprintf("--%s", cli.OnlyRequests), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) @@ -155,14 +155,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) { require.NoError(t, err) // setup - t.Run("step=setup", jt.LoginRejectionSetup(ctx, reg.ConsentManager(), reg.ClientManager())) + t.Run("step=setup", jt.LoginRejectionSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { cmdx.ExecNoErr(t, newJanitorCmd(), "janitor", fmt.Sprintf("--%s", cli.OnlyRequests), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) @@ -176,14 +176,14 @@ func TestJanitorHandler_PurgeLoginConsent(t *testing.T) { require.NoError(t, err) // setup - t.Run("step=setup", jt.ConsentRejectionSetup(ctx, reg.ConsentManager(), reg.ClientManager())) + t.Run("step=setup", jt.ConsentRejectionSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { cmdx.ExecNoErr(t, newJanitorCmd(), "janitor", fmt.Sprintf("--%s", cli.OnlyRequests), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) @@ -279,7 +279,7 @@ func TestJanitorHandler_PurgeGrantNotAfter(t *testing.T) { require.NoError(t, err) // setup test - t.Run("step=setup", jt.GrantNotAfterSetup(ctx, reg.ClientManager(), reg.GrantManager())) + t.Run("step=setup", jt.GrantNotAfterSetup(ctx, reg.GrantManager())) // run the cleanup routine t.Run("step=cleanup", func(t *testing.T) { @@ -287,7 +287,7 @@ func TestJanitorHandler_PurgeGrantNotAfter(t *testing.T) { "janitor", fmt.Sprintf("--%s=%s", cli.KeepIfYounger, v.String()), fmt.Sprintf("--%s", cli.OnlyGrants), - jt.GetDSN(ctx), + jt.GetDSN(), ) }) diff --git a/cmd/cmd_list_clients.go b/cmd/cmd_list_clients.go index 2c8f79356cf..ddaf2762018 100644 --- a/cmd/cmd_list_clients.go +++ b/cmd/cmd_list_clients.go @@ -35,11 +35,10 @@ func NewListClientsCmd() *cobra.Command { if err != nil { return cmdx.PrintOpenAPIError(cmd, err) } + defer resp.Body.Close() var collection outputOAuth2ClientCollection - for k := range list { - collection.clients = append(collection.clients, list[k]) - } + collection.clients = append(collection.clients, list...) interfaceList := make([]interface{}, len(list)) for k := range collection.clients { diff --git a/cmd/output_client.go b/cmd/output_client.go index 1b052c56967..3f060f281df 100644 --- a/cmd/output_client.go +++ b/cmd/output_client.go @@ -19,7 +19,7 @@ type ( } ) -func (_ outputOAuth2Client) Header() []string { +func (outputOAuth2Client) Header() []string { return []string{"CLIENT ID", "CLIENT SECRET", "GRANT TYPES", "RESPONSE TYPES", "SCOPE", "AUDIENCE", "REDIRECT URIS"} } @@ -40,7 +40,7 @@ func (i outputOAuth2Client) Interface() interface{} { return i } -func (_ outputOAuth2ClientCollection) Header() []string { +func (outputOAuth2ClientCollection) Header() []string { return outputOAuth2Client{}.Header() } diff --git a/cmd/output_introspection.go b/cmd/output_introspection.go index e3aa576421d..1f89f016530 100644 --- a/cmd/output_introspection.go +++ b/cmd/output_introspection.go @@ -16,7 +16,7 @@ type ( outputOAuth2TokenIntrospection hydra.IntrospectedOAuth2Token ) -func (_ outputOAuth2TokenIntrospection) Header() []string { +func (outputOAuth2TokenIntrospection) Header() []string { return []string{"ACTIVE", "SUBJECT", "CLIENT ID", "SCOPE", "EXPIRY", "TOKEN USE"} } diff --git a/cmd/output_jwks.go b/cmd/output_jwks.go index 3b42af3b113..207e33a9d1f 100644 --- a/cmd/output_jwks.go +++ b/cmd/output_jwks.go @@ -20,7 +20,7 @@ type ( } ) -func (_ outputJsonWebKey) Header() []string { +func (outputJsonWebKey) Header() []string { return []string{"SET ID", "KEY ID", "ALGORITHM", "USE"} } @@ -38,7 +38,7 @@ func (i outputJsonWebKey) Interface() interface{} { return i } -func (_ outputJSONWebKeyCollection) Header() []string { +func (outputJSONWebKeyCollection) Header() []string { return outputJsonWebKey{}.Header() } diff --git a/cmd/output_token.go b/cmd/output_token.go index c91add12cb5..17da6bf274c 100644 --- a/cmd/output_token.go +++ b/cmd/output_token.go @@ -16,7 +16,7 @@ type ( outputOAuth2Token oauth2.Token ) -func (_ outputOAuth2Token) Header() []string { +func (outputOAuth2Token) Header() []string { return []string{"ACCESS TOKEN", "REFRESH TOKEN", "ID TOKEN", "EXPIRY"} } diff --git a/cmd/server/handler.go b/cmd/server/handler.go index f230648c62d..d9c1bf394f8 100644 --- a/cmd/server/handler.go +++ b/cmd/server/handler.go @@ -17,7 +17,7 @@ import ( "github.com/ory/x/corsx" "github.com/ory/x/httprouterx" - analytics "github.com/ory/analytics-go/v5" + "github.com/ory/analytics-go/v5" "github.com/ory/x/configx" "github.com/ory/x/reqlog" @@ -223,11 +223,8 @@ func setup(ctx context.Context, d driver.Registry, cmd *cobra.Command) (admin *h d.Logger(), d.Config().Source(ctx), &metricsx.Options{ - Service: "ory-hydra", - DeploymentId: metricsx.Hash(fmt.Sprintf("%s|%s", - d.Config().IssuerURL(ctx).String(), - d.Config().DSN(), - )), + Service: "hydra", + DeploymentId: metricsx.Hash(d.Persister().NetworkID(ctx).String()), IsDevelopment: d.Config().DSN() == "memory" || d.Config().IssuerURL(ctx).String() == "" || strings.Contains(d.Config().IssuerURL(ctx).String(), "localhost"), diff --git a/consent/handler.go b/consent/handler.go index 81dd69b7541..10cc988c9b8 100644 --- a/consent/handler.go +++ b/consent/handler.go @@ -9,6 +9,9 @@ import ( "net/url" "time" + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/oauth2/flowctx" + "github.com/ory/hydra/v2/x/events" "github.com/ory/x/pagination/tokenpagination" "github.com/ory/x/httprouterx" @@ -68,6 +71,8 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin) { // Revoke OAuth 2.0 Consent Session Parameters // // swagger:parameters revokeOAuth2ConsentSessions +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type revokeOAuth2ConsentSessions struct { // OAuth 2.0 Consent Subject // @@ -110,7 +115,7 @@ type revokeOAuth2ConsentSessions struct { // Responses: // 204: emptyResponse // default: errorOAuth2 -func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { subject := r.URL.Query().Get("subject") client := r.URL.Query().Get("client") allClients := r.URL.Query().Get("all") == "true" @@ -125,11 +130,13 @@ func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Req h.r.Writer().WriteError(w, r, err) return } + events.Trace(r.Context(), events.ConsentRevoked, events.WithSubject(subject), events.WithClientID(client)) case allClients: if err := h.r.ConsentManager().RevokeSubjectConsentSession(r.Context(), subject); err != nil && !errors.Is(err, x.ErrNotFound) { h.r.Writer().WriteError(w, r, err) return } + events.Trace(r.Context(), events.ConsentRevoked, events.WithSubject(subject)) default: h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter both 'client' and 'all' is not defined but one of them should have been.`))) return @@ -141,6 +148,8 @@ func (h *Handler) revokeOAuth2ConsentSessions(w http.ResponseWriter, r *http.Req // List OAuth 2.0 Consent Session Parameters // // swagger:parameters listOAuth2ConsentSessions +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type listOAuth2ConsentSessions struct { tokenpagination.RequestParameters @@ -176,7 +185,7 @@ type listOAuth2ConsentSessions struct { // Responses: // 200: oAuth2ConsentSessions // default: errorOAuth2 -func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { subject := r.URL.Query().Get("subject") if subject == "" { h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter 'subject' is not defined but should have been.`))) @@ -186,7 +195,7 @@ func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Reque page, itemsPerPage := x.ParsePagination(r) - var s []AcceptOAuth2ConsentRequest + var s []flow.AcceptOAuth2ConsentRequest var err error if len(loginSessionId) == 0 { s, err = h.r.ConsentManager().FindSubjectsGrantedConsentRequests(r.Context(), subject, itemsPerPage, itemsPerPage*page) @@ -194,21 +203,21 @@ func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Reque s, err = h.r.ConsentManager().FindSubjectsSessionGrantedConsentRequests(r.Context(), subject, loginSessionId, itemsPerPage, itemsPerPage*page) } if errors.Is(err, ErrNoPreviousConsentFound) { - h.r.Writer().Write(w, r, []OAuth2ConsentSession{}) + h.r.Writer().Write(w, r, []flow.OAuth2ConsentSession{}) return } else if err != nil { h.r.Writer().WriteError(w, r, err) return } - var a []OAuth2ConsentSession + var a []flow.OAuth2ConsentSession for _, session := range s { session.ConsentRequest.Client = sanitizeClient(session.ConsentRequest.Client) - a = append(a, OAuth2ConsentSession(session)) + a = append(a, flow.OAuth2ConsentSession(session)) } if len(a) == 0 { - a = []OAuth2ConsentSession{} + a = []flow.OAuth2ConsentSession{} } n, err := h.r.ConsentManager().CountSubjectsGrantedConsentRequests(r.Context(), subject) @@ -224,6 +233,8 @@ func (h *Handler) listOAuth2ConsentSessions(w http.ResponseWriter, r *http.Reque // Revoke OAuth 2.0 Consent Login Sessions Parameters // // swagger:parameters revokeOAuth2LoginSessions +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type revokeOAuth2LoginSessions struct { // OAuth 2.0 Subject // @@ -264,7 +275,7 @@ type revokeOAuth2LoginSessions struct { // Responses: // 204: emptyResponse // default: errorOAuth2 -func (h *Handler) revokeOAuth2LoginSessions(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) revokeOAuth2LoginSessions(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { sid := r.URL.Query().Get("sid") subject := r.URL.Query().Get("subject") @@ -294,6 +305,8 @@ func (h *Handler) revokeOAuth2LoginSessions(w http.ResponseWriter, r *http.Reque // Get OAuth 2.0 Login Request // // swagger:parameters getOAuth2LoginRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getOAuth2LoginRequest struct { // OAuth 2.0 Login Request Challenge // @@ -328,7 +341,7 @@ type getOAuth2LoginRequest struct { // 200: oAuth2LoginRequest // 410: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { challenge := stringsx.Coalesce( r.URL.Query().Get("login_challenge"), r.URL.Query().Get("challenge"), @@ -345,7 +358,7 @@ func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, return } if request.WasHandled { - h.r.Writer().WriteCode(w, r, http.StatusGone, &OAuth2RedirectTo{ + h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{ RedirectTo: request.RequestURL, }) return @@ -358,6 +371,8 @@ func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, // Accept OAuth 2.0 Login Request // // swagger:parameters acceptOAuth2LoginRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type acceptOAuth2LoginRequest struct { // OAuth 2.0 Login Request Challenge // @@ -366,7 +381,7 @@ type acceptOAuth2LoginRequest struct { Challenge string `json:"login_challenge"` // in: body - Body HandledLoginRequest + Body flow.HandledLoginRequest } // swagger:route PUT /admin/oauth2/auth/requests/login/accept oAuth2 acceptOAuth2LoginRequest @@ -396,7 +411,9 @@ type acceptOAuth2LoginRequest struct { // Responses: // 200: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + ctx := r.Context() + challenge := stringsx.Coalesce( r.URL.Query().Get("login_challenge"), r.URL.Query().Get("challenge"), @@ -406,7 +423,7 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques return } - var p HandledLoginRequest + var p flow.HandledLoginRequest d := json.NewDecoder(r.Body) d.DisallowUnknownFields() if err := d.Decode(&p); err != nil { @@ -420,7 +437,7 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques } p.ID = challenge - ar, err := h.r.ConsentManager().GetLoginRequest(r.Context(), challenge) + ar, err := h.r.ConsentManager().GetLoginRequest(ctx, challenge) if err != nil { h.r.Writer().WriteError(w, r, err) return @@ -440,7 +457,12 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques } p.RequestedAt = ar.RequestedAt - request, err := h.r.ConsentManager().HandleLoginRequest(r.Context(), challenge, &p) + f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsLoginChallenge) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + request, err := h.r.ConsentManager().HandleLoginRequest(ctx, f, challenge, &p) if err != nil { h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return @@ -452,14 +474,24 @@ func (h *Handler) acceptOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques return } - h.r.Writer().Write(w, r, &OAuth2RedirectTo{ - RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {request.Verifier}}).String(), + verifier, err := f.ToLoginVerifier(ctx, h.r) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + + events.Trace(ctx, events.LoginAccepted, events.WithClientID(request.Client.GetID()), events.WithSubject(request.Subject)) + + h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{ + RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {verifier}}).String(), }) } // Reject OAuth 2.0 Login Request // // swagger:parameters rejectOAuth2LoginRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type rejectOAuth2LoginRequest struct { // OAuth 2.0 Login Request Challenge // @@ -468,7 +500,7 @@ type rejectOAuth2LoginRequest struct { Challenge string `json:"login_challenge"` // in: body - Body RequestDeniedError + Body flow.RequestDeniedError } // swagger:route PUT /admin/oauth2/auth/requests/login/reject oAuth2 rejectOAuth2LoginRequest @@ -497,7 +529,9 @@ type rejectOAuth2LoginRequest struct { // Responses: // 200: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + ctx := r.Context() + challenge := stringsx.Coalesce( r.URL.Query().Get("login_challenge"), r.URL.Query().Get("challenge"), @@ -507,7 +541,7 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques return } - var p RequestDeniedError + var p flow.RequestDeniedError d := json.NewDecoder(r.Body) d.DisallowUnknownFields() if err := d.Decode(&p); err != nil { @@ -515,15 +549,20 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques return } - p.valid = true - p.SetDefaults(loginRequestDeniedErrorName) - ar, err := h.r.ConsentManager().GetLoginRequest(r.Context(), challenge) + p.Valid = true + p.SetDefaults(flow.LoginRequestDeniedErrorName) + ar, err := h.r.ConsentManager().GetLoginRequest(ctx, challenge) if err != nil { h.r.Writer().WriteError(w, r, err) return } - request, err := h.r.ConsentManager().HandleLoginRequest(r.Context(), challenge, &HandledLoginRequest{ + f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsLoginChallenge) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + request, err := h.r.ConsentManager().HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{ Error: &p, ID: challenge, RequestedAt: ar.RequestedAt, @@ -533,20 +572,30 @@ func (h *Handler) rejectOAuth2LoginRequest(w http.ResponseWriter, r *http.Reques return } + verifier, err := f.ToLoginVerifier(ctx, h.r) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + ru, err := url.Parse(request.RequestURL) if err != nil { h.r.Writer().WriteError(w, r, err) return } - h.r.Writer().Write(w, r, &OAuth2RedirectTo{ - RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {request.Verifier}}).String(), + events.Trace(ctx, events.LoginRejected, events.WithClientID(request.Client.GetID()), events.WithSubject(request.Subject)) + + h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{ + RedirectTo: urlx.SetQuery(ru, url.Values{"login_verifier": {verifier}}).String(), }) } // Get OAuth 2.0 Consent Request // // swagger:parameters getOAuth2ConsentRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getOAuth2ConsentRequest struct { // OAuth 2.0 Consent Request Challenge // @@ -582,7 +631,7 @@ type getOAuth2ConsentRequest struct { // 200: oAuth2ConsentRequest // 410: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { challenge := stringsx.Coalesce( r.URL.Query().Get("consent_challenge"), r.URL.Query().Get("challenge"), @@ -598,7 +647,7 @@ func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request return } if request.WasHandled { - h.r.Writer().WriteCode(w, r, http.StatusGone, &OAuth2RedirectTo{ + h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{ RedirectTo: request.RequestURL, }) return @@ -619,6 +668,8 @@ func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request // Accept OAuth 2.0 Consent Request // // swagger:parameters acceptOAuth2ConsentRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type acceptOAuth2ConsentRequest struct { // OAuth 2.0 Consent Request Challenge // @@ -627,7 +678,7 @@ type acceptOAuth2ConsentRequest struct { Challenge string `json:"consent_challenge"` // in: body - Body AcceptOAuth2ConsentRequest + Body flow.AcceptOAuth2ConsentRequest } // swagger:route PUT /admin/oauth2/auth/requests/consent/accept oAuth2 acceptOAuth2ConsentRequest @@ -662,7 +713,9 @@ type acceptOAuth2ConsentRequest struct { // Responses: // 200: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + ctx := r.Context() + challenge := stringsx.Coalesce( r.URL.Query().Get("consent_challenge"), r.URL.Query().Get("challenge"), @@ -672,7 +725,7 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - var p AcceptOAuth2ConsentRequest + var p flow.AcceptOAuth2ConsentRequest d := json.NewDecoder(r.Body) d.DisallowUnknownFields() if err := d.Decode(&p); err != nil { @@ -680,7 +733,7 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - cr, err := h.r.ConsentManager().GetConsentRequest(r.Context(), challenge) + cr, err := h.r.ConsentManager().GetConsentRequest(ctx, challenge) if err != nil { h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return @@ -690,7 +743,12 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ p.RequestedAt = cr.RequestedAt p.HandledAt = sqlxx.NullTime(time.Now().UTC()) - hr, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), &p) + f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsConsentChallenge) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + hr, err := h.r.ConsentManager().HandleConsentRequest(ctx, f, &p) if err != nil { h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return @@ -704,14 +762,24 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - h.r.Writer().Write(w, r, &OAuth2RedirectTo{ - RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {hr.Verifier}}).String(), + verifier, err := f.ToConsentVerifier(ctx, h.r) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + + events.Trace(ctx, events.ConsentAccepted, events.WithClientID(cr.Client.GetID()), events.WithSubject(cr.Subject)) + + h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{ + RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {verifier}}).String(), }) } // Reject OAuth 2.0 Consent Request // // swagger:parameters rejectOAuth2ConsentRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type adminRejectOAuth2ConsentRequest struct { // OAuth 2.0 Consent Request Challenge // @@ -720,7 +788,7 @@ type adminRejectOAuth2ConsentRequest struct { Challenge string `json:"consent_challenge"` // in: body - Body RequestDeniedError + Body flow.RequestDeniedError } // swagger:route PUT /admin/oauth2/auth/requests/consent/reject oAuth2 rejectOAuth2ConsentRequest @@ -754,7 +822,9 @@ type adminRejectOAuth2ConsentRequest struct { // Responses: // 200: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + ctx := r.Context() + challenge := stringsx.Coalesce( r.URL.Query().Get("consent_challenge"), r.URL.Query().Get("challenge"), @@ -764,7 +834,7 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - var p RequestDeniedError + var p flow.RequestDeniedError d := json.NewDecoder(r.Body) d.DisallowUnknownFields() if err := d.Decode(&p); err != nil { @@ -772,15 +842,21 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - p.valid = true - p.SetDefaults(consentRequestDeniedErrorName) - hr, err := h.r.ConsentManager().GetConsentRequest(r.Context(), challenge) + p.Valid = true + p.SetDefaults(flow.ConsentRequestDeniedErrorName) + hr, err := h.r.ConsentManager().GetConsentRequest(ctx, challenge) if err != nil { h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return } - request, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), &AcceptOAuth2ConsentRequest{ + f, err := flowctx.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flowctx.AsConsentChallenge) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + + request, err := h.r.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{ Error: &p, ID: challenge, RequestedAt: hr.RequestedAt, @@ -797,14 +873,24 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ return } - h.r.Writer().Write(w, r, &OAuth2RedirectTo{ - RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {request.Verifier}}).String(), + verifier, err := f.ToConsentVerifier(ctx, h.r) + if err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + + events.Trace(ctx, events.ConsentRejected, events.WithClientID(request.Client.GetID()), events.WithSubject(request.Subject)) + + h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{ + RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {verifier}}).String(), }) } // Accept OAuth 2.0 Logout Request // // swagger:parameters acceptOAuth2LogoutRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type acceptOAuth2LogoutRequest struct { // OAuth 2.0 Logout Request Challenge // @@ -829,7 +915,7 @@ type acceptOAuth2LogoutRequest struct { // Responses: // 200: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { challenge := stringsx.Coalesce( r.URL.Query().Get("logout_challenge"), r.URL.Query().Get("challenge"), @@ -841,7 +927,7 @@ func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Reque return } - h.r.Writer().Write(w, r, &OAuth2RedirectTo{ + h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{ RedirectTo: urlx.SetQuery(urlx.AppendPaths(h.c.PublicURL(r.Context()), "/oauth2/sessions/logout"), url.Values{"logout_verifier": {c.Verifier}}).String(), }) } @@ -849,6 +935,8 @@ func (h *Handler) acceptOAuth2LogoutRequest(w http.ResponseWriter, r *http.Reque // Reject OAuth 2.0 Logout Request // // swagger:parameters rejectOAuth2LogoutRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type rejectOAuth2LogoutRequest struct { // in: query // required: true @@ -872,7 +960,7 @@ type rejectOAuth2LogoutRequest struct { // Responses: // 204: emptyResponse // default: errorOAuth2 -func (h *Handler) rejectOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) rejectOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { challenge := stringsx.Coalesce( r.URL.Query().Get("logout_challenge"), r.URL.Query().Get("challenge"), @@ -889,6 +977,8 @@ func (h *Handler) rejectOAuth2LogoutRequest(w http.ResponseWriter, r *http.Reque // Get OAuth 2.0 Logout Request // // swagger:parameters getOAuth2LogoutRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getOAuth2LogoutRequest struct { // in: query // required: true @@ -910,7 +1000,7 @@ type getOAuth2LogoutRequest struct { // 200: oAuth2LogoutRequest // 410: oAuth2RedirectTo // default: errorOAuth2 -func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { challenge := stringsx.Coalesce( r.URL.Query().Get("logout_challenge"), r.URL.Query().Get("challenge"), @@ -928,7 +1018,7 @@ func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, } if request.WasHandled { - h.r.Writer().WriteCode(w, r, http.StatusGone, &OAuth2RedirectTo{ + h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{ RedirectTo: request.RequestURL, }) return @@ -936,3 +1026,12 @@ func (h *Handler) getOAuth2LogoutRequest(w http.ResponseWriter, r *http.Request, h.r.Writer().Write(w, r, request) } + +func (h *Handler) flowFromCookie(r *http.Request) (*flow.Flow, error) { + clientID := r.URL.Query().Get("client_id") + if clientID == "" { + return nil, errors.WithStack(fosite.ErrInvalidClient) + } + + return flowctx.FromCookie[flow.Flow](r.Context(), r, h.r.FlowCipher(), flowctx.FlowCookie(flowctx.SuffixFromStatic(clientID))) +} diff --git a/consent/handler_test.go b/consent/handler_test.go index 6022674eeb1..47496fa0bf5 100644 --- a/consent/handler_test.go +++ b/consent/handler_test.go @@ -13,19 +13,17 @@ import ( "testing" "time" - "github.com/ory/x/pointerx" - - "github.com/ory/hydra/v2/x" - "github.com/ory/x/contextx" - "github.com/ory/x/sqlxx" - - "github.com/ory/hydra/v2/internal" - "github.com/stretchr/testify/require" hydra "github.com/ory/hydra-client-go/v2" "github.com/ory/hydra/v2/client" . "github.com/ory/hydra/v2/consent" + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/x" + "github.com/ory/x/contextx" + "github.com/ory/x/pointerx" + "github.com/ory/x/sqlxx" ) func TestGetLogoutRequest(t *testing.T) { @@ -39,6 +37,7 @@ func TestGetLogoutRequest(t *testing.T) { {true, true, http.StatusGone}, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + ctx := context.Background() key := fmt.Sprint(k) challenge := "challenge" + key requestURL := "http://192.0.2.1" @@ -48,8 +47,8 @@ func TestGetLogoutRequest(t *testing.T) { if tc.exists { cl := &client.Client{LegacyClientID: "client" + key} - require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl)) - require.NoError(t, reg.ConsentManager().CreateLogoutRequest(context.TODO(), &LogoutRequest{ + require.NoError(t, reg.ClientManager().CreateClient(ctx, cl)) + require.NoError(t, reg.ConsentManager().CreateLogoutRequest(context.TODO(), &flow.LogoutRequest{ Client: cl, ID: challenge, WasHandled: tc.handled, @@ -69,11 +68,11 @@ func TestGetLogoutRequest(t *testing.T) { require.EqualValues(t, tc.status, resp.StatusCode) if tc.handled { - var result OAuth2RedirectTo + var result flow.OAuth2RedirectTo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, requestURL, result.RedirectTo) } else if tc.exists { - var result LogoutRequest + var result flow.LogoutRequest require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, challenge, result.ID) require.Equal(t, requestURL, result.RequestURL) @@ -92,7 +91,8 @@ func TestGetLoginRequest(t *testing.T) { {true, false, http.StatusOK}, {true, true, http.StatusGone}, } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + t.Run(fmt.Sprintf("exists=%v/handled=%v", tc.exists, tc.handled), func(t *testing.T) { + ctx := context.Background() key := fmt.Sprint(k) challenge := "challenge" + key requestURL := "http://192.0.2.1" @@ -103,14 +103,20 @@ func TestGetLoginRequest(t *testing.T) { if tc.exists { cl := &client.Client{LegacyClientID: "client" + key} require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl)) - require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{ - Client: cl, - ID: challenge, - RequestURL: requestURL, - })) + f, err := reg.ConsentManager().CreateLoginRequest(context.Background(), &flow.LoginRequest{ + Client: cl, + ID: challenge, + RequestURL: requestURL, + RequestedAt: time.Now(), + }) + require.NoError(t, err) + challenge, err = f.ToLoginChallenge(ctx, reg) + require.NoError(t, err) if tc.handled { - _, err := reg.ConsentManager().HandleLoginRequest(context.Background(), challenge, &HandledLoginRequest{ID: challenge, WasHandled: true}) + _, err := reg.ConsentManager().HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{ID: challenge, WasHandled: true}) + require.NoError(t, err) + challenge, err = f.ToLoginChallenge(ctx, reg) require.NoError(t, err) } } @@ -127,11 +133,11 @@ func TestGetLoginRequest(t *testing.T) { require.EqualValues(t, tc.status, resp.StatusCode) if tc.handled { - var result OAuth2RedirectTo + var result flow.OAuth2RedirectTo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, requestURL, result.RedirectTo) } else if tc.exists { - var result LoginRequest + var result flow.LoginRequest require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, challenge, result.ID) require.Equal(t, requestURL, result.RequestURL) @@ -152,6 +158,7 @@ func TestGetConsentRequest(t *testing.T) { {true, true, http.StatusGone}, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + ctx := context.Background() key := fmt.Sprint(k) challenge := "challenge" + key requestURL := "http://192.0.2.1" @@ -161,14 +168,24 @@ func TestGetConsentRequest(t *testing.T) { if tc.exists { cl := &client.Client{LegacyClientID: "client" + key} - require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl)) - lr := &LoginRequest{ID: "login-" + challenge, Client: cl, RequestURL: requestURL} - require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), lr)) - _, err := reg.ConsentManager().HandleLoginRequest(context.Background(), lr.ID, &HandledLoginRequest{ - ID: lr.ID, + require.NoError(t, reg.ClientManager().CreateClient(ctx, cl)) + lr := &flow.LoginRequest{ + ID: "login-" + challenge, + Client: cl, + RequestURL: requestURL, + RequestedAt: time.Now(), + } + f, err := reg.ConsentManager().CreateLoginRequest(ctx, lr) + require.NoError(t, err) + challenge, err = f.ToLoginChallenge(ctx, reg) + require.NoError(t, err) + _, err = reg.ConsentManager().HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{ + ID: challenge, }) require.NoError(t, err) - require.NoError(t, reg.ConsentManager().CreateConsentRequest(context.Background(), &OAuth2ConsentRequest{ + challenge, err = f.ToConsentChallenge(ctx, reg) + require.NoError(t, err) + require.NoError(t, reg.ConsentManager().CreateConsentRequest(ctx, f, &flow.OAuth2ConsentRequest{ Client: cl, ID: challenge, Verifier: challenge, @@ -177,12 +194,14 @@ func TestGetConsentRequest(t *testing.T) { })) if tc.handled { - _, err := reg.ConsentManager().HandleConsentRequest(context.Background(), &AcceptOAuth2ConsentRequest{ + _, err := reg.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{ ID: challenge, WasHandled: true, HandledAt: sqlxx.NullTime(time.Now()), }) require.NoError(t, err) + challenge, err = f.ToConsentChallenge(ctx, reg) + require.NoError(t, err) } } @@ -199,11 +218,11 @@ func TestGetConsentRequest(t *testing.T) { require.EqualValues(t, tc.status, resp.StatusCode) if tc.handled { - var result OAuth2RedirectTo + var result flow.OAuth2RedirectTo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, requestURL, result.RedirectTo) } else if tc.exists { - var result OAuth2ConsentRequest + var result flow.OAuth2ConsentRequest require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, challenge, result.ID) require.Equal(t, requestURL, result.RequestURL) @@ -215,6 +234,7 @@ func TestGetConsentRequest(t *testing.T) { func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { t.Run("Test get login request with duplicate accept", func(t *testing.T) { + ctx := context.Background() challenge := "challenge" requestURL := "http://192.0.2.1" @@ -222,12 +242,16 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) cl := &client.Client{LegacyClientID: "client"} - require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cl)) - require.NoError(t, reg.ConsentManager().CreateLoginRequest(context.Background(), &LoginRequest{ - Client: cl, - ID: challenge, - RequestURL: requestURL, - })) + require.NoError(t, reg.ClientManager().CreateClient(ctx, cl)) + f, err := reg.ConsentManager().CreateLoginRequest(ctx, &flow.LoginRequest{ + Client: cl, + ID: challenge, + RequestURL: requestURL, + RequestedAt: time.Now(), + }) + require.NoError(t, err) + challenge, err = f.ToLoginChallenge(ctx, reg) + require.NoError(t, err) h := NewHandler(reg, conf) r := x.NewRouterAdmin(conf.AdminURL) @@ -238,7 +262,7 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { c := &http.Client{} sub := "sub123" - acceptLogin := &hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Bool(true), Subject: sub} + acceptLogin := &hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Ptr(true), Subject: sub} // marshal User to json acceptLoginJson, err := json.Marshal(acceptLogin) @@ -256,7 +280,7 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { require.NoError(t, err) require.EqualValues(t, http.StatusOK, resp.StatusCode) - var result OAuth2RedirectTo + var result flow.OAuth2RedirectTo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.NotNil(t, result.RedirectTo) require.Contains(t, result.RedirectTo, "login_verifier") @@ -270,7 +294,7 @@ func TestGetLoginRequestWithDuplicateAccept(t *testing.T) { require.NoError(t, err) require.EqualValues(t, http.StatusOK, resp2.StatusCode) - var result2 OAuth2RedirectTo + var result2 flow.OAuth2RedirectTo require.NoError(t, json.NewDecoder(resp2.Body).Decode(&result2)) require.NotNil(t, result2.RedirectTo) require.Contains(t, result2.RedirectTo, "login_verifier") diff --git a/consent/helper.go b/consent/helper.go index ed15dd03147..bf6e46b2765 100644 --- a/consent/helper.go +++ b/consent/helper.go @@ -6,9 +6,9 @@ package consent import ( "net/http" "strings" - "time" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/x" "github.com/ory/x/errorsx" @@ -33,7 +33,7 @@ func sanitizeClient(c *client.Client) *client.Client { return cc } -func matchScopes(scopeStrategy fosite.ScopeStrategy, previousConsent []AcceptOAuth2ConsentRequest, requestedScope []string) *AcceptOAuth2ConsentRequest { +func matchScopes(scopeStrategy fosite.ScopeStrategy, previousConsent []flow.AcceptOAuth2ConsentRequest, requestedScope []string) *flow.AcceptOAuth2ConsentRequest { for _, cs := range previousConsent { var found = true for _, scope := range requestedScope { diff --git a/consent/helper_test.go b/consent/helper_test.go index c350ee2f63b..a5f09e81cdd 100644 --- a/consent/helper_test.go +++ b/consent/helper_test.go @@ -12,6 +12,7 @@ import ( "github.com/golang/mock/gomock" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/internal/mock" "github.com/gorilla/securecookie" @@ -38,22 +39,22 @@ func TestSanitizeClient(t *testing.T) { func TestMatchScopes(t *testing.T) { for k, tc := range []struct { - granted []AcceptOAuth2ConsentRequest + granted []flow.AcceptOAuth2ConsentRequest requested []string expectChallenge string }{ { - granted: []AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}}, + granted: []flow.AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}}, requested: []string{"foo", "bar"}, expectChallenge: "1", }, { - granted: []AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}}, + granted: []flow.AcceptOAuth2ConsentRequest{{ID: "1", GrantedScope: []string{"foo", "bar"}}}, requested: []string{"foo", "bar", "baz"}, expectChallenge: "", }, { - granted: []AcceptOAuth2ConsentRequest{ + granted: []flow.AcceptOAuth2ConsentRequest{ {ID: "1", GrantedScope: []string{"foo", "bar"}}, {ID: "2", GrantedScope: []string{"foo", "bar"}}, }, @@ -61,7 +62,7 @@ func TestMatchScopes(t *testing.T) { expectChallenge: "1", }, { - granted: []AcceptOAuth2ConsentRequest{ + granted: []flow.AcceptOAuth2ConsentRequest{ {ID: "1", GrantedScope: []string{"foo", "bar"}}, {ID: "2", GrantedScope: []string{"foo", "bar", "baz"}}, }, @@ -69,7 +70,7 @@ func TestMatchScopes(t *testing.T) { expectChallenge: "2", }, { - granted: []AcceptOAuth2ConsentRequest{ + granted: []flow.AcceptOAuth2ConsentRequest{ {ID: "1", GrantedScope: []string{"foo", "bar"}}, {ID: "2", GrantedScope: []string{"foo", "bar", "baz"}}, }, diff --git a/consent/janitor_consent_test_helper.go b/consent/janitor_consent_test_helper.go index 6467eb1a63d..645a88a2209 100644 --- a/consent/janitor_consent_test_helper.go +++ b/consent/janitor_consent_test_helper.go @@ -6,23 +6,24 @@ package consent import ( "time" + "github.com/ory/hydra/v2/flow" "github.com/ory/x/sqlxx" ) -func NewHandledLoginRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *HandledLoginRequest { - var deniedErr *RequestDeniedError +func NewHandledLoginRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *flow.HandledLoginRequest { + var deniedErr *flow.RequestDeniedError if hasError { - deniedErr = &RequestDeniedError{ + deniedErr = &flow.RequestDeniedError{ Name: "consent request denied", Description: "some description", Hint: "some hint", Code: 403, Debug: "some debug", - valid: true, + Valid: true, } } - return &HandledLoginRequest{ + return &flow.HandledLoginRequest{ ID: challenge, Error: deniedErr, WasHandled: true, @@ -31,20 +32,20 @@ func NewHandledLoginRequest(challenge string, hasError bool, requestedAt time.Ti } } -func NewHandledConsentRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *AcceptOAuth2ConsentRequest { - var deniedErr *RequestDeniedError +func NewHandledConsentRequest(challenge string, hasError bool, requestedAt time.Time, authenticatedAt sqlxx.NullTime) *flow.AcceptOAuth2ConsentRequest { + var deniedErr *flow.RequestDeniedError if hasError { - deniedErr = &RequestDeniedError{ + deniedErr = &flow.RequestDeniedError{ Name: "consent request denied", Description: "some description", Hint: "some hint", Code: 403, Debug: "some debug", - valid: true, + Valid: true, } } - return &AcceptOAuth2ConsentRequest{ + return &flow.AcceptOAuth2ConsentRequest{ ID: challenge, HandledAt: sqlxx.NullTime(time.Now().Round(time.Second)), Error: deniedErr, diff --git a/consent/manager.go b/consent/manager.go index 2910bcc9e40..69b62ed8b9e 100644 --- a/consent/manager.go +++ b/consent/manager.go @@ -10,6 +10,7 @@ import ( "github.com/gofrs/uuid" "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/flow" ) type ForcedObfuscatedLoginSession struct { @@ -19,44 +20,50 @@ type ForcedObfuscatedLoginSession struct { NID uuid.UUID `db:"nid"` } -func (_ ForcedObfuscatedLoginSession) TableName() string { +func (ForcedObfuscatedLoginSession) TableName() string { return "hydra_oauth2_obfuscated_authentication_session" } -type Manager interface { - CreateConsentRequest(ctx context.Context, req *OAuth2ConsentRequest) error - GetConsentRequest(ctx context.Context, challenge string) (*OAuth2ConsentRequest, error) - HandleConsentRequest(ctx context.Context, r *AcceptOAuth2ConsentRequest) (*OAuth2ConsentRequest, error) - RevokeSubjectConsentSession(ctx context.Context, user string) error - RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error - - VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*AcceptOAuth2ConsentRequest, error) - FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]AcceptOAuth2ConsentRequest, error) - FindSubjectsGrantedConsentRequests(ctx context.Context, user string, limit, offset int) ([]AcceptOAuth2ConsentRequest, error) - FindSubjectsSessionGrantedConsentRequests(ctx context.Context, user, sid string, limit, offset int) ([]AcceptOAuth2ConsentRequest, error) - CountSubjectsGrantedConsentRequests(ctx context.Context, user string) (int, error) - - // Cookie management - GetRememberedLoginSession(ctx context.Context, id string) (*LoginSession, error) - CreateLoginSession(ctx context.Context, session *LoginSession) error - DeleteLoginSession(ctx context.Context, id string) error - RevokeSubjectLoginSession(ctx context.Context, user string) error - ConfirmLoginSession(ctx context.Context, id string, authTime time.Time, subject string, remember bool) error - - CreateLoginRequest(ctx context.Context, req *LoginRequest) error - GetLoginRequest(ctx context.Context, challenge string) (*LoginRequest, error) - HandleLoginRequest(ctx context.Context, challenge string, r *HandledLoginRequest) (*LoginRequest, error) - VerifyAndInvalidateLoginRequest(ctx context.Context, verifier string) (*HandledLoginRequest, error) - - CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error - GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error) - - ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) - ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) - - CreateLogoutRequest(ctx context.Context, request *LogoutRequest) error - GetLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error) - AcceptLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error) - RejectLogoutRequest(ctx context.Context, challenge string) error - VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*LogoutRequest, error) -} +type ( + Manager interface { + CreateConsentRequest(ctx context.Context, f *flow.Flow, req *flow.OAuth2ConsentRequest) error + GetConsentRequest(ctx context.Context, challenge string) (*flow.OAuth2ConsentRequest, error) + HandleConsentRequest(ctx context.Context, f *flow.Flow, r *flow.AcceptOAuth2ConsentRequest) (*flow.OAuth2ConsentRequest, error) + RevokeSubjectConsentSession(ctx context.Context, user string) error + RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error + + VerifyAndInvalidateConsentRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.AcceptOAuth2ConsentRequest, error) + FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]flow.AcceptOAuth2ConsentRequest, error) + FindSubjectsGrantedConsentRequests(ctx context.Context, user string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error) + FindSubjectsSessionGrantedConsentRequests(ctx context.Context, user, sid string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error) + CountSubjectsGrantedConsentRequests(ctx context.Context, user string) (int, error) + + // Cookie management + GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error) + CreateLoginSession(ctx context.Context, session *flow.LoginSession) error + DeleteLoginSession(ctx context.Context, id string) error + RevokeSubjectLoginSession(ctx context.Context, user string) error + ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authTime time.Time, subject string, remember bool) error + + CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (*flow.Flow, error) + GetLoginRequest(ctx context.Context, challenge string) (*flow.LoginRequest, error) + HandleLoginRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledLoginRequest) (*flow.LoginRequest, error) + VerifyAndInvalidateLoginRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.HandledLoginRequest, error) + + CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error + GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error) + + ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) + ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) + + CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error + GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) + AcceptLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) + RejectLogoutRequest(ctx context.Context, challenge string) error + VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*flow.LogoutRequest, error) + } + + ManagerProvider interface { + ConsentManager() Manager + } +) diff --git a/consent/manager_test_helpers.go b/consent/manager_test_helpers.go index 084b9d4c4a4..2d84bf071d5 100644 --- a/consent/manager_test_helpers.go +++ b/consent/manager_test_helpers.go @@ -10,7 +10,10 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/aead" + "github.com/ory/hydra/v2/flow" "github.com/ory/x/assertx" + "github.com/ory/x/contextx" gofrsuuid "github.com/gofrs/uuid" "github.com/google/uuid" @@ -25,14 +28,14 @@ import ( "github.com/ory/hydra/v2/x" ) -func MockConsentRequest(key string, remember bool, rememberFor int, hasError bool, skip bool, authAt bool, loginChallengeBase string, network string) (c *OAuth2ConsentRequest, h *AcceptOAuth2ConsentRequest) { - c = &OAuth2ConsentRequest{ +func MockConsentRequest(key string, remember bool, rememberFor int, hasError bool, skip bool, authAt bool, loginChallengeBase string, network string) (c *flow.OAuth2ConsentRequest, h *flow.AcceptOAuth2ConsentRequest, f *flow.Flow) { + c = &flow.OAuth2ConsentRequest{ ID: makeID("challenge", network, key), RequestedScope: []string{"scopea" + key, "scopeb" + key}, RequestedAudience: []string{"auda" + key, "audb" + key}, Skip: skip, Subject: "subject" + key, - OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{ + OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{ ACRValues: []string{"1" + key, "2" + key}, UILocales: []string{"fr" + key, "de" + key}, Display: "popup" + key, @@ -46,19 +49,37 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo CSRF: "csrf" + key, ACR: "1", AuthenticatedAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Hour)), - RequestedAt: time.Now().UTC().Add(-time.Hour), + RequestedAt: time.Now().UTC(), Context: sqlxx.JSONRawMessage(`{"foo": "bar` + key + `"}`), } - var err *RequestDeniedError + f = &flow.Flow{ + ID: c.LoginChallenge.String(), + LoginVerifier: makeID("login-verifier", network, key), + SessionID: c.LoginSessionID, + Client: c.Client, + State: flow.FlowStateConsentInitialized, + ConsentChallengeID: sqlxx.NullString(c.ID), + ConsentSkip: c.Skip, + ConsentVerifier: sqlxx.NullString(c.Verifier), + ConsentCSRF: sqlxx.NullString(c.CSRF), + OpenIDConnectContext: c.OpenIDConnectContext, + Subject: c.Subject, + RequestedScope: c.RequestedScope, + RequestedAudience: c.RequestedAudience, + RequestURL: c.RequestURL, + RequestedAt: c.RequestedAt, + } + + var err *flow.RequestDeniedError if hasError { - err = &RequestDeniedError{ + err = &flow.RequestDeniedError{ Name: "error_name" + key, Description: "error_description" + key, Hint: "error_hint,omitempty" + key, Code: 100, Debug: "error_debug,omitempty" + key, - valid: true, + Valid: true, } } @@ -67,7 +88,7 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo authenticatedAt = sqlxx.NullTime(time.Now().UTC().Add(-time.Minute)) } - h = &AcceptOAuth2ConsentRequest{ + h = &flow.AcceptOAuth2ConsentRequest{ ConsentRequest: c, RememberFor: rememberFor, Remember: remember, @@ -81,17 +102,17 @@ func MockConsentRequest(key string, remember bool, rememberFor int, hasError boo // WasUsed: true, } - return c, h + return c, h, f } -func MockLogoutRequest(key string, withClient bool, network string) (c *LogoutRequest) { +func MockLogoutRequest(key string, withClient bool, network string) (c *flow.LogoutRequest) { var cl *client.Client if withClient { cl = &client.Client{ LegacyClientID: "fk-client-" + key, } } - return &LogoutRequest{ + return &flow.LogoutRequest{ Subject: "subject" + key, ID: makeID("challenge", network, key), Verifier: makeID("verifier", network, key), @@ -105,9 +126,9 @@ func MockLogoutRequest(key string, withClient bool, network string) (c *LogoutRe } } -func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest, h *HandledLoginRequest) { - c = &LoginRequest{ - OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{ +func MockAuthRequest(key string, authAt bool, network string) (c *flow.LoginRequest, h *flow.HandledLoginRequest, f *flow.Flow) { + c = &flow.LoginRequest{ + OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{ ACRValues: []string{"1" + key, "2" + key}, UILocales: []string{"fr" + key, "de" + key}, Display: "popup" + key, @@ -124,13 +145,15 @@ func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest, SessionID: sqlxx.NullString(makeID("fk-login-session", network, key)), } - var err = &RequestDeniedError{ + f = flow.NewFlow(c) + + var err = &flow.RequestDeniedError{ Name: "error_name" + key, Description: "error_description" + key, Hint: "error_hint,omitempty" + key, Code: 100, Debug: "error_debug,omitempty" + key, - valid: true, + Valid: true, } var authenticatedAt time.Time @@ -138,7 +161,7 @@ func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest, authenticatedAt = time.Now().UTC().Add(-time.Minute) } - h = &HandledLoginRequest{ + h = &flow.HandledLoginRequest{ LoginRequest: c, RememberFor: 120, Remember: true, @@ -152,23 +175,23 @@ func MockAuthRequest(key string, authAt bool, network string) (c *LoginRequest, WasHandled: false, } - return c, h + return c, h, f } -func SaneMockHandleConsentRequest(t *testing.T, m Manager, c *OAuth2ConsentRequest, authAt time.Time, rememberFor int, remember bool, hasError bool) *AcceptOAuth2ConsentRequest { - var rde *RequestDeniedError +func SaneMockHandleConsentRequest(t *testing.T, m Manager, f *flow.Flow, c *flow.OAuth2ConsentRequest, authAt time.Time, rememberFor int, remember bool, hasError bool) *flow.AcceptOAuth2ConsentRequest { + var rde *flow.RequestDeniedError if hasError { - rde = &RequestDeniedError{ + rde = &flow.RequestDeniedError{ Name: "error_name", Description: "error_description", Hint: "error_hint", Code: 100, Debug: "error_debug", - valid: true, + Valid: true, } } - h := &AcceptOAuth2ConsentRequest{ + h := &flow.AcceptOAuth2ConsentRequest{ ConsentRequest: c, RememberFor: rememberFor, Remember: remember, @@ -182,27 +205,28 @@ func SaneMockHandleConsentRequest(t *testing.T, m Manager, c *OAuth2ConsentReque HandledAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Minute)), } - _, err := m.HandleConsentRequest(context.Background(), h) + _, err := m.HandleConsentRequest(context.Background(), f, h) require.NoError(t, err) + return h } // SaneMockConsentRequest does the same thing as MockConsentRequest but uses less insanity and implicit dependencies. -func SaneMockConsentRequest(t *testing.T, m Manager, ar *LoginRequest, skip bool) (c *OAuth2ConsentRequest) { - c = &OAuth2ConsentRequest{ +func SaneMockConsentRequest(t *testing.T, m Manager, f *flow.Flow, skip bool) (c *flow.OAuth2ConsentRequest) { + c = &flow.OAuth2ConsentRequest{ RequestedScope: []string{"scopea", "scopeb"}, RequestedAudience: []string{"auda", "audb"}, Skip: skip, - Subject: ar.Subject, - OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{ + Subject: f.Subject, + OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{ ACRValues: []string{"1", "2"}, UILocales: []string{"fr", "de"}, Display: "popup", }, - Client: ar.Client, + Client: f.Client, RequestURL: "https://request-url/path", - LoginChallenge: sqlxx.NullString(ar.ID), - LoginSessionID: ar.SessionID, + LoginChallenge: sqlxx.NullString(f.ID), + LoginSessionID: f.SessionID, ForceSubjectIdentifier: "forced-subject", ACR: "1", AuthenticatedAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Hour)), @@ -214,14 +238,15 @@ func SaneMockConsentRequest(t *testing.T, m Manager, ar *LoginRequest, skip bool CSRF: uuid.New().String(), } - require.NoError(t, m.CreateConsentRequest(context.Background(), c)) + require.NoError(t, m.CreateConsentRequest(context.Background(), f, c)) + return c } // SaneMockAuthRequest does the same thing as MockAuthRequest but uses less insanity and implicit dependencies. -func SaneMockAuthRequest(t *testing.T, m Manager, ls *LoginSession, cl *client.Client) (c *LoginRequest) { - c = &LoginRequest{ - OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{ +func SaneMockAuthRequest(t *testing.T, m Manager, ls *flow.LoginSession, cl *client.Client) (c *flow.LoginRequest) { + c = &flow.LoginRequest{ + OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{ ACRValues: []string{"1", "2"}, UILocales: []string{"fr", "de"}, Display: "popup", @@ -238,7 +263,8 @@ func SaneMockAuthRequest(t *testing.T, m Manager, ls *LoginSession, cl *client.C ID: uuid.New().String(), Verifier: uuid.New().String(), } - require.NoError(t, m.CreateLoginRequest(context.Background(), c)) + _, err := m.CreateLoginRequest(context.Background(), c) + require.NoError(t, err) return c } @@ -246,20 +272,23 @@ func makeID(base string, network string, key string) string { return fmt.Sprintf("%s-%s-%s", base, network, key) } -func TestHelperNID(t1ClientManager client.Manager, t1ValidNID Manager, t2InvalidNID Manager) func(t *testing.T) { +func TestHelperNID(r interface { + client.ManagerProvider + FlowCipher() *aead.XChaCha20Poly1305 +}, t1ValidNID Manager, t2InvalidNID Manager) func(t *testing.T) { testClient := client.Client{LegacyClientID: "2022-03-11-client-nid-test-1"} - testLS := LoginSession{ + testLS := flow.LoginSession{ ID: "2022-03-11-ls-nid-test-1", Subject: "2022-03-11-test-1-sub", } - testLR := LoginRequest{ + testLR := flow.LoginRequest{ ID: "2022-03-11-lr-nid-test-1", Subject: "2022-03-11-test-1-sub", Verifier: "2022-03-11-test-1-ver", RequestedAt: time.Now(), Client: &client.Client{LegacyClientID: "2022-03-11-client-nid-test-1"}, } - testHLR := HandledLoginRequest{ + testHLR := flow.HandledLoginRequest{ LoginRequest: &testLR, RememberFor: 120, Remember: true, @@ -274,44 +303,58 @@ func TestHelperNID(t1ClientManager client.Manager, t1ValidNID Manager, t2Invalid } return func(t *testing.T) { - require.NoError(t, t1ClientManager.CreateClient(context.Background(), &testClient)) - require.Error(t, t2InvalidNID.CreateLoginSession(context.Background(), &testLS)) - require.NoError(t, t1ValidNID.CreateLoginSession(context.Background(), &testLS)) - require.Error(t, t2InvalidNID.CreateLoginRequest(context.Background(), &testLR)) - require.NoError(t, t1ValidNID.CreateLoginRequest(context.Background(), &testLR)) - _, err := t2InvalidNID.GetLoginRequest(context.Background(), testLR.ID) + ctx := context.Background() + require.NoError(t, r.ClientManager().CreateClient(ctx, &testClient)) + require.Error(t, t2InvalidNID.CreateLoginSession(ctx, &testLS)) + require.NoError(t, t1ValidNID.CreateLoginSession(ctx, &testLS)) + + _, err := t2InvalidNID.CreateLoginRequest(ctx, &testLR) + require.Error(t, err) + f, err := t1ValidNID.CreateLoginRequest(ctx, &testLR) + require.NoError(t, err) + + testLR.ID = x.Must(f.ToLoginChallenge(ctx, r)) + _, err = t2InvalidNID.GetLoginRequest(ctx, testLR.ID) require.Error(t, err) - _, err = t1ValidNID.GetLoginRequest(context.Background(), testLR.ID) + _, err = t1ValidNID.GetLoginRequest(ctx, testLR.ID) require.NoError(t, err) - _, err = t2InvalidNID.HandleLoginRequest(context.Background(), testLR.ID, &testHLR) + _, err = t2InvalidNID.HandleLoginRequest(ctx, f, testLR.ID, &testHLR) require.Error(t, err) - _, err = t1ValidNID.HandleLoginRequest(context.Background(), testLR.ID, &testHLR) + _, err = t1ValidNID.HandleLoginRequest(ctx, f, testLR.ID, &testHLR) require.NoError(t, err) - require.NoError(t, t2InvalidNID.ConfirmLoginSession(context.Background(), testLS.ID, time.Now(), testLS.Subject, true)) - require.NoError(t, t1ValidNID.ConfirmLoginSession(context.Background(), testLS.ID, time.Now(), testLS.Subject, true)) - require.Error(t, t2InvalidNID.DeleteLoginSession(context.Background(), testLS.ID)) - require.NoError(t, t1ValidNID.DeleteLoginSession(context.Background(), testLS.ID)) + require.Error(t, t2InvalidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true)) + require.NoError(t, t1ValidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true)) + require.Error(t, t2InvalidNID.DeleteLoginSession(ctx, testLS.ID)) + require.NoError(t, t1ValidNID.DeleteLoginSession(ctx, testLS.ID)) } } -func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) { - lr := make(map[string]*LoginRequest) +type Deps interface { + FlowCipher() *aead.XChaCha20Poly1305 + contextx.Provider +} + +func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) { + lr := make(map[string]*flow.LoginRequest) return func(t *testing.T) { if parallel { t.Parallel() } + ctx := context.Background() t.Run("case=init-fks", func(t *testing.T) { for _, k := range []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "rv1", "rv2"} { - require.NoError(t, clientManager.CreateClient(context.Background(), &client.Client{LegacyClientID: fmt.Sprintf("fk-client-%s", k)})) + require.NoError(t, clientManager.CreateClient(ctx, &client.Client{LegacyClientID: fmt.Sprintf("fk-client-%s", k)})) - require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + loginSession := &flow.LoginSession{ ID: makeID("fk-login-session", network, k), AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).UTC()), Subject: fmt.Sprintf("subject-%s", k), - })) + } + require.NoError(t, m.CreateLoginSession(ctx, loginSession)) + require.NoError(t, m.ConfirmLoginSession(ctx, loginSession, loginSession.ID, time.Now().Round(time.Second).UTC(), loginSession.Subject, true)) - lr[k] = &LoginRequest{ + lr[k] = &flow.LoginRequest{ ID: makeID("fk-login-challenge", network, k), Subject: fmt.Sprintf("subject%s", k), SessionID: sqlxx.NullString(makeID("fk-login-session", network, k)), @@ -321,23 +364,24 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit RequestedAt: time.Now(), } - require.NoError(t, m.CreateLoginRequest(context.Background(), lr[k])) + _, err := m.CreateLoginRequest(ctx, lr[k]) + require.NoError(t, err) } }) t.Run("case=auth-session", func(t *testing.T) { for _, tc := range []struct { - s LoginSession + s flow.LoginSession }{ { - s: LoginSession{ + s: flow.LoginSession{ ID: makeID("session", network, "1"), AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).Add(-time.Minute).UTC()), Subject: "subject1", }, }, { - s: LoginSession{ + s: flow.LoginSession{ ID: makeID("session", network, "2"), AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Minute).Add(-time.Minute).UTC()), Subject: "subject2", @@ -345,19 +389,19 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }, } { t.Run("case=create-get-"+tc.s.ID, func(t *testing.T) { - _, err := m.GetRememberedLoginSession(context.Background(), tc.s.ID) + _, err := m.GetRememberedLoginSession(ctx, &tc.s, tc.s.ID) require.EqualError(t, err, x.ErrNotFound.Error(), "%#v", err) - err = m.CreateLoginSession(context.Background(), &tc.s) + err = m.CreateLoginSession(ctx, &tc.s) require.NoError(t, err) - _, err = m.GetRememberedLoginSession(context.Background(), tc.s.ID) + _, err = m.GetRememberedLoginSession(ctx, &tc.s, tc.s.ID) require.EqualError(t, err, x.ErrNotFound.Error()) updatedAuth := time.Time(tc.s.AuthenticatedAt).Add(time.Second) - require.NoError(t, m.ConfirmLoginSession(context.Background(), tc.s.ID, updatedAuth, tc.s.Subject, true)) + require.NoError(t, m.ConfirmLoginSession(ctx, &tc.s, tc.s.ID, updatedAuth, tc.s.Subject, true)) - got, err := m.GetRememberedLoginSession(context.Background(), tc.s.ID) + got, err := m.GetRememberedLoginSession(ctx, nil, tc.s.ID) require.NoError(t, err) assert.EqualValues(t, tc.s.ID, got.ID) assert.Equal(t, updatedAuth.Unix(), time.Time(got.AuthenticatedAt).Unix()) // this was updated from confirm... @@ -365,9 +409,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit time.Sleep(time.Second) // Make sure AuthAt does not equal... updatedAuth2 := time.Now().Truncate(time.Second).UTC() - require.NoError(t, m.ConfirmLoginSession(context.Background(), tc.s.ID, updatedAuth2, "some-other-subject", true)) + require.NoError(t, m.ConfirmLoginSession(ctx, nil, tc.s.ID, updatedAuth2, "some-other-subject", true)) - got2, err := m.GetRememberedLoginSession(context.Background(), tc.s.ID) + got2, err := m.GetRememberedLoginSession(ctx, nil, tc.s.ID) require.NoError(t, err) assert.EqualValues(t, tc.s.ID, got2.ID) assert.Equal(t, updatedAuth2.Unix(), time.Time(got2.AuthenticatedAt).Unix()) // this was updated from confirm... @@ -385,10 +429,10 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }, } { t.Run("case=delete-get-"+tc.id, func(t *testing.T) { - err := m.DeleteLoginSession(context.Background(), tc.id) + err := m.DeleteLoginSession(ctx, tc.id) require.NoError(t, err) - _, err = m.GetRememberedLoginSession(context.Background(), tc.id) + _, err = m.GetRememberedLoginSession(ctx, nil, tc.id) require.Error(t, err) }) } @@ -408,32 +452,38 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit {"7", true}, } { t.Run("key="+tc.key, func(t *testing.T) { - c, h := MockAuthRequest(tc.key, tc.authAt, network) - _ = clientManager.CreateClient(context.Background(), c.Client) // Ignore errors that are caused by duplication + c, h, f := MockAuthRequest(tc.key, tc.authAt, network) + _ = clientManager.CreateClient(ctx, c.Client) // Ignore errors that are caused by duplication + loginChallenge := x.Must(f.ToLoginChallenge(ctx, deps)) - _, err := m.GetLoginRequest(context.Background(), makeID("challenge", network, tc.key)) + _, err := m.GetLoginRequest(ctx, loginChallenge) require.Error(t, err) - require.NoError(t, m.CreateLoginRequest(context.Background(), c)) + f, err = m.CreateLoginRequest(ctx, c) + require.NoError(t, err) + + loginChallenge = x.Must(f.ToLoginChallenge(ctx, deps)) - got1, err := m.GetLoginRequest(context.Background(), makeID("challenge", network, tc.key)) + got1, err := m.GetLoginRequest(ctx, loginChallenge) require.NoError(t, err) assert.False(t, got1.WasHandled) compareAuthenticationRequest(t, c, got1) - got1, err = m.HandleLoginRequest(context.Background(), makeID("challenge", network, tc.key), h) + got1, err = m.HandleLoginRequest(ctx, f, loginChallenge, h) require.NoError(t, err) compareAuthenticationRequest(t, c, got1) - got2, err := m.VerifyAndInvalidateLoginRequest(context.Background(), makeID("verifier", network, tc.key)) + loginVerifier := x.Must(f.ToLoginVerifier(ctx, deps)) + + got2, err := m.VerifyAndInvalidateLoginRequest(ctx, f, loginVerifier) require.NoError(t, err) compareAuthenticationRequest(t, c, got2.LoginRequest) - assert.Equal(t, c.ID, got2.ID) - _, err = m.VerifyAndInvalidateLoginRequest(context.Background(), makeID("verifier", network, tc.key)) + _, err = m.VerifyAndInvalidateLoginRequest(ctx, nil, loginVerifier) require.Error(t, err) - got1, err = m.GetLoginRequest(context.Background(), makeID("challenge", network, tc.key)) + loginChallenge = x.Must(f.ToLoginChallenge(ctx, deps)) + got1, err = m.GetLoginRequest(ctx, loginChallenge) require.NoError(t, err) assert.True(t, got1.WasHandled) }) @@ -458,39 +508,47 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit {"7", false, 0, false, false, false}, } { t.Run("key="+tc.key, func(t *testing.T) { - c, h := MockConsentRequest(tc.key, tc.remember, tc.rememberFor, tc.hasError, tc.skip, tc.authAt, "challenge", network) - _ = clientManager.CreateClient(context.Background(), c.Client) // Ignore errors that are caused by duplication + consentRequest, h, f := MockConsentRequest(tc.key, tc.remember, tc.rememberFor, tc.hasError, tc.skip, tc.authAt, "challenge", network) + _ = clientManager.CreateClient(ctx, consentRequest.Client) // Ignore errors that are caused by duplication + f.NID = deps.Contextualizer().Network(context.Background(), gofrsuuid.Nil) consentChallenge := makeID("challenge", network, tc.key) - _, err := m.GetConsentRequest(context.Background(), consentChallenge) + _, err := m.GetConsentRequest(ctx, consentChallenge) require.Error(t, err) - require.NoError(t, m.CreateConsentRequest(context.Background(), c)) + consentChallenge = x.Must(f.ToConsentChallenge(ctx, deps)) + consentRequest.ID = consentChallenge + + err = m.CreateConsentRequest(ctx, f, consentRequest) + require.NoError(t, err) - got1, err := m.GetConsentRequest(context.Background(), consentChallenge) + got1, err := m.GetConsentRequest(ctx, consentChallenge) require.NoError(t, err) - compareConsentRequest(t, c, got1) + compareConsentRequest(t, consentRequest, got1) assert.False(t, got1.WasHandled) - got1, err = m.HandleConsentRequest(context.Background(), h) + got1, err = m.HandleConsentRequest(ctx, f, h) require.NoError(t, err) assertx.TimeDifferenceLess(t, time.Now(), time.Time(h.HandledAt), 5) - compareConsentRequest(t, c, got1) + compareConsentRequest(t, consentRequest, got1) h.GrantedAudience = sqlxx.StringSliceJSONFormat{"new-audience"} - _, err = m.HandleConsentRequest(context.Background(), h) + _, err = m.HandleConsentRequest(ctx, f, h) require.NoError(t, err) - got2, err := m.VerifyAndInvalidateConsentRequest(context.Background(), makeID("verifier", network, tc.key)) + consentVerifier := x.Must(f.ToConsentVerifier(ctx, deps)) + + got2, err := m.VerifyAndInvalidateConsentRequest(ctx, f, consentVerifier) require.NoError(t, err) - compareConsentRequest(t, c, got2.ConsentRequest) - assert.Equal(t, c.ID, got2.ID) + consentRequest.ID = f.ConsentChallengeID.String() + compareConsentRequest(t, consentRequest, got2.ConsentRequest) + assert.Equal(t, consentRequest.ID, got2.ID) assert.Equal(t, h.GrantedAudience, got2.GrantedAudience) // Trying to update this again should return an error because the consent request was used. h.GrantedAudience = sqlxx.StringSliceJSONFormat{"new-audience", "new-audience-2"} - _, err = m.HandleConsentRequest(context.Background(), h) + _, err = m.HandleConsentRequest(ctx, f, h) require.Error(t, err) if tc.hasError { @@ -499,12 +557,14 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit assert.Equal(t, tc.remember, got2.Remember) assert.Equal(t, tc.rememberFor, got2.RememberFor) - _, err = m.VerifyAndInvalidateConsentRequest(context.Background(), makeID("verifier", network, tc.key)) + _, err = m.VerifyAndInvalidateConsentRequest(ctx, f, makeID("verifier", network, tc.key)) require.Error(t, err) - got1, err = m.GetConsentRequest(context.Background(), consentChallenge) - require.NoError(t, err) - assert.True(t, got1.WasHandled) + // Because we don't persist the flow any more, we can't check for this. + //got1, err = m.GetConsentRequest(ctx, consentChallenge) + //require.NoError(t, err) + //assert.True(t, got1.WasHandled) + }) } @@ -515,7 +575,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }{ {"1", "1", 1}, {"2", "2", 0}, - {"3", "3", 0}, + // {"3", "3", 0}, // Some consent is given in some other test case. Yay global fixtues :) {"4", "4", 0}, {"1", "2", 0}, {"2", "1", 0}, @@ -523,8 +583,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit {"6", "6", 0}, } { t.Run("key="+tc.keyC+"-"+tc.keyS, func(t *testing.T) { - rs, err := m.FindGrantedAndRememberedConsentRequests(context.Background(), "fk-client-"+tc.keyC, "subject"+tc.keyS) + rs, err := m.FindGrantedAndRememberedConsentRequests(ctx, "fk-client-"+tc.keyC, "subject"+tc.keyS) if tc.expectedLength == 0 { + assert.Nil(t, rs) assert.EqualError(t, err, ErrNoPreviousConsentFound.Error()) } else { require.NoError(t, err) @@ -535,19 +596,19 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }) t.Run("case=revoke-auth-request", func(t *testing.T) { - require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + require.NoError(t, m.CreateLoginSession(ctx, &flow.LoginSession{ ID: makeID("rev-session", network, "-1"), AuthenticatedAt: sqlxx.NullTime(time.Now()), Subject: "subject-1", })) - require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + require.NoError(t, m.CreateLoginSession(ctx, &flow.LoginSession{ ID: makeID("rev-session", network, "-2"), AuthenticatedAt: sqlxx.NullTime(time.Now()), Subject: "subject-2", })) - require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ + require.NoError(t, m.CreateLoginSession(ctx, &flow.LoginSession{ ID: makeID("rev-session", network, "-3"), AuthenticatedAt: sqlxx.NullTime(time.Now()), Subject: "subject-1", @@ -567,11 +628,11 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }, } { t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) { - require.NoError(t, m.RevokeSubjectLoginSession(context.Background(), tc.subject)) + require.NoError(t, m.RevokeSubjectLoginSession(ctx, tc.subject)) for _, id := range tc.ids { t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) { - _, err := m.GetRememberedLoginSession(context.Background(), id) + _, err := m.GetRememberedLoginSession(ctx, nil, id) assert.EqualError(t, err, x.ErrNotFound.Error()) }) } @@ -582,24 +643,49 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit challengerv1 := makeID("challenge", network, "rv1") challengerv2 := makeID("challenge", network, "rv2") t.Run("case=revoke-used-consent-request", func(t *testing.T) { - cr1, hcr1 := MockConsentRequest("rv1", false, 0, false, false, false, "fk-login-challenge", network) - cr2, hcr2 := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network) + cr1, hcr1, f1 := MockConsentRequest("rv1", false, 0, false, false, false, "fk-login-challenge", network) + cr2, hcr2, f2 := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network) + f1.NID = deps.Contextualizer().Network(context.Background(), gofrsuuid.Nil) + f2.NID = deps.Contextualizer().Network(context.Background(), gofrsuuid.Nil) // Ignore duplication errors - _ = clientManager.CreateClient(context.Background(), cr1.Client) - _ = clientManager.CreateClient(context.Background(), cr2.Client) + _ = clientManager.CreateClient(ctx, cr1.Client) + _ = clientManager.CreateClient(ctx, cr2.Client) + + err := m.CreateConsentRequest(ctx, f1, cr1) + require.NoError(t, err) + err = m.CreateConsentRequest(ctx, f2, cr2) + require.NoError(t, err) + _, err = m.HandleConsentRequest(ctx, f1, hcr1) + require.NoError(t, err) + _, err = m.HandleConsentRequest(ctx, f2, hcr2) + require.NoError(t, err) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr1)) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr2)) - _, err := m.HandleConsentRequest(context.Background(), hcr1) + _, err = m.VerifyAndInvalidateConsentRequest(ctx, f1, x.Must(f1.ToConsentVerifier(ctx, deps))) require.NoError(t, err) - _, err = m.HandleConsentRequest(context.Background(), hcr2) + _, err = m.VerifyAndInvalidateConsentRequest(ctx, f2, x.Must(f2.ToConsentVerifier(ctx, deps))) require.NoError(t, err) - require.NoError(t, fositeManager.CreateAccessTokenSession(context.Background(), makeID("", network, "trva1"), &fosite.Request{Client: cr1.Client, ID: challengerv1, RequestedAt: time.Now()})) - require.NoError(t, fositeManager.CreateRefreshTokenSession(context.Background(), makeID("", network, "rrva1"), &fosite.Request{Client: cr1.Client, ID: challengerv1, RequestedAt: time.Now()})) - require.NoError(t, fositeManager.CreateAccessTokenSession(context.Background(), makeID("", network, "trva2"), &fosite.Request{Client: cr2.Client, ID: challengerv2, RequestedAt: time.Now()})) - require.NoError(t, fositeManager.CreateRefreshTokenSession(context.Background(), makeID("", network, "rrva2"), &fosite.Request{Client: cr2.Client, ID: challengerv2, RequestedAt: time.Now()})) + require.NoError(t, fositeManager.CreateAccessTokenSession( + ctx, + makeID("", network, "trva1"), + &fosite.Request{Client: cr1.Client, ID: f1.ConsentChallengeID.String(), RequestedAt: time.Now()}, + )) + require.NoError(t, fositeManager.CreateRefreshTokenSession( + ctx, + makeID("", network, "rrva1"), + &fosite.Request{Client: cr1.Client, ID: f1.ConsentChallengeID.String(), RequestedAt: time.Now()}, + )) + require.NoError(t, fositeManager.CreateAccessTokenSession( + ctx, + makeID("", network, "trva2"), + &fosite.Request{Client: cr2.Client, ID: f2.ConsentChallengeID.String(), RequestedAt: time.Now()}, + )) + require.NoError(t, fositeManager.CreateRefreshTokenSession( + ctx, + makeID("", network, "rrva2"), + &fosite.Request{Client: cr2.Client, ID: f2.ConsentChallengeID.String(), RequestedAt: time.Now()}, + )) for i, tc := range []struct { subject string @@ -609,64 +695,74 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit ids []string }{ { - at: makeID("", network, "trva1"), rt: makeID("", network, "rrva1"), + at: makeID("", network, "trva1"), + rt: makeID("", network, "rrva1"), subject: "subjectrv1", client: "", ids: []string{challengerv1}, }, { - at: makeID("", network, "trva2"), rt: makeID("", network, "rrva2"), + at: makeID("", network, "trva2"), + rt: makeID("", network, "rrva2"), subject: "subjectrv2", client: "fk-client-rv2", ids: []string{challengerv2}, }, } { t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) { - _, err := fositeManager.GetAccessTokenSession(context.Background(), tc.at, nil) + _, err := fositeManager.GetAccessTokenSession(ctx, tc.at, nil) assert.NoError(t, err) - _, err = fositeManager.GetRefreshTokenSession(context.Background(), tc.rt, nil) + _, err = fositeManager.GetRefreshTokenSession(ctx, tc.rt, nil) assert.NoError(t, err) if tc.client == "" { - require.NoError(t, m.RevokeSubjectConsentSession(context.Background(), tc.subject)) + require.NoError(t, m.RevokeSubjectConsentSession(ctx, tc.subject)) } else { - require.NoError(t, m.RevokeSubjectClientConsentSession(context.Background(), tc.subject, tc.client)) + require.NoError(t, m.RevokeSubjectClientConsentSession(ctx, tc.subject, tc.client)) } for _, id := range tc.ids { t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) { - _, err := m.GetConsentRequest(context.Background(), id) + _, err := m.GetConsentRequest(ctx, id) assert.True(t, errors.Is(err, x.ErrNotFound)) }) } - r, err := fositeManager.GetAccessTokenSession(context.Background(), tc.at, nil) + r, err := fositeManager.GetAccessTokenSession(ctx, tc.at, nil) assert.Error(t, err, "%+v", r) - r, err = fositeManager.GetRefreshTokenSession(context.Background(), tc.rt, nil) + r, err = fositeManager.GetRefreshTokenSession(ctx, tc.rt, nil) assert.Error(t, err, "%+v", r) }) } - require.EqualError(t, m.RevokeSubjectConsentSession(context.Background(), "i-do-not-exist"), x.ErrNotFound.Error()) - require.EqualError(t, m.RevokeSubjectClientConsentSession(context.Background(), "i-do-not-exist", "i-do-not-exist"), x.ErrNotFound.Error()) + require.EqualError(t, m.RevokeSubjectConsentSession(ctx, "i-do-not-exist"), x.ErrNotFound.Error()) + require.EqualError(t, m.RevokeSubjectClientConsentSession(ctx, "i-do-not-exist", "i-do-not-exist"), x.ErrNotFound.Error()) }) t.Run("case=list-used-consent-requests", func(t *testing.T) { - require.NoError(t, m.CreateLoginRequest(context.Background(), lr["rv1"])) - require.NoError(t, m.CreateLoginRequest(context.Background(), lr["rv2"])) + f1, err := m.CreateLoginRequest(ctx, lr["rv1"]) + require.NoError(t, err) + f2, err := m.CreateLoginRequest(ctx, lr["rv2"]) + require.NoError(t, err) - cr1, hcr1 := MockConsentRequest("rv1", true, 0, false, false, false, "fk-login-challenge", network) - cr2, hcr2 := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network) + cr1, hcr1, _ := MockConsentRequest("rv1", true, 0, false, false, false, "fk-login-challenge", network) + cr2, hcr2, _ := MockConsentRequest("rv2", false, 0, false, false, false, "fk-login-challenge", network) // Ignore duplicate errors - _ = clientManager.CreateClient(context.Background(), cr1.Client) - _ = clientManager.CreateClient(context.Background(), cr2.Client) + _ = clientManager.CreateClient(ctx, cr1.Client) + _ = clientManager.CreateClient(ctx, cr2.Client) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr1)) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr2)) - _, err := m.HandleConsentRequest(context.Background(), hcr1) + err = m.CreateConsentRequest(ctx, f1, cr1) + require.NoError(t, err) + err = m.CreateConsentRequest(ctx, f2, cr2) + require.NoError(t, err) + _, err = m.HandleConsentRequest(ctx, f1, hcr1) + require.NoError(t, err) + _, err = m.HandleConsentRequest(ctx, f2, hcr2) + require.NoError(t, err) + handledConsentRequest1, err := m.VerifyAndInvalidateConsentRequest(ctx, f1, x.Must(f1.ToConsentVerifier(ctx, deps))) require.NoError(t, err) - _, err = m.HandleConsentRequest(context.Background(), hcr2) + handledConsentRequest2, err := m.VerifyAndInvalidateConsentRequest(ctx, f2, x.Must(f2.ToConsentVerifier(ctx, deps))) require.NoError(t, err) for i, tc := range []struct { @@ -678,13 +774,13 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit { subject: cr1.Subject, sid: makeID("fk-login-session", network, "rv1"), - challenges: []string{challengerv1}, + challenges: []string{handledConsentRequest1.ID}, clients: []string{"fk-client-rv1"}, }, { subject: cr2.Subject, sid: makeID("fk-login-session", network, "rv2"), - challenges: []string{challengerv2}, + challenges: []string{handledConsentRequest2.ID}, clients: []string{"fk-client-rv2"}, }, { @@ -695,7 +791,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }, } { t.Run(fmt.Sprintf("case=%d/subject=%s/session=%s", i, tc.subject, tc.sid), func(t *testing.T) { - consents, err := m.FindSubjectsSessionGrantedConsentRequests(context.Background(), tc.subject, tc.sid, 100, 0) + consents, err := m.FindSubjectsSessionGrantedConsentRequests(ctx, tc.subject, tc.sid, 100, 0) assert.Equal(t, len(tc.challenges), len(consents)) if len(tc.challenges) == 0 { @@ -708,7 +804,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } } - n, err := m.CountSubjectsGrantedConsentRequests(context.Background(), tc.subject) + n, err := m.CountSubjectsGrantedConsentRequests(ctx, tc.subject) require.NoError(t, err) assert.Equal(t, n, len(tc.challenges)) @@ -722,12 +818,12 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }{ { subject: "subjectrv1", - challenges: []string{challengerv1}, + challenges: []string{handledConsentRequest1.ID}, clients: []string{"fk-client-rv1"}, }, { subject: "subjectrv2", - challenges: []string{challengerv2}, + challenges: []string{handledConsentRequest2.ID}, clients: []string{"fk-client-rv2"}, }, { @@ -737,7 +833,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit }, } { t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) { - consents, err := m.FindSubjectsGrantedConsentRequests(context.Background(), tc.subject, 100, 0) + consents, err := m.FindSubjectsGrantedConsentRequests(ctx, tc.subject, 100, 0) assert.Equal(t, len(tc.challenges), len(consents)) if len(tc.challenges) == 0 { @@ -750,7 +846,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } } - n, err := m.CountSubjectsGrantedConsentRequests(context.Background(), tc.subject) + n, err := m.CountSubjectsGrantedConsentRequests(ctx, tc.subject) require.NoError(t, err) assert.Equal(t, n, len(tc.challenges)) @@ -758,7 +854,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } t.Run("case=obfuscated", func(t *testing.T) { - _, err := m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-1") + _, err := m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-1") require.True(t, errors.Is(err, x.ErrNotFound)) expect := &ForcedObfuscatedLoginSession{ @@ -766,9 +862,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit Subject: "subject-1", SubjectObfuscated: "obfuscated-1", } - require.NoError(t, m.CreateForcedObfuscatedLoginSession(context.Background(), expect)) + require.NoError(t, m.CreateForcedObfuscatedLoginSession(ctx, expect)) - got, err := m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-1") + got, err := m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-1") require.NoError(t, err) require.NotEqual(t, got.NID, gofrsuuid.Nil) got.NID = gofrsuuid.Nil @@ -779,15 +875,15 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit Subject: "subject-1", SubjectObfuscated: "obfuscated-2", } - require.NoError(t, m.CreateForcedObfuscatedLoginSession(context.Background(), expect)) + require.NoError(t, m.CreateForcedObfuscatedLoginSession(ctx, expect)) - got, err = m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-2") + got, err = m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-2") require.NotEqual(t, got.NID, gofrsuuid.Nil) got.NID = gofrsuuid.Nil require.NoError(t, err) assert.EqualValues(t, expect, got) - _, err = m.GetForcedObfuscatedLoginSession(context.Background(), "fk-client-1", "obfuscated-1") + _, err = m.GetForcedObfuscatedLoginSession(ctx, "fk-client-1", "obfuscated-1") require.True(t, errors.Is(err, x.ErrNotFound)) }) @@ -800,19 +896,20 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit subjects[k] = fmt.Sprintf("subject-ListUserAuthenticatedClientsWithFrontAndBackChannelLogout-%d", k) } - sessions := make([]LoginSession, len(subjects)*1) + sessions := make([]flow.LoginSession, len(subjects)*1) frontChannels := map[string][]client.Client{} backChannels := map[string][]client.Client{} for k := range sessions { id := uuid.New().String() subject := subjects[k%len(subjects)] t.Run(fmt.Sprintf("create/session=%s/subject=%s", id, subject), func(t *testing.T) { - ls := &LoginSession{ + ls := &flow.LoginSession{ ID: id, AuthenticatedAt: sqlxx.NullTime(time.Now()), Subject: subject, } - require.NoError(t, m.CreateLoginSession(context.Background(), ls)) + require.NoError(t, m.CreateLoginSession(ctx, ls)) + require.NoError(t, m.ConfirmLoginSession(ctx, ls, ls.ID, time.Now(), ls.Subject, true)) cl := &client.Client{LegacyClientID: uuid.New().String()} switch k % 4 { @@ -828,11 +925,15 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit frontChannels[id] = append(frontChannels[id], *cl) backChannels[id] = append(backChannels[id], *cl) } - require.NoError(t, clientManager.CreateClient(context.Background(), cl)) + require.NoError(t, clientManager.CreateClient(ctx, cl)) ar := SaneMockAuthRequest(t, m, ls, cl) - cr := SaneMockConsentRequest(t, m, ar, false) - _ = SaneMockHandleConsentRequest(t, m, cr, time.Time{}, 0, false, false) + f := flow.NewFlow(ar) + f.NID = deps.Contextualizer().Network(ctx, gofrsuuid.Nil) + cr := SaneMockConsentRequest(t, m, f, false) + _ = SaneMockHandleConsentRequest(t, m, f, cr, time.Time{}, 0, false, false) + _, err = m.VerifyAndInvalidateConsentRequest(ctx, f, x.Must(f.ToConsentVerifier(ctx, deps))) + require.NoError(t, err) sessions[k] = *ls }) @@ -862,13 +963,13 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit } t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) { - actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID) + actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, ls.Subject, ls.ID) require.NoError(t, err) check(t, frontChannels, actual) }) t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) { - actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID) + actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(ctx, ls.Subject, ls.ID) require.NoError(t, err) check(t, backChannels, actual) }) @@ -893,42 +994,42 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit verifier := makeID("verifier", network, tc.key) c := MockLogoutRequest(tc.key, tc.withClient, network) if tc.withClient { - require.NoError(t, clientManager.CreateClient(context.Background(), c.Client)) // Ignore errors that are caused by duplication + require.NoError(t, clientManager.CreateClient(ctx, c.Client)) // Ignore errors that are caused by duplication } - _, err := m.GetLogoutRequest(context.Background(), challenge) + _, err := m.GetLogoutRequest(ctx, challenge) require.Error(t, err) - require.NoError(t, m.CreateLogoutRequest(context.Background(), c)) + require.NoError(t, m.CreateLogoutRequest(ctx, c)) - got2, err := m.GetLogoutRequest(context.Background(), challenge) + got2, err := m.GetLogoutRequest(ctx, challenge) require.NoError(t, err) assert.False(t, got2.WasHandled) assert.False(t, got2.Accepted) compareLogoutRequest(t, c, got2) if k%2 == 0 { - got2, err = m.AcceptLogoutRequest(context.Background(), challenge) + got2, err = m.AcceptLogoutRequest(ctx, challenge) require.NoError(t, err) assert.True(t, got2.Accepted) compareLogoutRequest(t, c, got2) - got3, err := m.VerifyAndInvalidateLogoutRequest(context.Background(), verifier) + got3, err := m.VerifyAndInvalidateLogoutRequest(ctx, verifier) require.NoError(t, err) assert.True(t, got3.Accepted) assert.True(t, got3.WasHandled) compareLogoutRequest(t, c, got3) - _, err = m.VerifyAndInvalidateLogoutRequest(context.Background(), verifier) + _, err = m.VerifyAndInvalidateLogoutRequest(ctx, verifier) require.Error(t, err) - got2, err = m.GetLogoutRequest(context.Background(), challenge) + got2, err = m.GetLogoutRequest(ctx, challenge) require.NoError(t, err) compareLogoutRequest(t, got3, got2) assert.True(t, got2.WasHandled) } else { - require.NoError(t, m.RejectLogoutRequest(context.Background(), challenge)) - _, err = m.GetLogoutRequest(context.Background(), challenge) + require.NoError(t, m.RejectLogoutRequest(ctx, challenge)) + _, err = m.GetLogoutRequest(ctx, challenge) require.Error(t, err) } }) @@ -938,19 +1039,19 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit t.Run("case=foreign key regression", func(t *testing.T) { cl := &client.Client{LegacyClientID: uuid.New().String()} - require.NoError(t, clientManager.CreateClient(context.Background(), cl)) + require.NoError(t, clientManager.CreateClient(ctx, cl)) subject := uuid.New().String() - s := LoginSession{ + s := flow.LoginSession{ ID: uuid.New().String(), AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Minute).Add(-time.Minute).UTC()), Subject: subject, } - err := m.CreateLoginSession(context.Background(), &s) - require.NoError(t, err) + require.NoError(t, m.CreateLoginSession(ctx, &s)) + require.NoError(t, m.ConfirmLoginSession(ctx, &s, s.ID, time.Time(s.AuthenticatedAt), s.Subject, false)) - lr := &LoginRequest{ + lr := &flow.LoginRequest{ ID: uuid.New().String(), Subject: uuid.New().String(), Verifier: uuid.New().String(), @@ -960,9 +1061,10 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit SessionID: sqlxx.NullString(s.ID), } - require.NoError(t, m.CreateLoginRequest(context.Background(), lr)) - expected := &OAuth2ConsentRequest{ - ID: uuid.New().String(), + f, err := m.CreateLoginRequest(ctx, lr) + require.NoError(t, err) + expected := &flow.OAuth2ConsentRequest{ + ID: x.Must(f.ToConsentChallenge(ctx, deps)), Skip: true, Subject: subject, OpenIDConnectContext: nil, @@ -974,22 +1076,23 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit Verifier: uuid.New().String(), CSRF: uuid.New().String(), } - require.NoError(t, m.CreateConsentRequest(context.Background(), expected)) + err = m.CreateConsentRequest(ctx, f, expected) + require.NoError(t, err) - result, err := m.GetConsentRequest(context.Background(), expected.ID) + result, err := m.GetConsentRequest(ctx, expected.ID) require.NoError(t, err) assert.EqualValues(t, expected.ID, result.ID) - require.NoError(t, m.DeleteLoginSession(context.Background(), s.ID)) + require.NoError(t, m.DeleteLoginSession(ctx, s.ID)) - result, err = m.GetConsentRequest(context.Background(), expected.ID) + result, err = m.GetConsentRequest(ctx, expected.ID) require.NoError(t, err) assert.EqualValues(t, expected.ID, result.ID) }) } } -func compareLogoutRequest(t *testing.T, a, b *LogoutRequest) { +func compareLogoutRequest(t *testing.T, a, b *flow.LogoutRequest) { require.True(t, (a.Client != nil && b.Client != nil) || (a.Client == nil && b.Client == nil)) if a.Client != nil { assert.EqualValues(t, a.Client.GetID(), b.Client.GetID()) @@ -1004,9 +1107,8 @@ func compareLogoutRequest(t *testing.T, a, b *LogoutRequest) { assert.EqualValues(t, a.SessionID, b.SessionID) } -func compareAuthenticationRequest(t *testing.T, a, b *LoginRequest) { +func compareAuthenticationRequest(t *testing.T, a, b *flow.LoginRequest) { assert.EqualValues(t, a.Client.GetID(), b.Client.GetID()) - assert.EqualValues(t, a.ID, b.ID) assert.EqualValues(t, *a.OpenIDConnectContext, *b.OpenIDConnectContext) assert.EqualValues(t, a.Subject, b.Subject) assert.EqualValues(t, a.RequestedScope, b.RequestedScope) @@ -1017,7 +1119,7 @@ func compareAuthenticationRequest(t *testing.T, a, b *LoginRequest) { assert.EqualValues(t, a.SessionID, b.SessionID) } -func compareConsentRequest(t *testing.T, a, b *OAuth2ConsentRequest) { +func compareConsentRequest(t *testing.T, a, b *flow.OAuth2ConsentRequest) { assert.EqualValues(t, a.Client.GetID(), b.Client.GetID()) assert.EqualValues(t, a.ID, b.ID) assert.EqualValues(t, *a.OpenIDConnectContext, *b.OpenIDConnectContext) diff --git a/consent/registry.go b/consent/registry.go index b43e50bec32..447e345ee5b 100644 --- a/consent/registry.go +++ b/consent/registry.go @@ -7,6 +7,7 @@ import ( "context" "github.com/ory/fosite/handler/openid" + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/x" ) @@ -19,6 +20,7 @@ type InternalRegistry interface { Registry client.Registry + FlowCipher() *aead.XChaCha20Poly1305 OAuth2Storage() x.FositeStorer OpenIDConnectRequestValidator() *openid.OpenIDConnectRequestValidator } diff --git a/consent/sdk_test.go b/consent/sdk_test.go index 8306a2a2e5c..c15e8d5df96 100644 --- a/consent/sdk_test.go +++ b/consent/sdk_test.go @@ -6,11 +6,13 @@ package consent_test import ( "context" "fmt" + "net/http" "net/http/httptest" "testing" "time" hydra "github.com/ory/hydra-client-go/v2" + . "github.com/ory/hydra/v2/flow" "github.com/ory/x/httprouterx" @@ -36,6 +38,10 @@ func TestSDK(t *testing.T) { conf.MustSet(ctx, config.KeyAccessTokenLifespan, time.Minute) reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) + consentChallenge := func(f *Flow) string { return x.Must(f.ToConsentChallenge(ctx, reg)) } + consentVerifier := func(f *Flow) string { return x.Must(f.ToConsentVerifier(ctx, reg)) } + loginChallenge := func(f *Flow) string { return x.Must(f.ToLoginChallenge(ctx, reg)) } + router := x.NewRouterPublic() h := NewHandler(reg, conf) @@ -52,10 +58,8 @@ func TestSDK(t *testing.T) { Subject: "subject1", })) - ar1, _ := MockAuthRequest("ar-1", false, network) - ar2, _ := MockAuthRequest("ar-2", false, network) - require.NoError(t, reg.ClientManager().CreateClient(context.Background(), ar1.Client)) - require.NoError(t, reg.ClientManager().CreateClient(context.Background(), ar2.Client)) + ar1, _, _ := MockAuthRequest("1", false, network) + ar2, _, _ := MockAuthRequest("2", false, network) require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ ID: ar1.SessionID.String(), Subject: ar1.Subject, @@ -64,34 +68,80 @@ func TestSDK(t *testing.T) { ID: ar2.SessionID.String(), Subject: ar2.Subject, })) - require.NoError(t, m.CreateLoginRequest(context.Background(), ar1)) - require.NoError(t, m.CreateLoginRequest(context.Background(), ar2)) + _, err := m.CreateLoginRequest(context.Background(), ar1) + require.NoError(t, err) + _, err = m.CreateLoginRequest(context.Background(), ar2) + require.NoError(t, err) - cr1, hcr1 := MockConsentRequest("1", false, 0, false, false, false, "fk-login-challenge", network) - cr2, hcr2 := MockConsentRequest("2", false, 0, false, false, false, "fk-login-challenge", network) - cr3, hcr3 := MockConsentRequest("3", true, 3600, false, false, false, "fk-login-challenge", network) - cr4, hcr4 := MockConsentRequest("4", true, 3600, false, false, false, "fk-login-challenge", network) + cr1, hcr1, _ := MockConsentRequest("1", false, 0, false, false, false, "fk-login-challenge", network) + cr2, hcr2, _ := MockConsentRequest("2", false, 0, false, false, false, "fk-login-challenge", network) + cr3, hcr3, _ := MockConsentRequest("3", true, 3600, false, false, false, "fk-login-challenge", network) + cr4, hcr4, _ := MockConsentRequest("4", true, 3600, false, false, false, "fk-login-challenge", network) require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr1.Client)) require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr2.Client)) require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr3.Client)) require.NoError(t, reg.ClientManager().CreateClient(context.Background(), cr4.Client)) - require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr1.LoginChallenge.String(), Subject: cr1.Subject, Client: cr1.Client, Verifier: cr1.ID})) - require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr2.LoginChallenge.String(), Subject: cr2.Subject, Client: cr2.Client, Verifier: cr2.ID})) - require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ID: cr3.LoginSessionID.String()})) - require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr3.LoginChallenge.String(), Subject: cr3.Subject, Client: cr3.Client, Verifier: cr3.ID, RequestedAt: hcr3.RequestedAt, SessionID: cr3.LoginSessionID})) - require.NoError(t, m.CreateLoginSession(context.Background(), &LoginSession{ID: cr4.LoginSessionID.String()})) - require.NoError(t, m.CreateLoginRequest(context.Background(), &LoginRequest{ID: cr4.LoginChallenge.String(), Client: cr4.Client, Verifier: cr4.ID, SessionID: cr4.LoginSessionID})) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr1)) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr2)) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr3)) - require.NoError(t, m.CreateConsentRequest(context.Background(), cr4)) - _, err := m.HandleConsentRequest(context.Background(), hcr1) + + cr1Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{ + ID: cr1.LoginChallenge.String(), + Subject: cr1.Subject, + Client: cr1.Client, + Verifier: cr1.ID, + RequestedAt: time.Now(), + }) + require.NoError(t, err) + cr1Flow.LoginSkip = ar1.Skip + + cr2Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{ + ID: cr2.LoginChallenge.String(), + Subject: cr2.Subject, + Client: cr2.Client, + Verifier: cr2.ID, + RequestedAt: time.Now(), + }) + require.NoError(t, err) + cr2Flow.LoginSkip = ar2.Skip + + loginSession3 := &LoginSession{ID: cr3.LoginSessionID.String()} + require.NoError(t, m.CreateLoginSession(context.Background(), loginSession3)) + require.NoError(t, m.ConfirmLoginSession(context.Background(), loginSession3, loginSession3.ID, time.Now(), cr3.Subject, true)) + cr3Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{ + ID: cr3.LoginChallenge.String(), + Subject: cr3.Subject, + Client: cr3.Client, + Verifier: cr3.ID, + RequestedAt: hcr3.RequestedAt, + SessionID: cr3.LoginSessionID, + }) + require.NoError(t, err) + + loginSession4 := &LoginSession{ID: cr4.LoginSessionID.String()} + require.NoError(t, m.CreateLoginSession(context.Background(), loginSession4)) + require.NoError(t, m.ConfirmLoginSession(context.Background(), loginSession4, loginSession4.ID, time.Now(), cr4.Subject, true)) + cr4Flow, err := m.CreateLoginRequest(context.Background(), &LoginRequest{ + ID: cr4.LoginChallenge.String(), + Client: cr4.Client, + Verifier: cr4.ID, + SessionID: cr4.LoginSessionID, + }) + require.NoError(t, err) + + require.NoError(t, m.CreateConsentRequest(context.Background(), cr1Flow, cr1)) + require.NoError(t, m.CreateConsentRequest(context.Background(), cr2Flow, cr2)) + require.NoError(t, m.CreateConsentRequest(context.Background(), cr3Flow, cr3)) + require.NoError(t, m.CreateConsentRequest(context.Background(), cr4Flow, cr4)) + _, err = m.HandleConsentRequest(context.Background(), cr1Flow, hcr1) + require.NoError(t, err) + _, err = m.HandleConsentRequest(context.Background(), cr2Flow, hcr2) require.NoError(t, err) - _, err = m.HandleConsentRequest(context.Background(), hcr2) + _, err = m.HandleConsentRequest(context.Background(), cr3Flow, hcr3) require.NoError(t, err) - _, err = m.HandleConsentRequest(context.Background(), hcr3) + _, err = m.HandleConsentRequest(context.Background(), cr4Flow, hcr4) require.NoError(t, err) - _, err = m.HandleConsentRequest(context.Background(), hcr4) + + _, err = m.VerifyAndInvalidateConsentRequest(context.Background(), cr3Flow, consentVerifier(cr3Flow)) + require.NoError(t, err) + _, err = m.VerifyAndInvalidateConsentRequest(context.Background(), cr4Flow, consentVerifier(cr4Flow)) require.NoError(t, err) lur1 := MockLogoutRequest("testsdk-1", true, network) @@ -101,19 +151,20 @@ func TestSDK(t *testing.T) { lur2 := MockLogoutRequest("testsdk-2", false, network) require.NoError(t, m.CreateLogoutRequest(context.Background(), lur2)) - crGot, _, err := sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "1")).Execute() - require.NoError(t, err) + cr1.ID = consentChallenge(cr1Flow) + crGot := execute[hydra.OAuth2ConsentRequest](t, sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(cr1.ID)) compareSDKConsentRequest(t, cr1, *crGot) - crGot, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "2")).Execute() - require.NoError(t, err) + cr2.ID = consentChallenge(cr2Flow) + crGot = execute[hydra.OAuth2ConsentRequest](t, sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(cr2.ID)) compareSDKConsentRequest(t, cr2, *crGot) - arGot, _, err := sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(makeID("challenge", network, "ar-1")).Execute() - require.NoError(t, err) + ar1.ID = loginChallenge(cr1Flow) + arGot := execute[hydra.OAuth2LoginRequest](t, sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(ar1.ID)) compareSDKLoginRequest(t, ar1, *arGot) - arGot, _, err = sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(makeID("challenge", network, "ar-2")).Execute() + ar2.ID = loginChallenge(cr2Flow) + arGot = execute[hydra.OAuth2LoginRequest](t, sdk.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(ar2.ID)) require.NoError(t, err) compareSDKLoginRequest(t, ar2, *arGot) @@ -132,7 +183,8 @@ func TestSDK(t *testing.T) { _, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "1")).Execute() require.Error(t, err) - crGot, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(makeID("challenge", network, "2")).Execute() + cr2.ID = consentChallenge(cr2Flow) + crGot, _, err = sdk.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(cr2.ID).Execute() require.NoError(t, err) compareSDKConsentRequest(t, cr2, *crGot) @@ -145,8 +197,6 @@ func TestSDK(t *testing.T) { csGot, _, err := sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject3").Execute() require.NoError(t, err) assert.Equal(t, 1, len(csGot)) - cs := csGot[0] - assert.Equal(t, makeID("challenge", network, "3"), cs.ConsentRequest.Challenge) csGot, _, err = sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject2").Execute() require.NoError(t, err) @@ -155,8 +205,6 @@ func TestSDK(t *testing.T) { csGot, _, err = sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject3").LoginSessionId("fk-login-session-t1-3").Execute() require.NoError(t, err) assert.Equal(t, 1, len(csGot)) - cs = csGot[0] - assert.Equal(t, makeID("challenge", network, "3"), cs.ConsentRequest.Challenge) csGot, _, err = sdk.OAuth2Api.ListOAuth2ConsentSessions(ctx).Subject("subject3").LoginSessionId("fk-login-session-t1-X").Execute() require.NoError(t, err) @@ -198,3 +246,15 @@ func compareSDKLogoutRequest(t *testing.T, expected *LogoutRequest, got *hydra.O assert.EqualValues(t, expected.RequestURL, *got.RequestUrl) assert.EqualValues(t, expected.RPInitiated, *got.RpInitiated) } + +type executer[T any] interface { + Execute() (*T, *http.Response, error) +} + +func execute[T any](t *testing.T, e executer[T]) *T { + got, res, err := e.Execute() + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + + return got +} diff --git a/consent/strategy.go b/consent/strategy.go index 9d31b3de4b1..08e8788c756 100644 --- a/consent/strategy.go +++ b/consent/strategy.go @@ -8,13 +8,19 @@ import ( "net/http" "github.com/ory/fosite" + "github.com/ory/hydra/v2/flow" ) var _ Strategy = new(DefaultStrategy) type Strategy interface { - HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*AcceptOAuth2ConsentRequest, error) - HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) + HandleOAuth2AuthorizationRequest( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + req fosite.AuthorizeRequester, + ) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) + HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) HandleHeadlessLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, sid string) error ObfuscateSubjectIdentifier(ctx context.Context, cl fosite.Client, subject, forcedIdentifier string) (string, error) } diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 0de9ac2b168..2df79bce9f9 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -16,7 +16,9 @@ import ( "github.com/pborman/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/twmb/murmur3" + + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/oauth2/flowctx" "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" @@ -79,7 +81,7 @@ func (s *DefaultStrategy) matchesValueFromSession(ctx context.Context, c fosite. return nil } -func (s *DefaultStrategy) authenticationSession(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LoginSession, error) { +func (s *DefaultStrategy) authenticationSession(ctx context.Context, _ http.ResponseWriter, r *http.Request) (*flow.LoginSession, error) { store, err := s.r.CookieStore(ctx) if err != nil { return nil, err @@ -102,7 +104,8 @@ func (s *DefaultStrategy) authenticationSession(ctx context.Context, w http.Resp return nil, errorsx.WithStack(ErrNoAuthenticationSessionFound) } - session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), sessionID) + sessionFromCookie := s.loginSessionFromCookie(r) + session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), sessionFromCookie, sessionID) if errors.Is(err, x.ErrNotFound) { s.r.Logger().WithRequest(r).WithError(err). Debug("User logout skipped because cookie exists and session value exist but are not remembered any more.") @@ -184,7 +187,7 @@ func (s *DefaultStrategy) getSubjectFromIDTokenHint(ctx context.Context, idToken return sub, nil } -func (s *DefaultStrategy) forwardAuthenticationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, subject string, authenticatedAt time.Time, session *LoginSession) error { +func (s *DefaultStrategy) forwardAuthenticationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, subject string, authenticatedAt time.Time, session *flow.LoginSession) error { if (subject != "" && authenticatedAt.IsZero()) || (subject == "" && !authenticatedAt.IsZero()) { return errorsx.WithStack(fosite.ErrServerError.WithHint("Consent strategy returned a non-empty subject with an empty auth date, or an empty subject with a non-empty auth date.")) } @@ -224,51 +227,66 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(ctx context.Context, w ht sessionID = session.ID } else { // Create a stub session so that we can later update it. - if err := s.r.ConsentManager().CreateLoginSession(r.Context(), &LoginSession{ID: sessionID}); err != nil { + loginSession := &flow.LoginSession{ID: sessionID} + if err := s.r.ConsentManager().CreateLoginSession(ctx, loginSession); err != nil { + return err + } + if err := flowctx.SetCookie(ctx, w, s.r, flowctx.LoginSessionCookie(flowctx.SuffixForClient(ar.GetClient())), loginSession); err != nil { return err } } // Set the session cl := sanitizeClientFromRequest(ar) - if err := s.r.ConsentManager().CreateLoginRequest( - r.Context(), - &LoginRequest{ - ID: challenge, - Verifier: verifier, - CSRF: csrf, - Skip: skip, - RequestedScope: []string(ar.GetRequestedScopes()), - RequestedAudience: []string(ar.GetRequestedAudience()), - Subject: subject, - Client: cl, - RequestURL: iu.String(), - AuthenticatedAt: sqlxx.NullTime(authenticatedAt), - RequestedAt: time.Now().Truncate(time.Second).UTC(), - SessionID: sqlxx.NullString(sessionID), - OpenIDConnectContext: &OAuth2ConsentRequestOpenIDConnectContext{ - IDTokenHintClaims: idTokenHintClaims, - ACRValues: stringsx.Splitx(ar.GetRequestForm().Get("acr_values"), " "), - UILocales: stringsx.Splitx(ar.GetRequestForm().Get("ui_locales"), " "), - Display: ar.GetRequestForm().Get("display"), - LoginHint: ar.GetRequestForm().Get("login_hint"), - }, + loginRequest := &flow.LoginRequest{ + ID: challenge, + Verifier: verifier, + CSRF: csrf, + Skip: skip, + RequestedScope: []string(ar.GetRequestedScopes()), + RequestedAudience: []string(ar.GetRequestedAudience()), + Subject: subject, + Client: cl, + RequestURL: iu.String(), + AuthenticatedAt: sqlxx.NullTime(authenticatedAt), + RequestedAt: time.Now().Truncate(time.Second).UTC(), + SessionID: sqlxx.NullString(sessionID), + OpenIDConnectContext: &flow.OAuth2ConsentRequestOpenIDConnectContext{ + IDTokenHintClaims: idTokenHintClaims, + ACRValues: stringsx.Splitx(ar.GetRequestForm().Get("acr_values"), " "), + UILocales: stringsx.Splitx(ar.GetRequestForm().Get("ui_locales"), " "), + Display: ar.GetRequestForm().Get("display"), + LoginHint: ar.GetRequestForm().Get("login_hint"), }, - ); err != nil { + } + f, err := s.r.ConsentManager().CreateLoginRequest( + ctx, + loginRequest, + ) + if err != nil { return errorsx.WithStack(err) } + if err := flowctx.SetCookie(ctx, w, s.r, flowctx.FlowCookie(cl), f); err != nil { + return err + } + store, err := s.r.CookieStore(ctx) if err != nil { return err } - clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameLoginCSRF(ctx), murmur3.Sum32(cl.ID.Bytes())) + clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameLoginCSRF(ctx), cl.CookieSuffix()) if err := createCsrfSession(w, r, s.r.Config(), store, clientSpecificCookieNameLoginCSRF, csrf, s.c.ConsentRequestMaxAge(ctx)); err != nil { return errorsx.WithStack(err) } - http.Redirect(w, r, urlx.SetQuery(s.c.LoginURL(ctx), url.Values{"login_challenge": {challenge}}).String(), http.StatusFound) + encodedFlow, err := f.ToLoginChallenge(ctx, s.r) + if err != nil { + return err + } + + http.Redirect(w, r, urlx.SetQuery(s.c.LoginURL(ctx), url.Values{"login_challenge": {encodedFlow}}).String(), http.StatusFound) // generate the verifier return errorsx.WithStack(ErrAbortOAuth2Request) @@ -312,9 +330,22 @@ func (s *DefaultStrategy) revokeAuthenticationCookie(w http.ResponseWriter, r *h return sid, nil } -func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledLoginRequest, error) { - ctx := r.Context() - session, err := s.r.ConsentManager().VerifyAndInvalidateLoginRequest(ctx, verifier) +func (s *DefaultStrategy) verifyAuthentication( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + req fosite.AuthorizeRequester, + verifier string, +) (*flow.Flow, error) { + f, err := s.flowFromCookie(r) + if err != nil { + return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The flow cookie is missing in the request.")) + } + if f.Client.GetID() != req.GetClient().GetID() { + return nil, errorsx.WithStack(fosite.ErrInvalidClient.WithHint("The flow cookie client id does not match the authorize request client id.")) + } + + session, err := s.r.ConsentManager().VerifyAndInvalidateLoginRequest(ctx, f, verifier) if errors.Is(err, sqlcon.ErrNoRows) { return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The login verifier has already been used, has not been granted, or is invalid.")) } else if err != nil { @@ -322,8 +353,8 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re } if session.HasError() { - session.Error.SetDefaults(loginRequestDeniedErrorName) - return nil, errorsx.WithStack(session.Error.toRFCError()) + session.Error.SetDefaults(flow.LoginRequestDeniedErrorName) + return nil, errorsx.WithStack(session.Error.ToRFCError()) } if session.RequestedAt.Add(s.c.ConsentRequestMaxAge(ctx)).Before(time.Now()) { @@ -335,7 +366,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re return nil, err } - clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameLoginCSRF(ctx), murmur3.Sum32(session.LoginRequest.Client.ID.Bytes())) + clientSpecificCookieNameLoginCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameLoginCSRF(ctx), session.LoginRequest.Client.CookieSuffix()) if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameLoginCSRF, session.LoginRequest.CSRF); err != nil { return nil, err } @@ -409,10 +440,16 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re if !session.LoginRequest.Skip { if time.Time(session.AuthenticatedAt).IsZero() { - return nil, errorsx.WithStack(fosite.ErrServerError.WithHint("Expected the handled login request to contain a valid authenticated_at value but it was zero. This is a bug which should be reported to https://github.com/ory/hydra.")) + return nil, errorsx.WithStack(fosite.ErrServerError.WithHint( + "Expected the handled login request to contain a valid authenticated_at value but it was zero. This is a bug which should be reported to https://github.com/ory/hydra.")) + } + + loginSession := s.loginSessionFromCookie(r) + if loginSession == nil { + return nil, fosite.ErrAccessDenied.WithHint("The login session cookie was not found or malformed.") } - if err := s.r.ConsentManager().ConfirmLoginSession(r.Context(), sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil { + if err := s.r.ConsentManager().ConfirmLoginSession(ctx, loginSession, sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil { return nil, err } } @@ -429,7 +466,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re // If the user doesn't want to remember the session, we do not store a cookie. // If login was skipped, it means an authentication cookie was present and // we don't want to touch it (in order to preserve its original expiry date) - return session, nil + return f, nil } // Not a skipped login and the user asked to remember its session, store a cookie @@ -453,13 +490,24 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re "cookie_same_site": s.c.CookieSameSiteMode(ctx), "cookie_secure": s.c.CookieSecure(ctx), }).Debug("Authentication session cookie was set.") - return session, nil + + if err = flowctx.SetCookie(ctx, w, s.r, flowctx.FlowCookie(flowctx.SuffixForClient(req.GetClient())), f); err != nil { + return nil, errorsx.WithStack(err) + } + + return f, nil } -func (s *DefaultStrategy) requestConsent(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, authenticationSession *HandledLoginRequest) error { +func (s *DefaultStrategy) requestConsent( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + ar fosite.AuthorizeRequester, + f *flow.Flow, +) error { prompt := stringsx.Splitx(ar.GetRequestForm().Get("prompt"), " ") if stringslice.Has(prompt, "consent") { - return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil) + return s.forwardConsentRequest(ctx, w, r, ar, f, nil) } // https://tools.ietf.org/html/rfc6749 @@ -483,7 +531,7 @@ func (s *DefaultStrategy) requestConsent(ctx context.Context, w http.ResponseWri // This is tracked as issue: https://github.com/ory/hydra/issues/866 // This is also tracked as upstream issue: https://github.com/openid-certification/oidctest/issues/97 if !(ar.GetRedirectURI().Scheme == "https" || (fosite.IsLocalhost(ar.GetRedirectURI()) && ar.GetRedirectURI().Scheme == "http")) { - return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil) + return s.forwardConsentRequest(ctx, w, r, ar, f, nil) } } @@ -494,23 +542,31 @@ func (s *DefaultStrategy) requestConsent(ctx context.Context, w http.ResponseWri // return s.forwardConsentRequest(w, r, ar, authenticationSession, nil) // } - consentSessions, err := s.r.ConsentManager().FindGrantedAndRememberedConsentRequests(r.Context(), ar.GetClient().GetID(), authenticationSession.Subject) + consentSessions, err := s.r.ConsentManager().FindGrantedAndRememberedConsentRequests(ctx, ar.GetClient().GetID(), f.Subject) if errors.Is(err, ErrNoPreviousConsentFound) { - return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil) + return s.forwardConsentRequest(ctx, w, r, ar, f, nil) } else if err != nil { return err } if found := matchScopes(s.r.Config().GetScopeStrategy(ctx), consentSessions, ar.GetRequestedScopes()); found != nil { - return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, found) + return s.forwardConsentRequest(ctx, w, r, ar, f, found) } - return s.forwardConsentRequest(ctx, w, r, ar, authenticationSession, nil) + return s.forwardConsentRequest(ctx, w, r, ar, f, nil) } -func (s *DefaultStrategy) forwardConsentRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, as *HandledLoginRequest, cs *AcceptOAuth2ConsentRequest) error { +func (s *DefaultStrategy) forwardConsentRequest( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + ar fosite.AuthorizeRequester, + f *flow.Flow, + previousConsent *flow.AcceptOAuth2ConsentRequest, +) error { + as := f.GetHandledLoginRequest() skip := false - if cs != nil { + if previousConsent != nil { skip = true } @@ -525,45 +581,53 @@ func (s *DefaultStrategy) forwardConsentRequest(ctx context.Context, w http.Resp csrf := strings.Replace(uuid.New(), "-", "", -1) cl := sanitizeClientFromRequest(ar) - if err := s.r.ConsentManager().CreateConsentRequest( - r.Context(), - &OAuth2ConsentRequest{ - ID: challenge, - ACR: as.ACR, - AMR: as.AMR, - Verifier: verifier, - CSRF: csrf, - Skip: skip, - RequestedScope: []string(ar.GetRequestedScopes()), - RequestedAudience: []string(ar.GetRequestedAudience()), - Subject: as.Subject, - Client: cl, - RequestURL: as.LoginRequest.RequestURL, - AuthenticatedAt: as.AuthenticatedAt, - RequestedAt: as.RequestedAt, - ForceSubjectIdentifier: as.ForceSubjectIdentifier, - OpenIDConnectContext: as.LoginRequest.OpenIDConnectContext, - LoginSessionID: as.LoginRequest.SessionID, - LoginChallenge: sqlxx.NullString(as.LoginRequest.ID), - Context: as.Context, - }, - ); err != nil { + + consentRequest := &flow.OAuth2ConsentRequest{ + ID: challenge, + ACR: as.ACR, + AMR: as.AMR, + Verifier: verifier, + CSRF: csrf, + Skip: skip, + RequestedScope: []string(ar.GetRequestedScopes()), + RequestedAudience: []string(ar.GetRequestedAudience()), + Subject: as.Subject, + Client: cl, + RequestURL: as.LoginRequest.RequestURL, + AuthenticatedAt: as.AuthenticatedAt, + RequestedAt: as.RequestedAt, + ForceSubjectIdentifier: as.ForceSubjectIdentifier, + OpenIDConnectContext: as.LoginRequest.OpenIDConnectContext, + LoginSessionID: as.LoginRequest.SessionID, + LoginChallenge: sqlxx.NullString(as.LoginRequest.ID), + Context: as.Context, + } + err := s.r.ConsentManager().CreateConsentRequest(ctx, f, consentRequest) + if err != nil { return errorsx.WithStack(err) } + if err := flowctx.SetCookie(ctx, w, s.r, flowctx.FlowCookie(cl), f); err != nil { + return err + } + consentChallenge, err := f.ToConsentChallenge(ctx, s.r) + if err != nil { + return err + } + store, err := s.r.CookieStore(ctx) if err != nil { return err } - clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameConsentCSRF(ctx), murmur3.Sum32(cl.ID.Bytes())) + clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameConsentCSRF(ctx), cl.CookieSuffix()) if err := createCsrfSession(w, r, s.r.Config(), store, clientSpecificCookieNameConsentCSRF, csrf, s.c.ConsentRequestMaxAge(ctx)); err != nil { return errorsx.WithStack(err) } http.Redirect( w, r, - urlx.SetQuery(s.c.ConsentURL(ctx), url.Values{"consent_challenge": {challenge}}).String(), + urlx.SetQuery(s.c.ConsentURL(ctx), url.Values{"consent_challenge": {consentChallenge}}).String(), http.StatusFound, ) @@ -571,39 +635,51 @@ func (s *DefaultStrategy) forwardConsentRequest(ctx context.Context, w http.Resp return errorsx.WithStack(ErrAbortOAuth2Request) } -func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*AcceptOAuth2ConsentRequest, error) { - session, err := s.r.ConsentManager().VerifyAndInvalidateConsentRequest(r.Context(), verifier) +func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWriter, r *http.Request, verifier string) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { + f, err := s.flowFromCookie(r) + if err != nil { + return nil, nil, err + } + if f.Client.GetID() != r.URL.Query().Get("client_id") { + return nil, nil, errorsx.WithStack(fosite.ErrInvalidClient.WithHint("The flow cookie client id does not match the authorize request client id.")) + } + + session, err := s.r.ConsentManager().VerifyAndInvalidateConsentRequest(ctx, f, verifier) if errors.Is(err, sqlcon.ErrNoRows) { - return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid.")) + return nil, nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid.")) } else if err != nil { - return nil, err + return nil, nil, err } if session.RequestedAt.Add(s.c.ConsentRequestMaxAge(ctx)).Before(time.Now()) { - return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again.")) + return nil, nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again.")) } if session.HasError() { - session.Error.SetDefaults(consentRequestDeniedErrorName) - return nil, errorsx.WithStack(session.Error.toRFCError()) + session.Error.SetDefaults(flow.ConsentRequestDeniedErrorName) + return nil, nil, errorsx.WithStack(session.Error.ToRFCError()) } if time.Time(session.ConsentRequest.AuthenticatedAt).IsZero() { - return nil, errorsx.WithStack(fosite.ErrServerError.WithHint("The authenticatedAt value was not set.")) + return nil, nil, errorsx.WithStack(fosite.ErrServerError.WithHint("The authenticatedAt value was not set.")) } store, err := s.r.CookieStore(ctx) if err != nil { - return nil, err + return nil, nil, err } - clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%d", s.r.Config().CookieNameConsentCSRF(ctx), murmur3.Sum32(session.ConsentRequest.Client.ID.Bytes())) + clientSpecificCookieNameConsentCSRF := fmt.Sprintf("%s_%s", s.r.Config().CookieNameConsentCSRF(ctx), session.ConsentRequest.Client.CookieSuffix()) if err := validateCsrfSession(r, s.r.Config(), store, clientSpecificCookieNameConsentCSRF, session.ConsentRequest.CSRF); err != nil { - return nil, err + return nil, nil, err + } + + if err = flowctx.DeleteCookie(ctx, w, s.r, flowctx.FlowCookie(f.Client)); err != nil { + return nil, nil, err } if session.Session == nil { - session.Session = NewConsentRequestSessionData() + session.Session = flow.NewConsentRequestSessionData() } if session.Session.AccessToken == nil { @@ -615,7 +691,7 @@ func (s *DefaultStrategy) verifyConsent(ctx context.Context, w http.ResponseWrit } session.AuthenticatedAt = session.ConsentRequest.AuthenticatedAt - return session, nil + return session, f, nil } func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) { @@ -711,7 +787,7 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http. return nil } -func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) { +func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) { // There are two types of log out flows: // // - RP initiated logout @@ -758,7 +834,7 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon } challenge := uuid.New() - if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &LogoutRequest{ + if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &flow.LogoutRequest{ RequestURL: r.URL.String(), ID: challenge, Subject: session.Subject, @@ -869,7 +945,8 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon // We do not really want to verify if the user (from id token hint) has a session here because it doesn't really matter. // Instead, we'll check this when we're actually revoking the cookie! - session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), hintSid) + sessionFromCookie := s.loginSessionFromCookie(r) + session, err := s.r.ConsentManager().GetRememberedLoginSession(r.Context(), sessionFromCookie, hintSid) if errors.Is(err, x.ErrNotFound) { // Such a session does not exist - maybe it has already been revoked? In any case, we can't do much except // leaning back and redirecting back. @@ -880,7 +957,7 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon } challenge := uuid.New() - if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &LogoutRequest{ + if err := s.r.ConsentManager().CreateLogoutRequest(r.Context(), &flow.LogoutRequest{ RequestURL: r.URL.String(), ID: challenge, SessionID: hintSid, @@ -899,7 +976,7 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon return nil, errorsx.WithStack(ErrAbortOAuth2Request) } -func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.Context, r *http.Request, subject string, sid string) error { +func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Context, r *http.Request, subject string, sid string) error { if err := s.executeBackChannelLogout(r.Context(), r, subject, sid); err != nil { return err } @@ -918,7 +995,7 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(ctx context.C return nil } -func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) { +func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) { verifier := r.URL.Query().Get("logout_verifier") lr, err := s.r.ConsentManager().VerifyAndInvalidateLogoutRequest(r.Context(), verifier) @@ -976,13 +1053,13 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri WithField("subject", lr.Subject). Info("User logout completed!") - return &LogoutResult{ + return &flow.LogoutResult{ RedirectTo: lr.PostLogoutRedirectURI, FrontChannelLogoutURLs: urls, }, nil } -func (s *DefaultStrategy) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*LogoutResult, error) { +func (s *DefaultStrategy) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) { verifier := r.URL.Query().Get("logout_verifier") if verifier == "" { return s.issueLogoutVerifier(ctx, w, r) @@ -991,8 +1068,9 @@ func (s *DefaultStrategy) HandleOpenIDConnectLogout(ctx context.Context, w http. return s.completeLogout(ctx, w, r) } -func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, w http.ResponseWriter, r *http.Request, sid string) error { - loginSession, lsErr := s.r.ConsentManager().GetRememberedLoginSession(ctx, sid) +func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.ResponseWriter, r *http.Request, sid string) error { + sessionFromCookie := s.loginSessionFromCookie(r) + loginSession, lsErr := s.r.ConsentManager().GetRememberedLoginSession(ctx, sessionFromCookie, sid) if errors.Is(lsErr, x.ErrNotFound) { // This is ok (session probably already revoked), do nothing! @@ -1016,28 +1094,33 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, w http.Respo return nil } -func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*AcceptOAuth2ConsentRequest, error) { - authenticationVerifier := strings.TrimSpace(req.GetRequestForm().Get("login_verifier")) +func (s *DefaultStrategy) HandleOAuth2AuthorizationRequest( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, + req fosite.AuthorizeRequester, +) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { + loginVerifier := strings.TrimSpace(req.GetRequestForm().Get("login_verifier")) consentVerifier := strings.TrimSpace(req.GetRequestForm().Get("consent_verifier")) - if authenticationVerifier == "" && consentVerifier == "" { + if loginVerifier == "" && consentVerifier == "" { // ok, we need to process this request and redirect to auth endpoint - return nil, s.requestAuthentication(ctx, w, r, req) - } else if authenticationVerifier != "" { - authSession, err := s.verifyAuthentication(w, r, req, authenticationVerifier) + return nil, nil, s.requestAuthentication(ctx, w, r, req) + } else if loginVerifier != "" { + f, err := s.verifyAuthentication(ctx, w, r, req, loginVerifier) if err != nil { - return nil, err + return nil, nil, err } // ok, we need to process this request and redirect to auth endpoint - return nil, s.requestConsent(ctx, w, r, req, authSession) + return nil, f, s.requestConsent(ctx, w, r, req, f) } - consentSession, err := s.verifyConsent(ctx, w, r, req, consentVerifier) + consentSession, f, err := s.verifyConsent(ctx, w, r, consentVerifier) if err != nil { - return nil, err + return nil, nil, err } - return consentSession, nil + return consentSession, f, nil } func (s *DefaultStrategy) ObfuscateSubjectIdentifier(ctx context.Context, cl fosite.Client, subject, forcedIdentifier string) (string, error) { @@ -1057,3 +1140,22 @@ func (s *DefaultStrategy) ObfuscateSubjectIdentifier(ctx context.Context, cl fos } return subject, nil } + +func (s *DefaultStrategy) flowFromCookie(r *http.Request) (*flow.Flow, error) { + clientID := r.URL.Query().Get("client_id") + if clientID == "" { + return nil, errors.WithStack(fosite.ErrInvalidClient) + } + + return flowctx.FromCookie[flow.Flow](r.Context(), r, s.r.FlowCipher(), flowctx.FlowCookie(flowctx.SuffixFromStatic(clientID))) +} + +func (s *DefaultStrategy) loginSessionFromCookie(r *http.Request) *flow.LoginSession { + clientID := r.URL.Query().Get("client_id") + if clientID == "" { + return nil + } + ls, _ := flowctx.FromCookie[flow.LoginSession](r.Context(), r, s.r.FlowCipher(), flowctx.LoginSessionCookie(flowctx.SuffixFromStatic(clientID))) + + return ls +} diff --git a/consent/strategy_oauth_test.go b/consent/strategy_oauth_test.go index a44446121ee..d1ec766c61c 100644 --- a/consent/strategy_oauth_test.go +++ b/consent/strategy_oauth_test.go @@ -10,14 +10,20 @@ import ( "encoding/json" "fmt" "net/http" + "net/http/cookiejar" "net/url" + "regexp" "testing" "time" + "github.com/ory/hydra/v2/aead" + "github.com/ory/hydra/v2/consent" + "github.com/ory/hydra/v2/flow" + "github.com/ory/hydra/v2/oauth2/flowctx" + "github.com/ory/hydra/v2/x" "github.com/ory/x/ioutilx" - "github.com/twmb/murmur3" - + "golang.org/x/exp/slices" "golang.org/x/oauth2" "github.com/ory/x/pointerx" @@ -113,8 +119,12 @@ func TestStrategyLoginConsentNext(t *testing.T) { t.Run("case=should fail because a login verifier was given that doesn't exist in the store", func(t *testing.T) { testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) c := createDefaultClient(t) + hc := newHTTPClientWithFlowCookie(t, ctx, reg, c) - makeRequestAndExpectError(t, nil, c, url.Values{"login_verifier": {"does-not-exist"}}, "The login verifier has already been used, has not been granted, or is invalid.") + makeRequestAndExpectError( + t, hc, c, url.Values{"login_verifier": {"does-not-exist"}}, + "The login verifier has already been used, has not been granted, or is invalid.", + ) }) t.Run("case=should fail because a non-existing consent verifier was given", func(t *testing.T) { @@ -123,7 +133,12 @@ func TestStrategyLoginConsentNext(t *testing.T) { // - This should fail because a consent verifier was given but no login verifier testhelpers.NewLoginConsentUI(t, reg.Config(), testhelpers.HTTPServerNoExpectedCallHandler(t), testhelpers.HTTPServerNoExpectedCallHandler(t)) c := createDefaultClient(t) - makeRequestAndExpectError(t, nil, c, url.Values{"consent_verifier": {"does-not-exist"}}, "The consent verifier has already been used, has not been granted, or is invalid.") + hc := newHTTPClientWithFlowCookie(t, ctx, reg, c) + + makeRequestAndExpectError( + t, hc, c, url.Values{"consent_verifier": {"does-not-exist"}}, + "The consent verifier has already been used, has not been granted, or is invalid.", + ) }) t.Run("case=should fail because the request was redirected but the login endpoint doesn't do anything (like redirecting back)", func(t *testing.T) { @@ -169,6 +184,7 @@ func TestStrategyLoginConsentNext(t *testing.T) { testhelpers.HTTPServerNoExpectedCallHandler(t)) hc := new(http.Client) + hc.Jar = DropCookieJar(regexp.MustCompile("ory_hydra_.*_csrf_.*")) makeRequestAndExpectError(t, hc, c, url.Values{}, "No CSRF value available in the session cookie.") }) @@ -332,16 +348,13 @@ func TestStrategyLoginConsentNext(t *testing.T) { loginChallengeRedirect, err := oauthRes.Location() require.NoError(t, err) defer oauthRes.Body.Close() - setCookieHeader := oauthRes.Header.Get("set-cookie") - assert.NotNil(t, setCookieHeader) - - t.Run("login cookie client specific suffix is set", func(t *testing.T) { - assert.Regexp(t, fmt.Sprintf("ory_hydra_login_csrf_dev_%d=.*", murmur3.Sum32(c.ID.Bytes())), setCookieHeader) - }) - t.Run("login cookie max age is set", func(t *testing.T) { - assert.Regexp(t, fmt.Sprintf("ory_hydra_login_csrf_dev_%d=.*Max-Age=%.0f;.*", murmur3.Sum32(c.ID.Bytes()), consentRequestMaxAge), setCookieHeader) + foundLoginCookie := slices.ContainsFunc(oauthRes.Header.Values("set-cookie"), func(sc string) bool { + ok, err := regexp.MatchString(fmt.Sprintf("ory_hydra_login_csrf_dev_%s=.*Max-Age=%.0f;.*", c.CookieSuffix(), consentRequestMaxAge), sc) + require.NoError(t, err) + return ok }) + require.True(t, foundLoginCookie, "client-specific login cookie with max age set") loginChallengeRes, err := hc.Get(loginChallengeRedirect.String()) require.NoError(t, err) @@ -352,16 +365,13 @@ func TestStrategyLoginConsentNext(t *testing.T) { loginVerifierRes, err := hc.Get(loginVerifierRedirect.String()) require.NoError(t, err) defer loginVerifierRes.Body.Close() - setCookieHeader = loginVerifierRes.Header.Values("set-cookie")[1] - assert.NotNil(t, setCookieHeader) - t.Run("consent cookie client specific suffix set", func(t *testing.T) { - assert.Regexp(t, fmt.Sprintf("ory_hydra_consent_csrf_dev_%d=.*", murmur3.Sum32(c.ID.Bytes())), setCookieHeader) - }) - - t.Run("consent cookie max age is set", func(t *testing.T) { - assert.Regexp(t, fmt.Sprintf("ory_hydra_consent_csrf_dev_%d=.*Max-Age=%.0f;.*", murmur3.Sum32(c.ID.Bytes()), consentRequestMaxAge), setCookieHeader) + foundConsentCookie := slices.ContainsFunc(loginVerifierRes.Header.Values("set-cookie"), func(sc string) bool { + ok, err := regexp.MatchString(fmt.Sprintf("ory_hydra_consent_csrf_dev_%s=.*Max-Age=%.0f;.*", c.CookieSuffix(), consentRequestMaxAge), sc) + require.NoError(t, err) + return ok }) + require.True(t, foundConsentCookie, "client-specific consent cookie with max age set") }) t.Run("case=should pass if both login and consent are granted and check remember flows with refresh session cookie", func(t *testing.T) { @@ -432,6 +442,7 @@ func TestStrategyLoginConsentNext(t *testing.T) { require.NoError(t, err) defer loginChallengeRes.Body.Close() loginVerifierRedirect, err := loginChallengeRes.Location() + require.NoError(t, err) loginVerifierRes, err := hc.Get(loginVerifierRedirect.String()) require.NoError(t, err) @@ -580,9 +591,8 @@ func TestStrategyLoginConsentNext(t *testing.T) { hc := testhelpers.NewEmptyJarClient(t) - t.Run("set up initial session", func(t *testing.T) { - makeRequestAndExpectCode(t, hc, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}}) - }) + // set up initial session + makeRequestAndExpectCode(t, hc, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}}) // By not waiting here we ensure that there are no race conditions when it comes to authenticated_at and // requested_at time comparisons: @@ -1017,3 +1027,47 @@ func TestStrategyLoginConsentNext(t *testing.T) { makeRequestAndExpectCode(t, hc, c, url.Values{"redirect_uri": {c.RedirectURIs[0]}}) }) } + +func DropCookieJar(drop *regexp.Regexp) http.CookieJar { + jar, _ := cookiejar.New(nil) + return &dropCSRFCookieJar{ + jar: jar, + drop: drop, + } +} + +type dropCSRFCookieJar struct { + jar *cookiejar.Jar + drop *regexp.Regexp +} + +var _ http.CookieJar = (*dropCSRFCookieJar)(nil) + +func (d *dropCSRFCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) { + for _, c := range cookies { + if d.drop.MatchString(c.Name) { + continue + } + d.jar.SetCookies(u, []*http.Cookie{c}) + } +} + +func (d *dropCSRFCookieJar) Cookies(u *url.URL) []*http.Cookie { + return d.jar.Cookies(u) +} + +func newHTTPClientWithFlowCookie(t *testing.T, ctx context.Context, reg interface { + ConsentManager() consent.Manager + Config() *config.DefaultProvider + FlowCipher() *aead.XChaCha20Poly1305 +}, c *client.Client) *http.Client { + f, err := reg.ConsentManager().CreateLoginRequest(ctx, &flow.LoginRequest{Client: c}) + require.NoError(t, err) + + hc := testhelpers.NewEmptyJarClient(t) + hc.Jar.SetCookies(reg.Config().OAuth2AuthURL(ctx), []*http.Cookie{ + {Name: flowctx.FlowCookie(c), Value: x.Must(flowctx.Encode(ctx, reg.FlowCipher(), f))}, + }) + + return hc +} diff --git a/docs/flow-cache-design-doc.md b/docs/flow-cache-design-doc.md new file mode 100644 index 00000000000..22916348936 --- /dev/null +++ b/docs/flow-cache-design-doc.md @@ -0,0 +1,167 @@ +# Flow Cache Design Doc + +## Overview + +This design doc outlines the proposed solution for caching the flow object in +the OAuth2 exchange between the Client, Ory Hydra, and the Consent and Login +UIs. The flow object contains the state of the authorization request. + +## Problem Statement + +Currently, the flow object is stored in the database on the Ory Hydra server. +This approach has several drawbacks: + +- Each step of the OAuth2 flow (initialization, consent, login, etc.) requires a + database query to retrieve the flow object, and another to update it. +- Each part of the exchanges supplies different values (login challenge, consent + challenge, etc.) to identify the flow object. This means the database table + has multiple indices that slow down insertions. + +## Proposed Solution + +The proposed solution is to store the flow object in client cookies and URLs. +This way, the flow object is written only once when the flow is completed and +the final authorization code is generated. + +### Requirements + +- The flow object must be stored in client cookies and URLs. +- The flow object must be secure and protect against unauthorized access. +- The flow object must be persistent, so that the flow can be resumed if the + user navigates away from the page or closes the browser. +- The flow object must be scalable and able to handle a large number of + concurrent requests. + +### Architecture + +The proposed architecture for the flow cache is as follows: + +- Store the flow object in an AEAD encrypted cookie. +- Pass a partial flow around in the URL. +- Use a secure connection to protect against unauthorized access. + +```mermaid +sequenceDiagram + actor Client + participant Hydra + participant LoginUI as Login UI + participant ConsentUI as Consent UI + % participant Callback + + autonumber + + Client->>+Hydra: GET /oauth2/auth?client_id=CLIENT_ID&response_type=code&scope=SCOPES&state=STATE + Hydra->>-Client: Redirect to
http://login.local/?login_challenge=LOGIN_CHALLENGE + + Client->>+LoginUI: GET /?login_challenge=LOGIN_CHALLENGE + LoginUI->>Hydra: GET /admin/oauth2/auth/requests/login + Hydra->>LoginUI: oAuth2LoginRequest + alt accept login + LoginUI->>Hydra: PUT /admin/oauth2/auth/requests/login/accept + else reject login + LoginUI->>Hydra: PUT /admin/oauth2/auth/requests/login/reject + end + Hydra->>LoginUI: oAuth2RedirectTo + LoginUI->>-Client: Redirect to
http://hydra.local/oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&response_type=code&scope=SCOPES&state=STATE + + Client->>+Hydra: GET /oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&response_type=code&scope=SCOPES&state=STATE + Hydra->>-Client: Redirect to
http://consent.local/?consent_challenge=CONSENT_CHALLENGE + + Client->>+ConsentUI: GET /?consent_challenge=CONSENT_CHALLENGE + ConsentUI->>Hydra: GET /admin/oauth2/auth/requests/consent + Hydra->>ConsentUI: oAuth2ConsentRequest + alt accept login + ConsentUI->>Hydra: PUT /admin/oauth2/auth/requests/consent/accept + else reject login + ConsentUI->>Hydra: PUT /admin/oauth2/auth/requests/consent/reject + end + Hydra->>ConsentUI: oAuth2RedirectTo + ConsentUI->>-Client: Redirect to
http://hydra.local/oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&response_type=code&scope=SCOPES&state=STATE + + Client->>+Hydra: GET /oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&response_type=code&scope=SCOPES&state=STATE + Hydra->>-Client: Redirect to
http://callback.local/callback?code=AUTH_CODE&scope=SCOPES&state=STATE + Note over Hydra,Client: next, exchange code for token. + + + % Client->>+Callback: GET /callback?code=AUTH_CODE&scope=SCOPES&state=STATE + % Callback->>-Client: Return Authorization Code +``` + +Step 2: + +- Set the whole flow as an AEAD encrypted cookie on the client +- The cookie is keyed by the `state`, so that multiple flows can run in parallel + from one cookie jar +- Set the `LOGIN_CHALLENGE` to the AEAD-encrypted flow + +Step 5: + +- Decrypt the flow from the `LOGIN_CHALLENGE`, return the `oAuth2LoginRequest` + +Step 8: + +- Encode the flow into the redirect URL in `oAuth2RedirectTo` as the + `LOGIN_VERIFIER` + +Step 11 + +- Check that the login challenge in the `LOGIN_VERIFIER` matches the challenge + in the flow cookie. +- Update the flow based on the request from the `LOGIN_VERIFIER` +- Update the cookie +- Set the `CONSENT_CHALLENGE` to the AEAD-encrypted flow + +Step 14: + +- Decrypt the flow from the `CONSENT_CHALLENGE` + +Step 17: + +- Encode the flow into the redirect URL in `oAuth2RedirectTo` as the + `CONSENT_VERIFIER` + +Step 20 + +- Check that the consent challenge in the `CONSENT_VERIFIER` matches the + challenge in the flow cookie. +- Update the flow based on the request from the `CONSENT_VERIFIER` +- Update the cookie +- Write the flow to the database +- Continue the flow as currently implemented (generate the authentication code, + return the code, etc.) + +### Client HTTP requests + +For reference, these HTTP requests are issued by the client: + +``` +GET http://hydra.local/oauth2/auth?client_id=CLIENT_ID&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE +Redirect to http://login.local/?login_challenge=LOGIN_CHALLENGE +GET http://login.local/?login_challenge=LOGIN_CHALLENGE +Redirect to http://hydra.local/oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE +GET http://hydra.local/oauth2/auth?client_id=CLIENT_ID&login_verifier=LOGIN_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE +Redirect to http://consent.local/?consent_challenge=CONSENT_CHALLENGE +GET http://consent.local/?consent_challenge=CONSENT_CHALLENGE +Redirect to http://hydra.local/oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE +GET http://hydra.local/oauth2/auth?client_id=CLIENT_ID&consent_verifier=CONSENT_VERIFIER&nonce=NONCE&response_type=code&scope=SCOPES&state=STATE +Redirect to http://callback.local/callback?code=AUTH_CODE&scope=SCOPES&state=STATE +GET http://callback.local/callback?code=AUTH_CODE&scope=SCOPES&state=STATE +``` + +### Implementation + +The implementation of the flow cache will involve the following steps: + +1. Modify the Ory Hydra server to store the flow object in an AEAD encrypted + cookie. +2. Modify the Consent and Login UIs to include the flow object in the URL. +3. Use HTTPS to protect against unauthorized access. + +## Conclusion + +The proposed solution for caching the flow object in the OAuth2 exchange between +the Client, Ory Hydra, and the Consent and Login UIs is to store the flow object +in client cookies and URLs. This approach eliminates the need for a distributed +cache and provides a scalable and secure solution. The flow object will be +stored in an AEAD encrypted cookie and passed around in the URL. HTTPS will be +used to protect against unauthorized access. diff --git a/driver/factory.go b/driver/factory.go index 18a624357b0..2e5fe949c29 100644 --- a/driver/factory.go +++ b/driver/factory.go @@ -8,19 +8,23 @@ import ( "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/configx" - "github.com/ory/x/contextx" "github.com/ory/x/logrusx" + "github.com/ory/x/otelx" "github.com/ory/x/servicelocatorx" ) -type options struct { - preload bool - validate bool - opts []configx.OptionModifier - config *config.DefaultProvider - // The first default refers to determining the NID at startup; the second default referes to the fact that the Contextualizer may dynamically change the NID. - skipNetworkInit bool -} +type ( + options struct { + preload bool + validate bool + opts []configx.OptionModifier + config *config.DefaultProvider + // The first default refers to determining the NID at startup; the second default referes to the fact that the Contextualizer may dynamically change the NID. + skipNetworkInit bool + tracerWrapper TracerWrapper + } + TracerWrapper func(*otelx.Tracer) *otelx.Tracer +) func newOptions() *options { return &options{ @@ -66,6 +70,13 @@ func SkipNetworkInit() OptionsModifier { } } +// WithTracerWrapper sets a function that wraps the tracer. +func WithTracerWrapper(wrapper TracerWrapper) OptionsModifier { + return func(o *options) { + o.tracerWrapper = wrapper + } +} + func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifier) (Registry, error) { o := newOptions() for _, f := range opts { @@ -94,13 +105,17 @@ func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifie } } - r, err := NewRegistryFromDSN(ctx, c, l, o.skipNetworkInit, false, ctxter) + r, err := NewRegistryWithoutInit(c, l) if err != nil { l.WithError(err).Error("Unable to create service registry.") return nil, err } - if err = r.Init(ctx, o.skipNetworkInit, false, &contextx.Default{}); err != nil { + if o.tracerWrapper != nil { + r.WithTracerWrapper(o.tracerWrapper) + } + + if err = r.Init(ctx, o.skipNetworkInit, false, ctxter); err != nil { l.WithError(err).Error("Unable to initialize service registry.") return nil, err } diff --git a/driver/registry.go b/driver/registry.go index ccf80db1448..c75213e52c7 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -5,9 +5,13 @@ package driver import ( "context" + "net/http" + + "go.opentelemetry.io/otel/trace" "github.com/ory/x/httprouterx" + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/hsm" "github.com/ory/x/contextx" @@ -46,9 +50,12 @@ type Registry interface { WithConfig(c *config.DefaultProvider) Registry WithContextualizer(ctxer contextx.Contextualizer) Registry WithLogger(l *logrusx.Logger) Registry + WithTracer(t trace.Tracer) Registry + WithTracerWrapper(TracerWrapper) Registry x.HTTPClientProvider GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy + contextx.Provider config.Provider persistence.Provider x.RegistryLogger @@ -61,6 +68,7 @@ type Registry interface { oauth2.Registry PrometheusManager() *prometheus.MetricsManager x.TracingProvider + FlowCipher() *aead.XChaCha20Poly1305 RegisterRoutes(ctx context.Context, admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic) ClientHandler() *client.Handler @@ -68,6 +76,7 @@ type Registry interface { ConsentHandler() *consent.Handler OAuth2Handler() *oauth2.Handler HealthHandler() *healthx.Handler + OAuth2AwareMiddleware() func(h http.Handler) http.Handler OAuth2HMACStrategy() *foauth2.HMACSHAStrategy WithOAuth2Provider(f fosite.OAuth2Provider) @@ -109,6 +118,7 @@ func CallRegistry(ctx context.Context, r Registry) { r.SubjectIdentifierAlgorithm(ctx) r.KeyManager() r.KeyCipher() + r.FlowCipher() r.OAuth2Storage() r.OAuth2Provider() r.AudienceStrategy() diff --git a/driver/registry_base.go b/driver/registry_base.go index d2d458427a8..24a58e7b848 100644 --- a/driver/registry_base.go +++ b/driver/registry_base.go @@ -15,12 +15,14 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/cors" + "go.opentelemetry.io/otel/trace" "github.com/ory/fosite" "github.com/ory/fosite/compose" foauth2 "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/openid" "github.com/ory/herodot" + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver/config" @@ -59,7 +61,8 @@ type RegistryBase struct { ctxer contextx.Contextualizer hh *healthx.Handler migrationStatus *popx.MigrationStatuses - kc *jwk.AEAD + kc *aead.AESGCM + flowc *aead.XChaCha20Poly1305 cos consent.Strategy writer herodot.Writer hsm hsm.Context @@ -69,6 +72,7 @@ type RegistryBase struct { oah *oauth2.Handler sia map[string]consent.SubjectIdentifierAlgorithm trc *otelx.Tracer + tracerWrapper func(*otelx.Tracer) *otelx.Tracer pmm *prometheus.MetricsManager oa2mw func(h http.Handler) http.Handler arhs []oauth2.AccessRequestHook @@ -119,9 +123,9 @@ func (m *RegistryBase) WithBuildInfo(version, hash, date string) Registry { return m.r } -func (m *RegistryBase) OAuth2AwareMiddleware(ctx context.Context) func(h http.Handler) http.Handler { +func (m *RegistryBase) OAuth2AwareMiddleware() func(h http.Handler) http.Handler { if m.oa2mw == nil { - m.oa2mw = oauth2cors.Middleware(ctx, m.r) + m.oa2mw = oauth2cors.Middleware(m.r) } return m.oa2mw } @@ -150,9 +154,9 @@ func (m *RegistryBase) RegisterRoutes(ctx context.Context, admin *httprouterx.Ro admin.Handler("GET", prometheus.MetricsPrometheusPath, promhttp.Handler()) m.ConsentHandler().SetRoutes(admin) - m.KeyHandler().SetRoutes(admin, public, m.OAuth2AwareMiddleware(ctx)) + m.KeyHandler().SetRoutes(admin, public, m.OAuth2AwareMiddleware()) m.ClientHandler().SetRoutes(admin, public) - m.OAuth2Handler().SetRoutes(admin, public, m.OAuth2AwareMiddleware(ctx)) + m.OAuth2Handler().SetRoutes(admin, public, m.OAuth2AwareMiddleware()) m.JWTGrantHandler().SetRoutes(admin) } @@ -187,6 +191,16 @@ func (m *RegistryBase) WithLogger(l *logrusx.Logger) Registry { return m.r } +func (m *RegistryBase) WithTracer(t trace.Tracer) Registry { + m.trc = new(otelx.Tracer).WithOTLP(t) + return m.r +} + +func (m *RegistryBase) WithTracerWrapper(wrapper TracerWrapper) Registry { + m.tracerWrapper = wrapper + return m.r +} + func (m *RegistryBase) Logger() *logrusx.Logger { if m.l == nil { m.l = logrusx.New("Ory Hydra", m.BuildVersion()) @@ -282,13 +296,20 @@ func (m *RegistryBase) ConsentStrategy() consent.Strategy { return m.cos } -func (m *RegistryBase) KeyCipher() *jwk.AEAD { +func (m *RegistryBase) KeyCipher() *aead.AESGCM { if m.kc == nil { - m.kc = jwk.NewAEAD(m.Config()) + m.kc = aead.NewAESGCM(m.Config()) } return m.kc } +func (m *RegistryBase) FlowCipher() *aead.XChaCha20Poly1305 { + if m.flowc == nil { + m.flowc = aead.NewXChaCha20Poly1305(m.Config()) + } + return m.flowc +} + func (m *RegistryBase) CookieStore(ctx context.Context) (sessions.Store, error) { var keys [][]byte secrets, err := m.conf.GetCookieSecrets(ctx) @@ -460,12 +481,17 @@ func (m *RegistryBase) SubjectIdentifierAlgorithm(ctx context.Context) map[strin return m.sia } -func (m *RegistryBase) Tracer(ctx context.Context) *otelx.Tracer { +func (m *RegistryBase) Tracer(_ context.Context) *otelx.Tracer { if m.trc == nil { t, err := otelx.New("Ory Hydra", m.l, m.conf.Tracing()) if err != nil { m.Logger().WithError(err).Error("Unable to initialize Tracer.") } else { + // Wrap the tracer if required + if m.tracerWrapper != nil { + t = m.tracerWrapper(t) + } + m.trc = t } } diff --git a/consent/types.go b/flow/consent_types.go similarity index 98% rename from consent/types.go rename to flow/consent_types.go index 6a389e9d8bb..89e56ef8aa7 100644 --- a/consent/types.go +++ b/flow/consent_types.go @@ -1,7 +1,7 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package consent +package flow import ( "database/sql" @@ -23,8 +23,8 @@ import ( ) const ( - consentRequestDeniedErrorName = "consent request denied" - loginRequestDeniedErrorName = "login request denied" + ConsentRequestDeniedErrorName = "consent request denied" + LoginRequestDeniedErrorName = "login request denied" ) // OAuth 2.0 Redirect Browser To @@ -49,7 +49,7 @@ type LoginSession struct { Remember bool `db:"remember"` } -func (_ LoginSession) TableName() string { +func (LoginSession) TableName() string { return "hydra_oauth2_authentication_session" } @@ -77,11 +77,12 @@ type RequestDeniedError struct { // to the public but only in the server logs. Debug string `json:"error_debug"` - valid bool + // swagger:ignore + Valid bool `json:"valid"` } func (e *RequestDeniedError) IsError() bool { - return e != nil && e.valid + return e != nil && e.Valid } func (e *RequestDeniedError) SetDefaults(name string) { @@ -94,7 +95,7 @@ func (e *RequestDeniedError) SetDefaults(name string) { } } -func (e *RequestDeniedError) toRFCError() *fosite.RFC6749Error { +func (e *RequestDeniedError) ToRFCError() *fosite.RFC6749Error { if e.Name == "" { e.Name = "request_denied" } @@ -112,7 +113,7 @@ func (e *RequestDeniedError) toRFCError() *fosite.RFC6749Error { } } -func (e *RequestDeniedError) Scan(value interface{}) error { +func (e *RequestDeniedError) Scan(value any) error { v := fmt.Sprintf("%s", value) if len(v) == 0 || v == "{}" { return nil @@ -122,7 +123,7 @@ func (e *RequestDeniedError) Scan(value interface{}) error { return errorsx.WithStack(err) } - e.valid = true + e.Valid = true return nil } @@ -188,6 +189,8 @@ func (r *AcceptOAuth2ConsentRequest) HasError() bool { // List of OAuth 2.0 Consent Sessions // // swagger:model oAuth2ConsentSessions +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type oAuth2ConsentSessions []OAuth2ConsentSession // OAuth 2.0 Consent Session @@ -420,7 +423,7 @@ type LogoutRequest struct { Client *client.Client `json:"client" db:"-"` } -func (_ LogoutRequest) TableName() string { +func (LogoutRequest) TableName() string { return "hydra_oauth2_logout_request" } diff --git a/consent/types_test.go b/flow/consent_types_test.go similarity index 89% rename from consent/types_test.go rename to flow/consent_types_test.go index 6366404d9e9..116b0f328bb 100644 --- a/consent/types_test.go +++ b/flow/consent_types_test.go @@ -1,7 +1,7 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -package consent +package flow import ( "fmt" @@ -21,7 +21,7 @@ func TestToRFCError(t *testing.T) { { input: &RequestDeniedError{ Name: "not empty", - valid: true, + Valid: true, }, expect: &fosite.RFC6749Error{ ErrorField: "not empty", @@ -34,7 +34,7 @@ func TestToRFCError(t *testing.T) { input: &RequestDeniedError{ Name: "", Description: "not empty", - valid: true, + Valid: true, }, expect: &fosite.RFC6749Error{ ErrorField: "request_denied", @@ -44,7 +44,7 @@ func TestToRFCError(t *testing.T) { }, }, { - input: &RequestDeniedError{valid: true}, + input: &RequestDeniedError{Valid: true}, expect: &fosite.RFC6749Error{ ErrorField: "request_denied", DescriptionField: "", @@ -55,7 +55,7 @@ func TestToRFCError(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - require.EqualValues(t, tc.input.toRFCError(), tc.expect) + require.EqualValues(t, tc.input.ToRFCError(), tc.expect) }) } } diff --git a/flow/flow.go b/flow/flow.go index bbf2e36fec9..0868e7f5f14 100644 --- a/flow/flow.go +++ b/flow/flow.go @@ -4,15 +4,16 @@ package flow import ( + "context" "time" + "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" "github.com/pkg/errors" - "github.com/gobuffalo/pop/v6" - + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/client" - "github.com/ory/hydra/v2/consent" + "github.com/ory/hydra/v2/oauth2/flowctx" "github.com/ory/hydra/v2/x" "github.com/ory/x/sqlcon" "github.com/ory/x/sqlxx" @@ -113,7 +114,7 @@ type Flow struct { // OpenIDConnectContext provides context for the (potential) OpenID Connect context. Implementation of these // values in your app are optional but can be useful if you want to be fully compliant with the OpenID Connect spec. - OpenIDConnectContext *consent.OAuth2ConsentRequestOpenIDConnectContext `db:"oidc_context"` + OpenIDConnectContext *OAuth2ConsentRequestOpenIDConnectContext `db:"oidc_context"` // Client is the OAuth 2.0 Client that initiated the request. // @@ -195,8 +196,8 @@ type Flow struct { // recommend redirecting the user to `request_url` to re-initiate the flow. LoginWasUsed bool `db:"login_was_used"` - LoginError *consent.RequestDeniedError `db:"login_error"` - LoginAuthenticatedAt sqlxx.NullTime `db:"login_authenticated_at"` + LoginError *RequestDeniedError `db:"login_error"` + LoginAuthenticatedAt sqlxx.NullTime `db:"login_authenticated_at"` // ConsentChallengeID is the identifier ("authorization challenge") of the consent authorization request. It is used to // identify the session. @@ -231,13 +232,13 @@ type Flow struct { // ConsentWasHandled set to true means that the request was already handled. // This can happen on form double-submit or other errors. If this is set we // recommend redirecting the user to `request_url` to re-initiate the flow. - ConsentWasHandled bool `db:"consent_was_used"` - ConsentError *consent.RequestDeniedError `db:"consent_error"` - SessionIDToken sqlxx.MapStringInterface `db:"session_id_token" faker:"-"` - SessionAccessToken sqlxx.MapStringInterface `db:"session_access_token" faker:"-"` + ConsentWasHandled bool `db:"consent_was_used"` + ConsentError *RequestDeniedError `db:"consent_error"` + SessionIDToken sqlxx.MapStringInterface `db:"session_id_token" faker:"-"` + SessionAccessToken sqlxx.MapStringInterface `db:"session_access_token" faker:"-"` } -func NewFlow(r *consent.LoginRequest) *Flow { +func NewFlow(r *LoginRequest) *Flow { return &Flow{ ID: r.ID, RequestedScope: r.RequestedScope, @@ -259,7 +260,7 @@ func NewFlow(r *consent.LoginRequest) *Flow { } } -func (f *Flow) HandleLoginRequest(h *consent.HandledLoginRequest) error { +func (f *Flow) HandleLoginRequest(h *HandledLoginRequest) error { if f.LoginWasUsed { return errors.WithStack(x.ErrConflict.WithHint("The login request was already used and can no longer be changed.")) } @@ -301,8 +302,8 @@ func (f *Flow) HandleLoginRequest(h *consent.HandledLoginRequest) error { return nil } -func (f *Flow) GetHandledLoginRequest() consent.HandledLoginRequest { - return consent.HandledLoginRequest{ +func (f *Flow) GetHandledLoginRequest() HandledLoginRequest { + return HandledLoginRequest{ ID: f.ID, Remember: f.LoginRemember, RememberFor: f.LoginRememberFor, @@ -320,8 +321,8 @@ func (f *Flow) GetHandledLoginRequest() consent.HandledLoginRequest { } } -func (f *Flow) GetLoginRequest() *consent.LoginRequest { - return &consent.LoginRequest{ +func (f *Flow) GetLoginRequest() *LoginRequest { + return &LoginRequest{ ID: f.ID, RequestedScope: f.RequestedScope, RequestedAudience: f.RequestedAudience, @@ -355,7 +356,7 @@ func (f *Flow) InvalidateLoginRequest() error { return nil } -func (f *Flow) HandleConsentRequest(r *consent.AcceptOAuth2ConsentRequest) error { +func (f *Flow) HandleConsentRequest(r *AcceptOAuth2ConsentRequest) error { if time.Time(r.HandledAt).IsZero() { return errors.New("refusing to handle a consent request with null HandledAt") } @@ -408,8 +409,8 @@ func (f *Flow) InvalidateConsentRequest() error { return nil } -func (f *Flow) GetConsentRequest() *consent.OAuth2ConsentRequest { - return &consent.OAuth2ConsentRequest{ +func (f *Flow) GetConsentRequest() *OAuth2ConsentRequest { + cs := OAuth2ConsentRequest{ ID: f.ConsentChallengeID.String(), RequestedScope: f.RequestedScope, RequestedAudience: f.RequestedAudience, @@ -431,18 +432,22 @@ func (f *Flow) GetConsentRequest() *consent.OAuth2ConsentRequest { AuthenticatedAt: f.LoginAuthenticatedAt, RequestedAt: f.RequestedAt, } + if cs.AMR == nil { + cs.AMR = []string{} + } + return &cs } -func (f *Flow) GetHandledConsentRequest() *consent.AcceptOAuth2ConsentRequest { +func (f *Flow) GetHandledConsentRequest() *AcceptOAuth2ConsentRequest { crf := 0 if f.ConsentRememberFor != nil { crf = *f.ConsentRememberFor } - return &consent.AcceptOAuth2ConsentRequest{ + return &AcceptOAuth2ConsentRequest{ ID: f.ConsentChallengeID.String(), GrantedScope: f.GrantedScope, GrantedAudience: f.GrantedAudience, - Session: &consent.AcceptOAuth2ConsentRequestSession{AccessToken: f.SessionAccessToken, IDToken: f.SessionIDToken}, + Session: &AcceptOAuth2ConsentRequestSession{AccessToken: f.SessionAccessToken, IDToken: f.SessionIDToken}, Remember: f.ConsentRemember, RememberFor: crf, HandledAt: f.ConsentHandledAt, @@ -456,7 +461,7 @@ func (f *Flow) GetHandledConsentRequest() *consent.AcceptOAuth2ConsentRequest { } } -func (_ Flow) TableName() string { +func (Flow) TableName() string { return "hydra_oauth2_flow" } @@ -470,15 +475,15 @@ func (f *Flow) BeforeSave(_ *pop.Connection) error { return nil } -// TODO Populate the client field in FindInDB and FindByConsentChallengeID in -// order to avoid accessing the database twice. func (f *Flow) AfterFind(c *pop.Connection) error { + // TODO Populate the client field in FindInDB and FindByConsentChallengeID in + // order to avoid accessing the database twice. f.AfterSave(c) f.Client = &client.Client{} return sqlcon.HandleError(c.Where("id = ? AND nid = ?", f.ClientID, f.NID).First(f.Client)) } -func (f *Flow) AfterSave(c *pop.Connection) { +func (f *Flow) AfterSave(_ *pop.Connection) { if f.SessionAccessToken == nil { f.SessionAccessToken = make(map[string]interface{}) } @@ -486,3 +491,27 @@ func (f *Flow) AfterSave(c *pop.Connection) { f.SessionIDToken = make(map[string]interface{}) } } + +type CipherProvider interface { + FlowCipher() *aead.XChaCha20Poly1305 +} + +// ToLoginChallenge converts the flow into a login challenge. +func (f *Flow) ToLoginChallenge(ctx context.Context, cipherProvider CipherProvider) (string, error) { + return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsLoginChallenge) +} + +// ToLoginVerifier converts the flow into a login verifier. +func (f *Flow) ToLoginVerifier(ctx context.Context, cipherProvider CipherProvider) (string, error) { + return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsLoginVerifier) +} + +// ToConsentChallenge converts the flow into a consent challenge. +func (f *Flow) ToConsentChallenge(ctx context.Context, cipherProvider CipherProvider) (string, error) { + return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsConsentChallenge) +} + +// ToConsentVerifier converts the flow into a consent verifier. +func (f *Flow) ToConsentVerifier(ctx context.Context, cipherProvider CipherProvider) (string, error) { + return flowctx.Encode(ctx, cipherProvider.FlowCipher(), f, flowctx.AsConsentVerifier) +} diff --git a/flow/flow_test.go b/flow/flow_test.go index c00e7524b2e..7876e9a1a63 100644 --- a/flow/flow_test.go +++ b/flow/flow_test.go @@ -7,17 +7,16 @@ import ( "testing" "time" - "github.com/instana/testify/require" "github.com/mohae/deepcopy" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/bxcodec/faker/v3" - "github.com/ory/hydra/v2/consent" "github.com/ory/x/sqlxx" ) -func (f *Flow) setLoginRequest(r *consent.LoginRequest) { +func (f *Flow) setLoginRequest(r *LoginRequest) { f.ID = r.ID f.RequestedScope = r.RequestedScope f.RequestedAudience = r.RequestedAudience @@ -36,7 +35,7 @@ func (f *Flow) setLoginRequest(r *consent.LoginRequest) { f.RequestedAt = r.RequestedAt } -func (f *Flow) setHandledLoginRequest(r *consent.HandledLoginRequest) { +func (f *Flow) setHandledLoginRequest(r *HandledLoginRequest) { f.ID = r.ID f.LoginRemember = r.Remember f.LoginRememberFor = r.RememberFor @@ -52,7 +51,7 @@ func (f *Flow) setHandledLoginRequest(r *consent.HandledLoginRequest) { f.LoginAuthenticatedAt = r.AuthenticatedAt } -func (f *Flow) setConsentRequest(r consent.OAuth2ConsentRequest) { +func (f *Flow) setConsentRequest(r OAuth2ConsentRequest) { f.ConsentChallengeID = sqlxx.NullString(r.ID) f.RequestedScope = r.RequestedScope f.RequestedAudience = r.RequestedAudience @@ -75,7 +74,7 @@ func (f *Flow) setConsentRequest(r consent.OAuth2ConsentRequest) { f.RequestedAt = r.RequestedAt } -func (f *Flow) setHandledConsentRequest(r consent.AcceptOAuth2ConsentRequest) { +func (f *Flow) setHandledConsentRequest(r AcceptOAuth2ConsentRequest) { f.ConsentChallengeID = sqlxx.NullString(r.ID) f.GrantedScope = r.GrantedScope f.GrantedAudience = r.GrantedAudience @@ -93,7 +92,7 @@ func (f *Flow) setHandledConsentRequest(r consent.AcceptOAuth2ConsentRequest) { func TestFlow_GetLoginRequest(t *testing.T) { t.Run("GetLoginRequest should set all fields on its return value", func(t *testing.T) { f := Flow{} - expected := consent.LoginRequest{} + expected := LoginRequest{} assert.NoError(t, faker.FakeData(&expected)) f.setLoginRequest(&expected) actual := f.GetLoginRequest() @@ -104,7 +103,7 @@ func TestFlow_GetLoginRequest(t *testing.T) { func TestFlow_GetHandledLoginRequest(t *testing.T) { t.Run("GetHandledLoginRequest should set all fields on its return value", func(t *testing.T) { f := Flow{} - expected := consent.HandledLoginRequest{} + expected := HandledLoginRequest{} assert.NoError(t, faker.FakeData(&expected)) f.setHandledLoginRequest(&expected) actual := f.GetHandledLoginRequest() @@ -117,7 +116,7 @@ func TestFlow_GetHandledLoginRequest(t *testing.T) { func TestFlow_NewFlow(t *testing.T) { t.Run("NewFlow and GetLoginRequest should use all LoginRequest fields", func(t *testing.T) { - expected := &consent.LoginRequest{} + expected := &LoginRequest{} assert.NoError(t, faker.FakeData(expected)) actual := NewFlow(expected).GetLoginRequest() assert.Equal(t, expected, actual) @@ -132,7 +131,7 @@ func TestFlow_HandleLoginRequest(t *testing.T) { assert.NoError(t, faker.FakeData(&f)) f.State = FlowStateLoginInitialized - r := consent.HandledLoginRequest{} + r := HandledLoginRequest{} assert.NoError(t, faker.FakeData(&r)) r.ID = f.ID r.Subject = f.Subject @@ -152,12 +151,12 @@ func TestFlow_HandleLoginRequest(t *testing.T) { func TestFlow_InvalidateLoginRequest(t *testing.T) { t.Run("InvalidateLoginRequest should transition the flow into FlowStateLoginUsed", func(t *testing.T) { - f := NewFlow(&consent.LoginRequest{ + f := NewFlow(&LoginRequest{ ID: "t3-id", Subject: "t3-sub", WasHandled: false, }) - assert.NoError(t, f.HandleLoginRequest(&consent.HandledLoginRequest{ + assert.NoError(t, f.HandleLoginRequest(&HandledLoginRequest{ ID: "t3-id", Subject: "t3-sub", WasHandled: false, @@ -167,12 +166,12 @@ func TestFlow_InvalidateLoginRequest(t *testing.T) { assert.Equal(t, true, f.LoginWasUsed) }) t.Run("InvalidateLoginRequest should fail when flow.LoginWasUsed is true", func(t *testing.T) { - f := NewFlow(&consent.LoginRequest{ + f := NewFlow(&LoginRequest{ ID: "t3-id", Subject: "t3-sub", WasHandled: false, }) - assert.NoError(t, f.HandleLoginRequest(&consent.HandledLoginRequest{ + assert.NoError(t, f.HandleLoginRequest(&HandledLoginRequest{ ID: "t3-id", Subject: "t3-sub", WasHandled: true, @@ -186,7 +185,7 @@ func TestFlow_InvalidateLoginRequest(t *testing.T) { func TestFlow_GetConsentRequest(t *testing.T) { t.Run("GetConsentRequest should set all fields on its return value", func(t *testing.T) { f := Flow{} - expected := consent.OAuth2ConsentRequest{} + expected := OAuth2ConsentRequest{} assert.NoError(t, faker.FakeData(&expected)) f.setConsentRequest(expected) actual := f.GetConsentRequest() @@ -198,13 +197,13 @@ func TestFlow_HandleConsentRequest(t *testing.T) { f := Flow{} require.NoError(t, faker.FakeData(&f)) - expected := consent.AcceptOAuth2ConsentRequest{} + expected := AcceptOAuth2ConsentRequest{} require.NoError(t, faker.FakeData(&expected)) expected.ID = string(f.ConsentChallengeID) expected.HandledAt = sqlxx.NullTime(time.Now()) expected.RequestedAt = f.RequestedAt - expected.Session = &consent.AcceptOAuth2ConsentRequestSession{ + expected.Session = &AcceptOAuth2ConsentRequestSession{ IDToken: sqlxx.MapStringInterface{"claim1": "value1", "claim2": "value2"}, AccessToken: sqlxx.MapStringInterface{"claim3": "value3", "claim4": "value4"}, } @@ -215,7 +214,7 @@ func TestFlow_HandleConsentRequest(t *testing.T) { f.ConsentWasHandled = false fGood := deepcopy.Copy(f).(Flow) - eGood := deepcopy.Copy(expected).(consent.AcceptOAuth2ConsentRequest) + eGood := deepcopy.Copy(expected).(AcceptOAuth2ConsentRequest) require.NoError(t, f.HandleConsentRequest(&expected)) t.Run("HandleConsentRequest should fail when already handled", func(t *testing.T) { @@ -232,7 +231,7 @@ func TestFlow_HandleConsentRequest(t *testing.T) { t.Run("HandleConsentRequest should fail when HandledAt in its argument is zero", func(t *testing.T) { f := deepcopy.Copy(fGood).(Flow) - eBad := deepcopy.Copy(eGood).(consent.AcceptOAuth2ConsentRequest) + eBad := deepcopy.Copy(eGood).(AcceptOAuth2ConsentRequest) eBad.HandledAt = sqlxx.NullTime(time.Time{}) require.Error(t, f.HandleConsentRequest(&eBad)) }) @@ -249,11 +248,11 @@ func TestFlow_HandleConsentRequest(t *testing.T) { func TestFlow_GetHandledConsentRequest(t *testing.T) { t.Run("GetHandledConsentRequest should set all fields on its return value", func(t *testing.T) { f := Flow{} - expected := consent.AcceptOAuth2ConsentRequest{} + expected := AcceptOAuth2ConsentRequest{} assert.NoError(t, faker.FakeData(&expected)) expected.ConsentRequest = nil - expected.Session = &consent.AcceptOAuth2ConsentRequestSession{ + expected.Session = &AcceptOAuth2ConsentRequestSession{ IDToken: sqlxx.MapStringInterface{"claim1": "value1", "claim2": "value2"}, AccessToken: sqlxx.MapStringInterface{"claim3": "value3", "claim4": "value4"}, } diff --git a/fositex/token_strategy_test.go b/fositex/token_strategy_test.go index 894b2470317..e308de58ef4 100644 --- a/fositex/token_strategy_test.go +++ b/fositex/token_strategy_test.go @@ -4,6 +4,7 @@ package fositex import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -14,6 +15,8 @@ import ( // Test that the generic signature function implements the same signature as the // HMAC and JWT strategies. func TestAccessTokenSignature(t *testing.T) { + ctx := context.Background() + t.Run("strategy=DefaultJWTStrategy", func(t *testing.T) { strategy := new(oauth2.DefaultJWTStrategy) for _, tc := range []struct{ token string }{ @@ -25,7 +28,7 @@ func TestAccessTokenSignature(t *testing.T) { } { t.Run("case="+tc.token, func(t *testing.T) { assert.Equal(t, - strategy.AccessTokenSignature(nil, tc.token), + strategy.AccessTokenSignature(ctx, tc.token), genericSignature(tc.token)) }) } @@ -41,7 +44,7 @@ func TestAccessTokenSignature(t *testing.T) { } { t.Run("case="+tc.token, func(t *testing.T) { assert.Equal(t, - strategy.AccessTokenSignature(nil, tc.token), + strategy.AccessTokenSignature(ctx, tc.token), genericSignature(tc.token)) }) } diff --git a/go.mod b/go.mod index eea9903722e..d39f63a04a7 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ory/hydra/v2 -go 1.19 +go 1.20 replace ( github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3 @@ -28,9 +28,7 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.1 - github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69 github.com/hashicorp/go-retryablehttp v0.7.2 - github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65 github.com/jackc/pgx/v4 v4.17.2 github.com/julienschmidt/httprouter v1.3.0 github.com/luna-duclos/instrumentedsql v1.1.3 @@ -61,10 +59,16 @@ require ( github.com/toqueteos/webbrowser v1.2.0 github.com/twmb/murmur3 v1.1.6 github.com/urfave/negroni v1.0.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 go.opentelemetry.io/otel v1.11.1 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.9.0 + go.opentelemetry.io/otel/sdk v1.11.1 + go.opentelemetry.io/otel/trace v1.11.1 go.uber.org/automaxprocs v1.3.0 + golang.org/x/crypto v0.9.0 + golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/oauth2 v0.5.0 - golang.org/x/tools v0.7.0 + golang.org/x/tools v0.9.1 gopkg.in/square/go-jose.v2 v2.6.0 ) @@ -215,25 +219,20 @@ require ( github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c // indirect go.mongodb.org/mongo-driver v1.10.3 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.36.4 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 // indirect go.opentelemetry.io/contrib/propagators/b3 v1.11.1 // indirect go.opentelemetry.io/contrib/propagators/jaeger v1.11.1 // indirect go.opentelemetry.io/contrib/samplers/jaegerremote v0.5.2 // indirect go.opentelemetry.io/otel/exporters/jaeger v1.11.1 // indirect go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.11.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.9.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.9.0 // indirect go.opentelemetry.io/otel/exporters/zipkin v1.11.1 // indirect go.opentelemetry.io/otel/metric v0.33.0 // indirect - go.opentelemetry.io/otel/sdk v1.11.1 // indirect - go.opentelemetry.io/otel/trace v1.11.1 // indirect go.opentelemetry.io/proto/otlp v0.18.0 // indirect - golang.org/x/crypto v0.1.0 // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/net v0.8.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.7.0 // indirect - golang.org/x/text v0.8.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sync v0.2.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230403163135-c38d8f061ccd // indirect diff --git a/go.sum b/go.sum index 4e547b545ec..c893e4d9a09 100644 --- a/go.sum +++ b/go.sum @@ -452,8 +452,6 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 h1:kr3j8iIMR4ywO/O0rvksXaJvauGGCMg2zAZIiNZ9uIQ= github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0/go.mod h1:ummNFgdgLhhX7aIiy35vVmQNS0rWXknfPE0qe6fmFXg= -github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69 h1:7xsUJsB2NrdcttQPa7JLEaGzvdbk7KvfrjgHZXOQRo0= -github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69/go.mod h1:YLEMZOtU+AZ7dhN9T/IpGhXVGly2bvkJQ+zxj3WeVQo= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= @@ -480,8 +478,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf h1:FtEj8sfIcaaBfAKrE1Cwb61YDtYq9JxChK1c7AKce7s= github.com/inhies/go-bytesize v0.0.0-20220417184213-4913239db9cf/go.mod h1:yrqSXGoD/4EKfF26AOGzscPOgTTJcyAwM2rpixWT+t4= -github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65 h1:T25FL3WEzgmKB0m6XCJNZ65nw09/QIp3T1yXr487D+A= -github.com/instana/testify v1.6.2-0.20200721153833-94b1851f4d65/go.mod h1:nYhEREG/B7HUY7P+LKOrqy53TpIqmJ9JyUShcaEKtGw= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -1018,8 +1014,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220517005047-85d78b3ac167/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1030,6 +1026,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1105,8 +1103,8 @@ golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1136,8 +1134,9 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1219,8 +1218,8 @@ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1234,8 +1233,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1305,8 +1304,8 @@ golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= -golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= +golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= +golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/health/doc.go b/health/doc.go index bad9c42139c..a0b2f45cbe8 100644 --- a/health/doc.go +++ b/health/doc.go @@ -24,6 +24,8 @@ package health // Responses: // 200: healthStatus // 500: errorOAuth2 +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions func swaggerPublicIsInstanceAlive() {} // Alive returns an ok status if the instance is ready to handle HTTP requests. @@ -47,6 +49,8 @@ func swaggerPublicIsInstanceAlive() {} // Responses: // 200: healthStatus // 500: errorOAuth2 +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions func swaggerAdminIsInstanceAlive() {} // Ready returns an ok status if the instance is ready to handle HTTP requests and all ReadyCheckers are ok. @@ -70,6 +74,8 @@ func swaggerAdminIsInstanceAlive() {} // Responses: // 200: healthStatus // 503: healthNotReadyStatus +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions func swaggerAdminIsInstanceReady() {} // Ready returns an ok status if the instance is ready to handle HTTP requests and all ReadyCheckers are ok. @@ -93,6 +99,8 @@ func swaggerAdminIsInstanceReady() {} // Responses: // 200: healthStatus // 503: healthNotReadyStatus +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions func swaggerPublicIsInstanceReady() {} // Version returns this service's versions. @@ -111,4 +119,6 @@ func swaggerPublicIsInstanceReady() {} // // Responses: // 200: version +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions func swaggerGetVersion() {} diff --git a/internal/driver.go b/internal/driver.go index bd8de6ae547..a462f29ace5 100644 --- a/internal/driver.go +++ b/internal/driver.go @@ -5,10 +5,11 @@ package internal import ( "context" - "sync" "testing" + "gopkg.in/square/go-jose.v2" + "github.com/ory/x/configx" "github.com/stretchr/testify/require" @@ -45,19 +46,19 @@ func NewConfigurationWithDefaultsAndHTTPS() *config.DefaultProvider { return p } -func NewRegistryMemory(t *testing.T, c *config.DefaultProvider, ctxer contextx.Contextualizer) driver.Registry { +func NewRegistryMemory(t testing.TB, c *config.DefaultProvider, ctxer contextx.Contextualizer) driver.Registry { return newRegistryDefault(t, "memory", c, true, ctxer) } -func NewMockedRegistry(t *testing.T, ctxer contextx.Contextualizer) driver.Registry { +func NewMockedRegistry(t testing.TB, ctxer contextx.Contextualizer) driver.Registry { return newRegistryDefault(t, "memory", NewConfigurationWithDefaults(), true, ctxer) } -func NewRegistrySQLFromURL(t *testing.T, url string, migrate bool, ctxer contextx.Contextualizer) driver.Registry { +func NewRegistrySQLFromURL(t testing.TB, url string, migrate bool, ctxer contextx.Contextualizer) driver.Registry { return newRegistryDefault(t, url, NewConfigurationWithDefaults(), migrate, ctxer) } -func newRegistryDefault(t *testing.T, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry { +func newRegistryDefault(t testing.TB, url string, c *config.DefaultProvider, migrate bool, ctxer contextx.Contextualizer) driver.Registry { ctx := context.Background() c.MustSet(ctx, config.KeyLogLevel, "trace") c.MustSet(ctx, config.KeyDSN, url) @@ -77,15 +78,15 @@ func CleanAndMigrate(reg driver.Registry) func(*testing.T) { } } -func ConnectToMySQL(t *testing.T) string { - return dockertest.RunTestMySQLWithVersion(t, "11.8") +func ConnectToMySQL(t testing.TB) string { + return dockertest.RunTestMySQLWithVersion(t, "8.0.26") } -func ConnectToPG(t *testing.T) string { +func ConnectToPG(t testing.TB) string { return dockertest.RunTestPostgreSQLWithVersion(t, "11.8") } -func ConnectToCRDB(t *testing.T) string { +func ConnectToCRDB(t testing.TB) string { return dockertest.RunTestCockroachDBWithVersion(t, "v22.1.2") } @@ -134,8 +135,8 @@ func ConnectDatabases(t *testing.T, migrate bool, ctxer contextx.Contextualizer) return } -func MustEnsureRegistryKeys(r driver.Registry, key string) { - if err := jwk.EnsureAsymmetricKeypairExists(context.Background(), r, "RS256", key); err != nil { +func MustEnsureRegistryKeys(ctx context.Context, r driver.Registry, key string) { + if err := jwk.EnsureAsymmetricKeypairExists(ctx, r, string(jose.ES256), key); err != nil { panic(err) } } diff --git a/internal/testhelpers/janitor_test_helper.go b/internal/testhelpers/janitor_test_helper.go index 9954d013667..334f490364f 100644 --- a/internal/testhelpers/janitor_test_helper.go +++ b/internal/testhelpers/janitor_test_helper.go @@ -20,6 +20,7 @@ import ( "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/internal" "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" @@ -32,8 +33,8 @@ import ( type JanitorConsentTestHelper struct { uniqueName string - flushLoginRequests []*consent.LoginRequest - flushConsentRequests []*consent.OAuth2ConsentRequest + flushLoginRequests []*flow.LoginRequest + flushConsentRequests []*flow.OAuth2ConsentRequest flushAccessRequests []*fosite.Request flushRefreshRequests []*fosite.AccessRequest flushGrants []*createGrantRequest @@ -69,7 +70,7 @@ func NewConsentJanitorTestHelper(uniqueName string) *JanitorConsentTestHelper { } } -func (j *JanitorConsentTestHelper) GetDSN(ctx context.Context) string { +func (j *JanitorConsentTestHelper) GetDSN() string { return j.conf.DSN() } @@ -149,7 +150,7 @@ func (j *JanitorConsentTestHelper) RefreshTokenNotAfterValidate(ctx context.Cont } } -func (j *JanitorConsentTestHelper) GrantNotAfterSetup(ctx context.Context, cl client.Manager, gr trust.GrantManager) func(t *testing.T) { +func (j *JanitorConsentTestHelper) GrantNotAfterSetup(ctx context.Context, gr trust.GrantManager) func(t *testing.T) { return func(t *testing.T) { for _, fg := range j.flushGrants { require.NoError(t, gr.CreateGrant(ctx, fg.grant, fg.pk)) @@ -180,21 +181,29 @@ func (j *JanitorConsentTestHelper) GrantNotAfterValidate(ctx context.Context, no } } -func (j *JanitorConsentTestHelper) LoginRejectionSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) { - return func(t *testing.T) { - var err error +func (j *JanitorConsentTestHelper) LoginRejectionSetup(ctx context.Context, reg interface { + consent.ManagerProvider + client.ManagerProvider + flow.CipherProvider +}) func(t *testing.T) { + cm := reg.ConsentManager() + cl := reg.ClientManager() + return func(t *testing.T) { // Create login requests for _, r := range j.flushLoginRequests { require.NoError(t, cl.CreateClient(ctx, r.Client)) - require.NoError(t, cm.CreateLoginRequest(ctx, r)) - } + f, err := cm.CreateLoginRequest(ctx, r) + require.NoError(t, err) - // Explicit rejection - for _, r := range j.flushLoginRequests { + f.RequestedAt = time.Now() // we won't handle expired flows + f.LoginAuthenticatedAt = r.AuthenticatedAt + challenge := x.Must(f.ToLoginChallenge(ctx, reg)) + + // Explicit rejection if r.ID == j.flushLoginRequests[0].ID { // accept this one - _, err = cm.HandleLoginRequest(ctx, r.ID, consent.NewHandledLoginRequest( + _, err = cm.HandleLoginRequest(ctx, f, challenge, consent.NewHandledLoginRequest( r.ID, false, r.RequestedAt, r.AuthenticatedAt)) require.NoError(t, err) @@ -202,7 +211,7 @@ func (j *JanitorConsentTestHelper) LoginRejectionSetup(ctx context.Context, cm c } // reject flush-login-2 and 3 - _, err = cm.HandleLoginRequest(ctx, r.ID, consent.NewHandledLoginRequest( + _, err = cm.HandleLoginRequest(ctx, f, challenge, consent.NewHandledLoginRequest( r.ID, true, r.RequestedAt, r.AuthenticatedAt)) require.NoError(t, err) } @@ -215,28 +224,38 @@ func (j *JanitorConsentTestHelper) LoginRejectionValidate(ctx context.Context, c for _, r := range j.flushLoginRequests { t.Logf("check login: %s", r.ID) _, err := cm.GetLoginRequest(ctx, r.ID) - if r.ID == j.flushLoginRequests[0].ID { - require.NoError(t, err) - } else { - require.Error(t, err) - } + // Login requests should never be persisted. + require.Error(t, err) } } } -func (j *JanitorConsentTestHelper) LimitSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) { +func (j *JanitorConsentTestHelper) LimitSetup(ctx context.Context, reg interface { + consent.ManagerProvider + client.ManagerProvider + flow.CipherProvider +}) func(t *testing.T) { + cl := reg.ClientManager() + cm := reg.ConsentManager() + return func(t *testing.T) { - var err error + var ( + err error + f *flow.Flow + ) // Create login requests for _, r := range j.flushLoginRequests { require.NoError(t, cl.CreateClient(ctx, r.Client)) - require.NoError(t, cm.CreateLoginRequest(ctx, r)) - } + f, err = cm.CreateLoginRequest(ctx, r) + require.NoError(t, err) - // Reject each request - for _, r := range j.flushLoginRequests { - _, err = cm.HandleLoginRequest(ctx, r.ID, consent.NewHandledLoginRequest( + // Reject each request + f.RequestedAt = time.Now() // we won't handle expired flows + f.LoginAuthenticatedAt = r.AuthenticatedAt + challenge := x.Must(f.ToLoginChallenge(ctx, reg)) + + _, err = cm.HandleLoginRequest(ctx, f, challenge, consent.NewHandledLoginRequest( r.ID, true, r.RequestedAt, r.AuthenticatedAt)) require.NoError(t, err) } @@ -249,41 +268,50 @@ func (j *JanitorConsentTestHelper) LimitValidate(ctx context.Context, cm consent for _, r := range j.flushLoginRequests { t.Logf("check login: %s", r.ID) _, err := cm.GetLoginRequest(ctx, r.ID) - if r.ID == j.flushLoginRequests[0].ID { - require.NoError(t, err) - } else { - require.Error(t, err) - } + // No Requests should have been persisted. + require.Error(t, err) } } } -func (j *JanitorConsentTestHelper) ConsentRejectionSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) { +func (j *JanitorConsentTestHelper) ConsentRejectionSetup(ctx context.Context, reg interface { + consent.ManagerProvider + client.ManagerProvider + flow.CipherProvider +}) func(t *testing.T) { + cl := reg.ClientManager() + cm := reg.ConsentManager() + return func(t *testing.T) { - var err error + var ( + err error + f *flow.Flow + ) // Create login requests - for _, r := range j.flushLoginRequests { - require.NoError(t, cl.CreateClient(ctx, r.Client)) - require.NoError(t, cm.CreateLoginRequest(ctx, r)) - } + for i, loginRequest := range j.flushLoginRequests { + require.NoError(t, cl.CreateClient(ctx, loginRequest.Client)) + f, err = cm.CreateLoginRequest(ctx, loginRequest) + require.NoError(t, err) - // Create consent requests - for _, r := range j.flushConsentRequests { - require.NoError(t, cm.CreateConsentRequest(ctx, r)) - } + // Create consent requests + consentRequest := j.flushConsentRequests[i] + err = cm.CreateConsentRequest(ctx, f, consentRequest) + require.NoError(t, err) - //Reject the consents - for _, r := range j.flushConsentRequests { - if r.ID == j.flushConsentRequests[0].ID { + f.RequestedAt = time.Now() // we won't handle expired flows + f.LoginAuthenticatedAt = consentRequest.AuthenticatedAt + + // Reject the consents + if consentRequest.ID == j.flushConsentRequests[0].ID { // accept this one - _, err = cm.HandleConsentRequest(ctx, consent.NewHandledConsentRequest( - r.ID, false, r.RequestedAt, r.AuthenticatedAt)) + _, err = cm.HandleConsentRequest(ctx, f, consent.NewHandledConsentRequest( + consentRequest.ID, false, consentRequest.RequestedAt, consentRequest.AuthenticatedAt)) require.NoError(t, err) continue } - _, err = cm.HandleConsentRequest(ctx, consent.NewHandledConsentRequest( - r.ID, true, r.RequestedAt, r.AuthenticatedAt)) + _, err = cm.HandleConsentRequest(ctx, f, consent.NewHandledConsentRequest( + consentRequest.ID, true, consentRequest.RequestedAt, consentRequest.AuthenticatedAt)) require.NoError(t, err) } } @@ -295,32 +323,43 @@ func (j *JanitorConsentTestHelper) ConsentRejectionValidate(ctx context.Context, for _, r := range j.flushConsentRequests { t.Logf("check consent: %s", r.ID) _, err = cm.GetConsentRequest(ctx, r.ID) - if r.ID == j.flushConsentRequests[0].ID { - require.NoError(t, err) - } else { - require.Error(t, err) - } + // Consent requests should never be persisted. + require.Error(t, err) } } } -func (j *JanitorConsentTestHelper) LoginTimeoutSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) { +func (j *JanitorConsentTestHelper) LoginTimeoutSetup(ctx context.Context, reg interface { + consent.ManagerProvider + client.ManagerProvider + flow.CipherProvider +}) func(t *testing.T) { + cl := reg.ClientManager() + cm := reg.ConsentManager() + return func(t *testing.T) { - var err error + var ( + err error + f *flow.Flow + ) // Create login requests - for _, r := range j.flushLoginRequests { - require.NoError(t, cl.CreateClient(ctx, r.Client)) - require.NoError(t, cm.CreateLoginRequest(ctx, r)) - } + for i, loginRequest := range j.flushLoginRequests { + require.NoError(t, cl.CreateClient(ctx, loginRequest.Client)) + f, err = cm.CreateLoginRequest(ctx, loginRequest) + require.NoError(t, err) - // Creating at least 1 that has not timed out - _, err = cm.HandleLoginRequest(ctx, j.flushLoginRequests[0].ID, &consent.HandledLoginRequest{ - ID: j.flushLoginRequests[0].ID, - RequestedAt: j.flushLoginRequests[0].RequestedAt, - AuthenticatedAt: j.flushLoginRequests[0].AuthenticatedAt, - WasHandled: true, - }) + if i == 0 { + // Creating at least 1 that has not timed out + challenge := x.Must(f.ToLoginChallenge(ctx, reg)) + _, err = cm.HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{ + ID: loginRequest.ID, + RequestedAt: loginRequest.RequestedAt, + AuthenticatedAt: loginRequest.AuthenticatedAt, + WasHandled: true, + }) + } + } require.NoError(t, err) } @@ -328,51 +367,56 @@ func (j *JanitorConsentTestHelper) LoginTimeoutSetup(ctx context.Context, cm con func (j *JanitorConsentTestHelper) LoginTimeoutValidate(ctx context.Context, cm consent.Manager) func(t *testing.T) { return func(t *testing.T) { - var err error - for _, r := range j.flushLoginRequests { - _, err = cm.GetLoginRequest(ctx, r.ID) - if r.ID == j.flushLoginRequests[0].ID { - require.NoError(t, err) - } else { - require.Error(t, err) - } + _, err := cm.GetLoginRequest(ctx, r.ID) + // Login requests should never be persisted. + require.Error(t, err) } - } } -func (j *JanitorConsentTestHelper) ConsentTimeoutSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) { - return func(t *testing.T) { - var err error +func (j *JanitorConsentTestHelper) ConsentTimeoutSetup(ctx context.Context, reg interface { + consent.ManagerProvider + client.ManagerProvider + flow.CipherProvider +}) func(t *testing.T) { + cl := reg.ClientManager() + cm := reg.ConsentManager() + return func(t *testing.T) { // Let's reset and accept all login requests to test the consent requests - for _, r := range j.flushLoginRequests { - require.NoError(t, cl.CreateClient(ctx, r.Client)) - require.NoError(t, cm.CreateLoginRequest(ctx, r)) - _, err = cm.HandleLoginRequest(ctx, r.ID, &consent.HandledLoginRequest{ - ID: r.ID, - AuthenticatedAt: r.AuthenticatedAt, - RequestedAt: r.RequestedAt, + for i, loginRequest := range j.flushLoginRequests { + require.NoError(t, cl.CreateClient(ctx, loginRequest.Client)) + f, err := cm.CreateLoginRequest(ctx, loginRequest) + require.NoError(t, err) + f.RequestedAt = time.Now() // we won't handle expired flows + challenge := x.Must(f.ToLoginChallenge(ctx, reg)) + _, err = cm.HandleLoginRequest(ctx, f, challenge, &flow.HandledLoginRequest{ + ID: loginRequest.ID, + AuthenticatedAt: loginRequest.AuthenticatedAt, + RequestedAt: loginRequest.RequestedAt, WasHandled: true, }) require.NoError(t, err) - } - // Create consent requests - for _, r := range j.flushConsentRequests { - require.NoError(t, cm.CreateConsentRequest(ctx, r)) + // Create consent requests + consentRequest := j.flushConsentRequests[i] + err = cm.CreateConsentRequest(ctx, f, consentRequest) + require.NoError(t, err) + + if i == 0 { + // Create at least 1 consent request that has been accepted + _, err = cm.HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{ + ID: consentRequest.ID, + WasHandled: true, + HandledAt: sqlxx.NullTime(time.Now()), + RequestedAt: consentRequest.RequestedAt, + AuthenticatedAt: consentRequest.AuthenticatedAt, + }) + require.NoError(t, err) + } } - // Create at least 1 consent request that has been accepted - _, err = cm.HandleConsentRequest(ctx, &consent.AcceptOAuth2ConsentRequest{ - ID: j.flushConsentRequests[0].ID, - WasHandled: true, - HandledAt: sqlxx.NullTime(time.Now()), - RequestedAt: j.flushConsentRequests[0].RequestedAt, - AuthenticatedAt: j.flushConsentRequests[0].AuthenticatedAt, - }) - require.NoError(t, err) } } @@ -382,40 +426,58 @@ func (j *JanitorConsentTestHelper) ConsentTimeoutValidate(ctx context.Context, c for _, r := range j.flushConsentRequests { _, err = cm.GetConsentRequest(ctx, r.ID) - if r.ID == j.flushConsentRequests[0].ID { - require.NoError(t, err) - } else { - require.Error(t, err) - } + require.Error(t, err, "Unverified consent requests are never pesisted") } - } } func (j *JanitorConsentTestHelper) LoginConsentNotAfterSetup(ctx context.Context, cm consent.Manager, cl client.Manager) func(t *testing.T) { return func(t *testing.T) { + var ( + f *flow.Flow + err error + ) for _, r := range j.flushLoginRequests { require.NoError(t, cl.CreateClient(ctx, r.Client)) - require.NoError(t, cm.CreateLoginRequest(ctx, r)) + f, err = cm.CreateLoginRequest(ctx, r) + require.NoError(t, err) } for _, r := range j.flushConsentRequests { - require.NoError(t, cm.CreateConsentRequest(ctx, r)) + f.ID = r.LoginChallenge.String() + err = cm.CreateConsentRequest(ctx, f, r) + require.NoError(t, err) } } } -func (j *JanitorConsentTestHelper) LoginConsentNotAfterValidate(ctx context.Context, notAfter time.Time, consentRequestLifespan time.Time, cm consent.Manager) func(t *testing.T) { +func (j *JanitorConsentTestHelper) LoginConsentNotAfterValidate( + ctx context.Context, + notAfter time.Time, + consentRequestLifespan time.Time, + reg interface { + consent.ManagerProvider + flow.CipherProvider + }, +) func(t *testing.T) { return func(t *testing.T) { - var err error + var ( + err error + f *flow.Flow + ) for _, r := range j.flushLoginRequests { - t.Logf("login flush check:\nNotAfter: %s\nConsentRequest: %s\n%+v\n", - notAfter.String(), consentRequestLifespan.String(), r) - _, err = cm.GetLoginRequest(ctx, r.ID) + isExpired := r.RequestedAt.Before(consentRequestLifespan) + t.Logf("login flush check:\nNotAfter: %s\nLoginRequest: %s\nis expired: %v\n%+v\n", + notAfter.String(), consentRequestLifespan.String(), isExpired, r) + + f = x.Must(reg.ConsentManager().CreateLoginRequest(ctx, r)) + loginChallenge := x.Must(f.ToLoginChallenge(ctx, reg)) + + _, err = reg.ConsentManager().GetLoginRequest(ctx, loginChallenge) // if the lowest between notAfter and consent-request-lifespan is greater than requested_at // then the it should expect the value to be deleted. - if j.notAfterCheck(notAfter, consentRequestLifespan, r.RequestedAt) { + if isExpired { // value has been deleted here require.Error(t, err) } else { @@ -425,12 +487,19 @@ func (j *JanitorConsentTestHelper) LoginConsentNotAfterValidate(ctx context.Cont } for _, r := range j.flushConsentRequests { - t.Logf("consent flush check:\nNotAfter: %s\nConsentRequest: %s\n%+v\n", - notAfter.String(), consentRequestLifespan.String(), r) - _, err = cm.GetConsentRequest(ctx, r.ID) + isExpired := r.RequestedAt.Before(consentRequestLifespan) + t.Logf("consent flush check:\nNotAfter: %s\nConsentRequest: %s\nis expired: %v\n%+v\n", + notAfter.String(), consentRequestLifespan.String(), isExpired, r) + + f.ID = r.LoginChallenge.String() + require.NoError(t, reg.ConsentManager().CreateConsentRequest(ctx, f, r)) + f.RequestedAt = r.RequestedAt + consentChallenge := x.Must(f.ToConsentChallenge(ctx, reg)) + + _, err = reg.ConsentManager().GetConsentRequest(ctx, consentChallenge) // if the lowest between notAfter and consent-request-lifespan is greater than requested_at // then the it should expect the value to be deleted. - if j.notAfterCheck(notAfter, consentRequestLifespan, r.RequestedAt) { + if isExpired { // value has been deleted here require.Error(t, err) } else { @@ -470,8 +539,22 @@ func (j *JanitorConsentTestHelper) notAfterCheck(notAfter time.Time, lifespan ti return lesser.Unix() > requestedAt.Unix() } -func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, clientManager client.Manager, fositeManager x.FositeStorer, network string, parallel bool) func(t *testing.T) { +func JanitorTests( + reg interface { + ConsentManager() consent.Manager + OAuth2Storage() x.FositeStorer + config.Provider + client.ManagerProvider + flow.CipherProvider + }, + network string, + parallel bool, +) func(t *testing.T) { return func(t *testing.T) { + consentManager := reg.ConsentManager() + clientManager := reg.ClientManager() + fositeManager := reg.OAuth2Storage() + if parallel { t.Parallel() } @@ -479,7 +562,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, jt := NewConsentJanitorTestHelper(network + t.Name()) - conf.MustSet(context.Background(), config.KeyConsentRequestMaxAge, jt.GetConsentRequestLifespan(ctx)) + reg.Config().MustSet(context.Background(), config.KeyConsentRequestMaxAge, jt.GetConsentRequestLifespan(ctx)) t.Run("case=flush-consent-request-not-after", func(t *testing.T) { @@ -500,7 +583,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, }) // validate test - t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentRequestLifespan, consentManager)) + t.Run("step=validate", jt.LoginConsentNotAfterValidate(ctx, notAfter, consentRequestLifespan, reg)) }) } @@ -511,7 +594,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, t.Run("case=limit", func(t *testing.T) { // setup - t.Run("step=setup", jt.LimitSetup(ctx, consentManager, clientManager)) + t.Run("step=setup", jt.LimitSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { @@ -528,7 +611,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, t.Run(fmt.Sprintf("case=%s", "loginRejection"), func(t *testing.T) { // setup - t.Run("step=setup", jt.LoginRejectionSetup(ctx, consentManager, clientManager)) + t.Run("step=setup", jt.LoginRejectionSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { @@ -543,7 +626,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, t.Run(fmt.Sprintf("case=%s", "consentRejection"), func(t *testing.T) { // setup - t.Run("step=setup", jt.ConsentRejectionSetup(ctx, consentManager, clientManager)) + t.Run("step=setup", jt.ConsentRejectionSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { @@ -562,7 +645,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, t.Run(fmt.Sprintf("case=%s", "login-timeout"), func(t *testing.T) { // setup - t.Run("step=setup", jt.LoginTimeoutSetup(ctx, consentManager, clientManager)) + t.Run("step=setup", jt.LoginTimeoutSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { @@ -579,7 +662,7 @@ func JanitorTests(conf *config.DefaultProvider, consentManager consent.Manager, t.Run(fmt.Sprintf("case=%s", "consent-timeout"), func(t *testing.T) { // setup - t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, consentManager, clientManager)) + t.Run("step=setup", jt.ConsentTimeoutSetup(ctx, reg)) // cleanup t.Run("step=cleanup", func(t *testing.T) { @@ -627,7 +710,7 @@ func getAccessRequests(uniqueName string, lifespan time.Duration) []*fosite.Requ } func getRefreshRequests(uniqueName string, lifespan time.Duration) []*fosite.AccessRequest { - var tokenSignature = "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31" + var tokenSignature = "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31" //nolint:gosec return []*fosite.AccessRequest{ { GrantTypes: []string{ @@ -680,8 +763,8 @@ func getRefreshRequests(uniqueName string, lifespan time.Duration) []*fosite.Acc } } -func genLoginRequests(uniqueName string, lifespan time.Duration) []*consent.LoginRequest { - return []*consent.LoginRequest{ +func genLoginRequests(uniqueName string, lifespan time.Duration) []*flow.LoginRequest { + return []*flow.LoginRequest{ { ID: fmt.Sprintf("%s_flush-login-1", uniqueName), RequestedScope: []string{"foo", "bar"}, @@ -704,8 +787,8 @@ func genLoginRequests(uniqueName string, lifespan time.Duration) []*consent.Logi RedirectURIs: []string{"http://redirect"}, }, RequestURL: "http://redirect", - RequestedAt: time.Now().Round(time.Second).Add(-(lifespan + time.Minute)), - AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).Add(-(lifespan + time.Minute))), + RequestedAt: time.Now().Round(time.Second).Add(-(lifespan + 10*time.Minute)), + AuthenticatedAt: sqlxx.NullTime(time.Now().Round(time.Second).Add(-(lifespan + 10*time.Minute))), Verifier: fmt.Sprintf("%s_flush-login-2", uniqueName), }, { @@ -724,8 +807,8 @@ func genLoginRequests(uniqueName string, lifespan time.Duration) []*consent.Logi } } -func genConsentRequests(uniqueName string, lifespan time.Duration) []*consent.OAuth2ConsentRequest { - return []*consent.OAuth2ConsentRequest{ +func genConsentRequests(uniqueName string, lifespan time.Duration) []*flow.OAuth2ConsentRequest { + return []*flow.OAuth2ConsentRequest{ { ID: fmt.Sprintf("%s_flush-consent-1", uniqueName), RequestedScope: []string{"foo", "bar"}, diff --git a/internal/testhelpers/oauth2.go b/internal/testhelpers/oauth2.go index d637ee921ad..41f0ddaec8e 100644 --- a/internal/testhelpers/oauth2.go +++ b/internal/testhelpers/oauth2.go @@ -7,14 +7,17 @@ import ( "bytes" "context" "encoding/json" + "errors" "net/http" "net/http/cookiejar" + "net/http/httptest" "net/url" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "github.com/ory/fosite/token/jwt" @@ -26,8 +29,6 @@ import ( "github.com/ory/x/httpx" "github.com/ory/x/ioutilx" - "net/http/httptest" - "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/driver" "github.com/ory/hydra/v2/driver/config" @@ -55,7 +56,7 @@ func NewIDTokenWithClaims(t *testing.T, reg driver.Registry, claims jwt.MapClaim return token } -func NewOAuth2Server(ctx context.Context, t *testing.T, reg driver.Registry) (publicTS, adminTS *httptest.Server) { +func NewOAuth2Server(ctx context.Context, t testing.TB, reg driver.Registry) (publicTS, adminTS *httptest.Server) { // Lifespan is two seconds to avoid time synchronization issues with SQL. reg.Config().MustSet(ctx, config.KeySubjectIdentifierAlgorithmSalt, "76d5d2bf-747f-4592-9fbd-d2b895a54b3a") reg.Config().MustSet(ctx, config.KeyAccessTokenLifespan, time.Second*2) @@ -66,19 +67,22 @@ func NewOAuth2Server(ctx context.Context, t *testing.T, reg driver.Registry) (pu public, admin := x.NewRouterPublic(), x.NewRouterAdmin(reg.Config().AdminURL) - publicTS = httptest.NewServer(public) + internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) + + reg.RegisterRoutes(ctx, admin, public) + + publicTS = httptest.NewServer(otelhttp.NewHandler(public, "public", otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { + return r.URL.Path + }))) t.Cleanup(publicTS.Close) - adminTS = httptest.NewServer(admin) + adminTS = httptest.NewServer(otelhttp.NewHandler(admin, "admin", otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { + return r.URL.Path + }))) t.Cleanup(adminTS.Close) reg.Config().MustSet(ctx, config.KeyIssuerURL, publicTS.URL) - // SendDebugMessagesToClients: true, - - internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(reg, x.OAuth2JWTKeyName) - - reg.RegisterRoutes(ctx, admin, public) return publicTS, adminTS } @@ -93,7 +97,7 @@ func DecodeIDToken(t *testing.T, token *oauth2.Token) gjson.Result { return gjson.ParseBytes(body) } -func IntrospectToken(t *testing.T, conf *oauth2.Config, token string, adminTS *httptest.Server) gjson.Result { +func IntrospectToken(t testing.TB, conf *oauth2.Config, token string, adminTS *httptest.Server) gjson.Result { require.NotEmpty(t, token) req := httpx.MustNewRequest("POST", adminTS.URL+"/admin/oauth2/introspect", @@ -140,13 +144,13 @@ func HTTPServerNotImplementedHandler(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNotImplemented) } -func HTTPServerNoExpectedCallHandler(t *testing.T) http.HandlerFunc { +func HTTPServerNoExpectedCallHandler(t testing.TB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { t.Fatal("This should not have been called") } } -func NewLoginConsentUI(t *testing.T, c *config.DefaultProvider, login, consent http.HandlerFunc) { +func NewLoginConsentUI(t testing.TB, c *config.DefaultProvider, login, consent http.HandlerFunc) { if login == nil { login = HTTPServerNotImplementedHandler } @@ -165,7 +169,7 @@ func NewLoginConsentUI(t *testing.T, c *config.DefaultProvider, login, consent h c.MustSet(context.Background(), config.KeyConsentURL, ct.URL) } -func NewCallbackURL(t *testing.T, prefix string, h http.HandlerFunc) string { +func NewCallbackURL(t testing.TB, prefix string, h http.HandlerFunc) string { if h == nil { h = HTTPServerNotImplementedHandler } @@ -180,14 +184,35 @@ func NewCallbackURL(t *testing.T, prefix string, h http.HandlerFunc) string { return ts.URL + "/" + prefix } -func NewEmptyCookieJar(t *testing.T) *cookiejar.Jar { +func NewEmptyCookieJar(t testing.TB) *cookiejar.Jar { c, err := cookiejar.New(&cookiejar.Options{}) require.NoError(t, err) return c } -func NewEmptyJarClient(t *testing.T) *http.Client { +func NewEmptyJarClient(t testing.TB) *http.Client { return &http.Client{ - Jar: NewEmptyCookieJar(t), + Jar: NewEmptyCookieJar(t), + Transport: &loggingTransport{t}, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + //t.Logf("Redirect to %s", req.URL.String()) + + if len(via) >= 20 { + for k, v := range via { + t.Logf("Failed with redirect (%d): %s", k, v.URL.String()) + } + return errors.New("stopped after 20 redirects") + } + return nil + }, } } + +type loggingTransport struct{ t testing.TB } + +func (s *loggingTransport) RoundTrip(r *http.Request) (*http.Response, error) { + //s.t.Logf("%s %s", r.Method, r.URL.String()) + //s.t.Logf("%s %s\nWith Cookies: %v", r.Method, r.URL.String(), r.Cookies()) + + return otelhttp.DefaultClient.Transport.RoundTrip(r) +} diff --git a/jwk/aead.go b/jwk/aead.go deleted file mode 100644 index 081f5b87038..00000000000 --- a/jwk/aead.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package jwk - -import ( - "context" - "encoding/base64" - - "github.com/ory/x/errorsx" - - "github.com/ory/hydra/v2/driver/config" - - "github.com/gtank/cryptopasta" - "github.com/pkg/errors" -) - -type AEAD struct { - c *config.DefaultProvider -} - -func NewAEAD(c *config.DefaultProvider) *AEAD { - return &AEAD{c: c} -} - -func aeadKey(key []byte) *[32]byte { - var result [32]byte - copy(result[:], key[:32]) - return &result -} - -func (c *AEAD) Encrypt(ctx context.Context, plaintext []byte) (string, error) { - global, err := c.c.GetGlobalSecret(ctx) - if err != nil { - return "", err - } - - rotated, err := c.c.GetRotatedGlobalSecrets(ctx) - if err != nil { - return "", err - } - - keys := append([][]byte{global}, rotated...) - if len(keys) == 0 { - return "", errors.Errorf("at least one encryption key must be defined but none were") - } - - if len(keys[0]) < 32 { - return "", errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(keys[0])) - } - - ciphertext, err := cryptopasta.Encrypt(plaintext, aeadKey(keys[0])) - if err != nil { - return "", errorsx.WithStack(err) - } - - return base64.URLEncoding.EncodeToString(ciphertext), nil -} - -func (c *AEAD) Decrypt(ctx context.Context, ciphertext string) (p []byte, err error) { - global, err := c.c.GetGlobalSecret(ctx) - if err != nil { - return nil, err - } - - rotated, err := c.c.GetRotatedGlobalSecrets(ctx) - if err != nil { - return nil, err - } - - keys := append([][]byte{global}, rotated...) - if len(keys) == 0 { - return nil, errors.Errorf("at least one decryption key must be defined but none were") - } - - for _, key := range keys { - if p, err = c.decrypt(ciphertext, key); err == nil { - return p, nil - } - } - - return nil, err -} - -func (c *AEAD) decrypt(ciphertext string, key []byte) ([]byte, error) { - if len(key) != 32 { - return nil, errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(key)) - } - - raw, err := base64.URLEncoding.DecodeString(ciphertext) - if err != nil { - return nil, errorsx.WithStack(err) - } - - plaintext, err := cryptopasta.Decrypt(raw, aeadKey(key)) - if err != nil { - return nil, errorsx.WithStack(err) - } - - return plaintext, nil -} diff --git a/jwk/aead_test.go b/jwk/aead_test.go deleted file mode 100644 index 890918dde72..00000000000 --- a/jwk/aead_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package jwk_test - -import ( - "context" - "crypto/rand" - "fmt" - "io" - "testing" - - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/internal" - . "github.com/ory/hydra/v2/jwk" - - "github.com/pborman/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func secret(t *testing.T) string { - bytes := make([]byte, 32) - _, err := io.ReadFull(rand.Reader, bytes) - require.NoError(t, err) - return fmt.Sprintf("%X", bytes) -} - -func TestAEAD(t *testing.T) { - ctx := context.Background() - c := internal.NewConfigurationWithDefaults() - t.Run("case=without-rotation", func(t *testing.T) { - c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) - a := NewAEAD(c) - - plain := []byte(uuid.New()) - ct, err := a.Encrypt(ctx, plain) - assert.NoError(t, err) - - res, err := a.Decrypt(ctx, ct) - assert.NoError(t, err) - assert.Equal(t, plain, res) - }) - - t.Run("case=wrong-secret", func(t *testing.T) { - c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) - a := NewAEAD(c) - - ct, err := a.Encrypt(ctx, []byte(uuid.New())) - require.NoError(t, err) - - c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) - _, err = a.Decrypt(ctx, ct) - require.Error(t, err) - }) - - t.Run("case=with-rotation", func(t *testing.T) { - old := secret(t) - c.MustSet(ctx, config.KeyGetSystemSecret, []string{old}) - a := NewAEAD(c) - - plain := []byte(uuid.New()) - ct, err := a.Encrypt(ctx, plain) - require.NoError(t, err) - - // Sets the old secret as a rotated secret and creates a new one. - c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), old}) - res, err := a.Decrypt(ctx, ct) - require.NoError(t, err) - assert.Equal(t, plain, res) - - // THis should also work when we re-encrypt the same plain text. - ct2, err := a.Encrypt(ctx, plain) - require.NoError(t, err) - assert.NotEqual(t, ct2, ct) - - res, err = a.Decrypt(ctx, ct) - require.NoError(t, err) - assert.Equal(t, plain, res) - }) - - t.Run("case=with-rotation-wrong-secret", func(t *testing.T) { - c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t)}) - a := NewAEAD(c) - - plain := []byte(uuid.New()) - ct, err := a.Encrypt(ctx, plain) - require.NoError(t, err) - - // When the secrets do not match, an error should be thrown during decryption. - c.MustSet(ctx, config.KeyGetSystemSecret, []string{secret(t), secret(t)}) - _, err = a.Decrypt(ctx, ct) - require.Error(t, err) - }) -} diff --git a/jwk/cast_test.go b/jwk/cast_test.go index d55b81ba518..2e78283719d 100644 --- a/jwk/cast_test.go +++ b/jwk/cast_test.go @@ -14,6 +14,7 @@ import ( ) func TestMustRSAPrivate(t *testing.T) { + t.Parallel() keys, err := GenerateJWK(context.Background(), jose.RS256, "foo", "sig") require.NoError(t, err) diff --git a/jwk/generate_test.go b/jwk/generate_test.go index 01a47d4ec67..544a8de9d23 100644 --- a/jwk/generate_test.go +++ b/jwk/generate_test.go @@ -13,6 +13,7 @@ import ( ) func TestGenerateJWK(t *testing.T) { + t.Parallel() jwks, err := GenerateJWK(context.Background(), jose.RS256, "", "") require.NoError(t, err) assert.NotEmpty(t, jwks.Keys[0].KeyID) diff --git a/jwk/handler.go b/jwk/handler.go index d5e87ea29fc..1bbfb9ecb81 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -36,6 +36,8 @@ type Handler struct { // JSON Web Key Set // // swagger:model jsonWebKeySet +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type jsonWebKeySet struct { // List of JSON Web Keys // @@ -114,6 +116,8 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) { // Get JSON Web Key Request // // swagger:parameters getJsonWebKey +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getJsonWebKey struct { // JSON Web Key Set ID // @@ -162,6 +166,8 @@ func (h *Handler) getJsonWebKey(w http.ResponseWriter, r *http.Request, ps httpr // Get JSON Web Key Set Parameters // // swagger:parameters getJsonWebKeySet +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getJsonWebKeySet struct { // JSON Web Key Set ID // @@ -205,6 +211,8 @@ func (h *Handler) getJsonWebKeySet(w http.ResponseWriter, r *http.Request, ps ht // Create JSON Web Key Set Request // // swagger:parameters createJsonWebKeySet +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type adminCreateJsonWebKeySet struct { // The JSON Web Key Set ID // @@ -283,6 +291,8 @@ func (h *Handler) createJsonWebKeySet(w http.ResponseWriter, r *http.Request, ps // Set JSON Web Key Set Request // // swagger:parameters setJsonWebKeySet +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type setJsonWebKeySet struct { // The JSON Web Key Set ID // @@ -333,6 +343,8 @@ func (h *Handler) setJsonWebKeySet(w http.ResponseWriter, r *http.Request, ps ht // Set JSON Web Key Request // // swagger:parameters setJsonWebKey +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type setJsonWebKey struct { // The JSON Web Key Set ID // @@ -389,6 +401,8 @@ func (h *Handler) adminUpdateJsonWebKey(w http.ResponseWriter, r *http.Request, // Delete JSON Web Key Set Parameters // // swagger:parameters deleteJsonWebKeySet +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type deleteJsonWebKeySet struct { // The JSON Web Key Set // in: path @@ -429,6 +443,8 @@ func (h *Handler) adminDeleteJsonWebKeySet(w http.ResponseWriter, r *http.Reques // Delete JSON Web Key Parameters // // swagger:parameters deleteJsonWebKey +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type deleteJsonWebKey struct { // The JSON Web Key Set // in: path diff --git a/jwk/handler_test.go b/jwk/handler_test.go index 09613ca42b9..c9040a37a4c 100644 --- a/jwk/handler_test.go +++ b/jwk/handler_test.go @@ -25,6 +25,8 @@ import ( ) func TestHandlerWellKnown(t *testing.T) { + t.Parallel() + conf := internal.NewConfigurationWithDefaults() reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) conf.MustSet(context.Background(), config.KeyWellKnownKeys, []string{x.OpenIDConnectKeyName, x.OpenIDConnectKeyName}) @@ -37,6 +39,7 @@ func TestHandlerWellKnown(t *testing.T) { JWKPath := "/.well-known/jwks.json" t.Run("Test_Handler_WellKnown/Run_public_key_With_public_prefix", func(t *testing.T) { + t.Parallel() if conf.HSMEnabled() { t.Skip("Skipping test. Not applicable when Hardware Security Module is enabled. Public/private keys on HSM are generated with equal key id's and are not using prefixes") } @@ -62,6 +65,7 @@ func TestHandlerWellKnown(t *testing.T) { }) t.Run("Test_Handler_WellKnown/Run_public_key_Without_public_prefix", func(t *testing.T) { + t.Parallel() var IDKS *jose.JSONWebKeySet if conf.HSMEnabled() { diff --git a/jwk/helper_test.go b/jwk/helper_test.go index b724349d515..d0ce928c3ec 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -6,7 +6,7 @@ package jwk_test import ( "context" "crypto" - "crypto/dsa" + "crypto/dsa" //lint:ignore SA1019 used for testing invalid key types "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -46,7 +46,10 @@ func (f *fakeSigner) Public() crypto.PublicKey { } func TestHandlerFindPublicKey(t *testing.T) { + t.Parallel() + t.Run("Test_Helper/Run_FindPublicKey_With_RSA", func(t *testing.T) { + t.Parallel() RSIDKS, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig") require.NoError(t, err) keys, err := jwk.FindPublicKey(RSIDKS) @@ -56,6 +59,7 @@ func TestHandlerFindPublicKey(t *testing.T) { }) t.Run("Test_Helper/Run_FindPublicKey_With_Opaque", func(t *testing.T) { + t.Parallel() key, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig") RSIDKS := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{ Algorithm: "RS256", @@ -82,6 +86,7 @@ func TestHandlerFindPublicKey(t *testing.T) { }) t.Run("Test_Helper/Run_FindPublicKey_With_ECDSA", func(t *testing.T) { + t.Parallel() ECDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.ES256, "test-id-2", "sig") require.NoError(t, err) keys, err := jwk.FindPublicKey(ECDSAIDKS) @@ -91,6 +96,7 @@ func TestHandlerFindPublicKey(t *testing.T) { }) t.Run("Test_Helper/Run_FindPublicKey_With_EdDSA", func(t *testing.T) { + t.Parallel() EdDSAIDKS, err := jwk.GenerateJWK(context.Background(), jose.EdDSA, "test-id-3", "sig") require.NoError(t, err) keys, err := jwk.FindPublicKey(EdDSAIDKS) @@ -100,6 +106,7 @@ func TestHandlerFindPublicKey(t *testing.T) { }) t.Run("Test_Helper/Run_FindPublicKey_With_KeyNotFound", func(t *testing.T) { + t.Parallel() keySet := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}} _, err := jwk.FindPublicKey(keySet) require.Error(t, err) @@ -108,6 +115,7 @@ func TestHandlerFindPublicKey(t *testing.T) { } func TestHandlerFindPrivateKey(t *testing.T) { + t.Parallel() t.Run("Test_Helper/Run_FindPrivateKey_With_RSA", func(t *testing.T) { RSIDKS, _ := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig") keys, err := jwk.FindPrivateKey(RSIDKS) @@ -143,6 +151,7 @@ func TestHandlerFindPrivateKey(t *testing.T) { } func TestPEMBlockForKey(t *testing.T) { + t.Parallel() t.Run("Test_Helper/Run_PEMBlockForKey_With_RSA", func(t *testing.T) { RSIDKS, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig") require.NoError(t, err) @@ -185,6 +194,7 @@ func TestPEMBlockForKey(t *testing.T) { } func TestExcludeOpaquePrivateKeys(t *testing.T) { + t.Parallel() opaqueKeys, err := jwk.GenerateJWK(context.Background(), jose.RS256, "test-id-1", "sig") assert.NoError(t, err) require.Len(t, opaqueKeys.Keys, 1) @@ -199,6 +209,7 @@ func TestExcludeOpaquePrivateKeys(t *testing.T) { } func TestGetOrGenerateKeys(t *testing.T) { + t.Parallel() reg := internal.NewMockedRegistry(t, &contextx.Default{}) setId := uuid.NewUUID().String() diff --git a/jwk/manager_strategy_test.go b/jwk/manager_strategy_test.go index e138f30e072..6fb8db03bbb 100644 --- a/jwk/manager_strategy_test.go +++ b/jwk/manager_strategy_test.go @@ -17,6 +17,7 @@ import ( ) func TestKeyManagerStrategy(t *testing.T) { + t.Parallel() ctrl := gomock.NewController(t) softwareKeyManager := NewMockManager(ctrl) hardwareKeyManager := NewMockManager(ctrl) diff --git a/jwk/registry.go b/jwk/registry.go index 1d7b4355f8c..b5c3ea8d811 100644 --- a/jwk/registry.go +++ b/jwk/registry.go @@ -4,6 +4,7 @@ package jwk import ( + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/driver/config" "github.com/ory/hydra/v2/x" ) @@ -18,5 +19,5 @@ type Registry interface { config.Provider KeyManager() Manager SoftwareKeyManager() Manager - KeyCipher() *AEAD + KeyCipher() *aead.AESGCM } diff --git a/jwk/registry_mock_test.go b/jwk/registry_mock_test.go index d6295f11b31..c305fd18167 100644 --- a/jwk/registry_mock_test.go +++ b/jwk/registry_mock_test.go @@ -13,6 +13,7 @@ import ( gomock "github.com/golang/mock/gomock" herodot "github.com/ory/herodot" + "github.com/ory/hydra/v2/aead" config "github.com/ory/hydra/v2/driver/config" jwk "github.com/ory/hydra/v2/jwk" logrusx "github.com/ory/x/logrusx" @@ -70,10 +71,10 @@ func (mr *MockInternalRegistryMockRecorder) Config() *gomock.Call { } // KeyCipher mocks base method. -func (m *MockInternalRegistry) KeyCipher() *jwk.AEAD { +func (m *MockInternalRegistry) KeyCipher() *aead.AESGCM { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeyCipher") - ret0, _ := ret[0].(*jwk.AEAD) + ret0, _ := ret[0].(*aead.AESGCM) return ret0 } @@ -177,10 +178,10 @@ func (mr *MockRegistryMockRecorder) Config() *gomock.Call { } // KeyCipher mocks base method. -func (m *MockRegistry) KeyCipher() *jwk.AEAD { +func (m *MockRegistry) KeyCipher() *aead.AESGCM { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeyCipher") - ret0, _ := ret[0].(*jwk.AEAD) + ret0, _ := ret[0].(*aead.AESGCM) return ret0 } diff --git a/jwk/sdk_test.go b/jwk/sdk_test.go index 571e8ce5b05..46d1cc81448 100644 --- a/jwk/sdk_test.go +++ b/jwk/sdk_test.go @@ -24,6 +24,7 @@ import ( ) func TestJWKSDK(t *testing.T) { + t.Parallel() ctx := context.Background() conf := internal.NewConfigurationWithDefaults() reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) @@ -41,6 +42,7 @@ func TestJWKSDK(t *testing.T) { expectedKid := "key-bar" t.Run("JSON Web Key", func(t *testing.T) { + t.Parallel() t.Run("CreateJwkSetKey", func(t *testing.T) { // Create a key called set-foo resultKeys, _, err := sdk.JwkApi.CreateJsonWebKeySet(context.Background(), "set-foo").CreateJsonWebKeySet(hydra.CreateJsonWebKeySet{ @@ -93,6 +95,7 @@ func TestJWKSDK(t *testing.T) { }) t.Run("JWK Set", func(t *testing.T) { + t.Parallel() t.Run("CreateJwkSetKey", func(t *testing.T) { resultKeys, _, err := sdk.JwkApi.CreateJsonWebKeySet(ctx, "set-foo2").CreateJsonWebKeySet(hydra.CreateJsonWebKeySet{ Alg: "RS256", diff --git a/oauth2/flowctx/cookies.go b/oauth2/flowctx/cookies.go new file mode 100644 index 00000000000..00ae91aeef0 --- /dev/null +++ b/oauth2/flowctx/cookies.go @@ -0,0 +1,38 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package flowctx + +import "github.com/ory/hydra/v2/client" + +type ( + CookieSuffixer interface { + CookieSuffix() string + } + + StaticSuffix string + clientID string +) + +func (s StaticSuffix) CookieSuffix() string { return string(s) } +func (s clientID) GetID() string { return string(s) } + +const ( + flowCookie = "ory_hydra_flow" + loginSessionCookie = "ory_hydra_loginsession" +) + +func FlowCookie(suffix CookieSuffixer) string { + return flowCookie + "_" + suffix.CookieSuffix() +} +func LoginSessionCookie(suffix CookieSuffixer) string { + return loginSessionCookie + "_" + suffix.CookieSuffix() +} + +func SuffixForClient(c client.IDer) StaticSuffix { + return StaticSuffix(client.CookieSuffix(c)) +} + +func SuffixFromStatic(id string) StaticSuffix { + return SuffixForClient(clientID(id)) +} diff --git a/oauth2/flowctx/encoding.go b/oauth2/flowctx/encoding.go new file mode 100644 index 00000000000..5b01b5ec1cd --- /dev/null +++ b/oauth2/flowctx/encoding.go @@ -0,0 +1,150 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package flowctx + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/json" + "net/http" + + "github.com/pkg/errors" + + "github.com/ory/fosite" + "github.com/ory/hydra/v2/aead" + "github.com/ory/hydra/v2/driver/config" +) + +type ( + data struct { + Purpose purpose `json:"p,omitempty"` + } + purpose int + CodecOption func(ad *data) +) + +const ( + loginChallenge purpose = iota + loginVerifier + consentChallenge + consentVerifier +) + +func withPurpose(purpose purpose) CodecOption { return func(ad *data) { ad.Purpose = purpose } } + +var ( + AsLoginChallenge = withPurpose(loginChallenge) + AsLoginVerifier = withPurpose(loginVerifier) + AsConsentChallenge = withPurpose(consentChallenge) + AsConsentVerifier = withPurpose(consentVerifier) +) + +func additionalDataFromOpts(opts ...CodecOption) []byte { + if len(opts) == 0 { + return nil + } + ad := &data{} + for _, o := range opts { + o(ad) + } + b, err := json.Marshal(ad) + if err != nil { + // Panic is OK here because the struct and the parameters are all known. + panic("failed to marshal additional data: " + errors.WithStack(err).Error()) + } + + return b +} + +// Decode decodes the given string to a value. +func Decode[T any](ctx context.Context, cipher aead.Cipher, encoded string, opts ...CodecOption) (*T, error) { + plaintext, err := cipher.Decrypt(ctx, encoded, additionalDataFromOpts(opts...)) + if err != nil { + return nil, err + } + + rawBytes, err := gzip.NewReader(bytes.NewReader(plaintext)) + if err != nil { + return nil, err + } + defer func() { _ = rawBytes.Close() }() + + var val T + if err = json.NewDecoder(rawBytes).Decode(&val); err != nil { + return nil, err + } + + return &val, nil +} + +// Encode encodes the given value to a string. +func Encode(ctx context.Context, cipher aead.Cipher, val any, opts ...CodecOption) (s string, err error) { + // Steps: + // 1. Encode to JSON + // 2. GZIP + // 3. Encrypt with AEAD (AES-GCM) + Base64 URL-encode + var b bytes.Buffer + + gz := gzip.NewWriter(&b) + + if err = json.NewEncoder(gz).Encode(val); err != nil { + return "", err + } + if err = gz.Close(); err != nil { + return "", err + } + + return cipher.Encrypt(ctx, b.Bytes(), additionalDataFromOpts(opts...)) +} + +// SetCookie encrypts the given value and sets it in a cookie. +func SetCookie(ctx context.Context, w http.ResponseWriter, reg interface { + FlowCipher() *aead.XChaCha20Poly1305 + config.Provider +}, cookieName string, value any, opts ...CodecOption) error { + cipher := reg.FlowCipher() + cookie, err := Encode(ctx, cipher, value, opts...) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: cookieName, + Value: cookie, + HttpOnly: true, + Domain: reg.Config().CookieDomain(ctx), + Secure: reg.Config().CookieSecure(ctx), + SameSite: reg.Config().CookieSameSiteMode(ctx), + }) + + return nil +} + +// DeleteCookie deletes the flow cookie. +func DeleteCookie(ctx context.Context, w http.ResponseWriter, reg interface { + config.Provider +}, cookieName string) error { + http.SetCookie(w, &http.Cookie{ + Name: cookieName, + Value: "", + MaxAge: -1, + HttpOnly: true, + Domain: reg.Config().CookieDomain(ctx), + Secure: reg.Config().CookieSecure(ctx), + SameSite: reg.Config().CookieSameSiteMode(ctx), + }) + + return nil +} + +// FromCookie looks up the value stored in the cookie and decodes it. +func FromCookie[T any](ctx context.Context, r *http.Request, cipher aead.Cipher, cookieName string, opts ...CodecOption) (*T, error) { + cookie, err := r.Cookie(cookieName) + if err != nil { + return nil, errors.WithStack(fosite.ErrInvalidClient.WithHint("No cookie found for this request. Please initiate a new flow and retry.")) + } + + return Decode[T](ctx, cipher, cookie.Value, opts...) +} diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index a885b73bb16..0a2c670fe03 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/jwk" "github.com/gobuffalo/pop/v6" @@ -36,7 +37,6 @@ import ( "github.com/ory/x/sqlcon" "github.com/ory/hydra/v2/client" - "github.com/ory/hydra/v2/consent" ) func signatureFromJTI(jti string) string { @@ -121,9 +121,9 @@ var flushRequests = []*fosite.Request{ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createClient bool) { cl := &client.Client{LegacyClientID: "foobar"} - cr := &consent.OAuth2ConsentRequest{ + cr := &flow.OAuth2ConsentRequest{ Client: cl, - OpenIDConnectContext: new(consent.OAuth2ConsentRequestOpenIDConnectContext), + OpenIDConnectContext: new(flow.OAuth2ConsentRequestOpenIDConnectContext), LoginChallenge: sqlxx.NullString(id), ID: id, Verifier: id, @@ -132,18 +132,36 @@ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createCl RequestedAt: time.Now(), } + ctx := context.Background() if createClient { - require.NoError(t, x.ClientManager().CreateClient(context.Background(), cl)) + require.NoError(t, x.ClientManager().CreateClient(ctx, cl)) } - require.NoError(t, x.ConsentManager().CreateLoginRequest(context.Background(), &consent.LoginRequest{Client: cl, OpenIDConnectContext: new(consent.OAuth2ConsentRequestOpenIDConnectContext), ID: id, Verifier: id, AuthenticatedAt: sqlxx.NullTime(time.Now()), RequestedAt: time.Now()})) - require.NoError(t, x.ConsentManager().CreateConsentRequest(context.Background(), cr)) - _, err := x.ConsentManager().HandleConsentRequest(context.Background(), &consent.AcceptOAuth2ConsentRequest{ - ConsentRequest: cr, Session: new(consent.AcceptOAuth2ConsentRequestSession), AuthenticatedAt: sqlxx.NullTime(time.Now()), - ID: id, - RequestedAt: time.Now(), - HandledAt: sqlxx.NullTime(time.Now()), + f, err := x.ConsentManager().CreateLoginRequest( + ctx, &flow.LoginRequest{ + Client: cl, + OpenIDConnectContext: new(flow.OAuth2ConsentRequestOpenIDConnectContext), + ID: id, + Verifier: id, + AuthenticatedAt: sqlxx.NullTime(time.Now()), + RequestedAt: time.Now(), + }) + require.NoError(t, err) + err = x.ConsentManager().CreateConsentRequest(ctx, f, cr) + require.NoError(t, err) + + encodedFlow, err := f.ToConsentVerifier(ctx, x) + require.NoError(t, err) + + _, err = x.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{ + ConsentRequest: cr, + Session: new(flow.AcceptOAuth2ConsentRequestSession), + AuthenticatedAt: sqlxx.NullTime(time.Now()), + ID: encodedFlow, + RequestedAt: time.Now(), + HandledAt: sqlxx.NullTime(time.Now()), }) + require.NoError(t, err) } @@ -270,10 +288,18 @@ func testHelperRevokeRefreshToken(x InternalRegistry) func(t *testing.T) { mockRequestForeignKey(t, reqIdOne, x, false) mockRequestForeignKey(t, reqIdTwo, x, false) - err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{ID: reqIdOne, Client: &client.Client{LegacyClientID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), Session: &Session{}}) + err = m.CreateRefreshTokenSession(ctx, "1111", &fosite.Request{ + ID: reqIdOne, + Client: &client.Client{LegacyClientID: "foobar"}, + RequestedAt: time.Now().UTC().Round(time.Second), + Session: &Session{}}) require.NoError(t, err) - err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{ID: reqIdTwo, Client: &client.Client{LegacyClientID: "foobar"}, RequestedAt: time.Now().UTC().Round(time.Second), Session: &Session{}}) + err = m.CreateRefreshTokenSession(ctx, "1122", &fosite.Request{ + ID: reqIdTwo, + Client: &client.Client{LegacyClientID: "foobar"}, + RequestedAt: time.Now().UTC().Round(time.Second), + Session: &Session{}}) require.NoError(t, err) _, err = m.GetRefreshTokenSession(ctx, "1111", &Session{}) diff --git a/oauth2/fosite_store_test.go b/oauth2/fosite_store_test.go index 70efbc3f4b2..767f33444d5 100644 --- a/oauth2/fosite_store_test.go +++ b/oauth2/fosite_store_test.go @@ -23,8 +23,8 @@ import ( func TestMain(m *testing.M) { flag.Parse() - runner := dockertest.Register() - runner.Exit(m.Run()) + defer dockertest.KillAllTestDatabases() + m.Run() } var registries = make(map[string]driver.Registry) diff --git a/oauth2/handler.go b/oauth2/handler.go index c9eaaaa0c62..0cd6d53cb34 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/ory/hydra/v2/x/events" "github.com/ory/x/httprouterx" "github.com/pborman/uuid" @@ -58,7 +59,10 @@ type Handler struct { } func NewHandler(r InternalRegistry, c *config.DefaultProvider) *Handler { - return &Handler{r: r, c: c} + return &Handler{ + r: r, + c: c, + } } func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic, corsMiddleware func(http.Handler) http.Handler) { @@ -460,6 +464,8 @@ func (h *Handler) discoverOidcConfiguration(w http.ResponseWriter, r *http.Reque // OpenID Connect Userinfo // // swagger:model oidcUserInfo +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type oidcUserInfo struct { // Subject - Identifier for the End-User at the IssuerURL. Subject string `json:"sub"` @@ -623,6 +629,8 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) { // Revoke OAuth 2.0 Access or Refresh Token Request // // swagger:parameters revokeOAuth2Token +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type revokeOAuth2Token struct { // in: formData // required: true @@ -656,6 +664,7 @@ type revokeOAuth2Token struct { // default: errorOAuth2 func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + events.Trace(r.Context(), events.AccessTokenRevoked) err := h.r.OAuth2Provider().NewRevocationRequest(ctx, r) if err != nil { @@ -668,6 +677,8 @@ func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) { // Introspect OAuth 2.0 Access or Refresh Token Request // // swagger:parameters introspectOAuth2Token +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type introspectOAuth2Token struct { // The string value of the token. For access tokens, this // is the "access_token" value returned from the token endpoint @@ -791,11 +802,19 @@ func (h *Handler) introspectOAuth2Token(w http.ResponseWriter, r *http.Request, }); err != nil { x.LogError(r, errorsx.WithStack(err), h.r.Logger()) } + + events.Trace(ctx, + events.AccessTokenInspected, + events.WithSubject(session.GetSubject()), + events.WithClientID(resp.GetAccessRequester().GetClient().GetID()), + ) } // OAuth 2.0 Token Exchange Parameters // // swagger:parameters oauth2TokenExchange +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type performOAuth2TokenFlow struct { // in: formData // required: true @@ -817,6 +836,8 @@ type performOAuth2TokenFlow struct { // OAuth2 Token Exchange Result // // swagger:model oAuth2TokenExchange +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type oAuth2TokenExchange struct { // The lifetime in seconds of the access token. For // example, the value "3600" denotes that the access token will @@ -865,23 +886,26 @@ type oAuth2TokenExchange struct { // 200: oAuth2TokenExchange // default: errorOAuth2 func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) { - var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims(r.Context())) - var ctx = r.Context() + session := NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims(r.Context())) + ctx := r.Context() accessRequest, err := h.r.OAuth2Provider().NewAccessRequest(ctx, r, session) if err != nil { h.logOrAudit(err, r) h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err) + events.Trace(ctx, events.TokenExchangeError) return } - if accessRequest.GetGrantTypes().ExactOne("client_credentials") || accessRequest.GetGrantTypes().ExactOne("urn:ietf:params:oauth:grant-type:jwt-bearer") { + if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeClientCredentials)) || + accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeJWTBearer)) { var accessTokenKeyID string if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" { accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx) if err != nil { x.LogError(r, err, h.r.Logger()) h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err) + events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest)) return } } @@ -895,7 +919,7 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) { session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String() session.DefaultSession.Claims.IssuedAt = time.Now().UTC() - var scopes = accessRequest.GetRequestedScopes() + scopes := accessRequest.GetRequestedScopes() // Added for compatibility with MITREid if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 { @@ -921,6 +945,7 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) { if err := hook(ctx, accessRequest); err != nil { h.logOrAudit(err, r) h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err) + events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest)) return } } @@ -929,6 +954,7 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) { if err != nil { h.logOrAudit(err, r) h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err) + events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest)) return } @@ -962,7 +988,7 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http return } - session, err := h.r.ConsentStrategy().HandleOAuth2AuthorizationRequest(ctx, w, r, authorizeRequest) + session, flow, err := h.r.ConsentStrategy().HandleOAuth2AuthorizationRequest(ctx, w, r, authorizeRequest) if errors.Is(err, consent.ErrAbortOAuth2Request) { x.LogAudit(r, nil, h.r.AuditLogger()) // do nothing @@ -1049,6 +1075,7 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http ConsentChallenge: session.ID, ExcludeNotBeforeClaim: h.c.ExcludeNotBeforeClaim(ctx), AllowedTopLevelClaims: h.c.AllowedTopLevelClaims(ctx), + Flow: flow, }) if err != nil { x.LogError(r, err, h.r.Logger()) @@ -1062,6 +1089,8 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http // Delete OAuth 2.0 Access Token Parameters // // swagger:parameters deleteOAuth2Token +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type deleteOAuth2Token struct { // OAuth 2.0 Client ID // diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index 74a8751a5eb..cc10d429127 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -93,7 +93,7 @@ func TestUserinfo(t *testing.T) { conf.MustSet(ctx, config.KeyAuthCodeLifespan, lifespan) conf.MustSet(ctx, config.KeyIssuerURL, "http://hydra.localhost") reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) ctrl := gomock.NewController(t) op := NewMockOAuth2Provider(ctrl) @@ -147,8 +147,8 @@ func TestUserinfo(t *testing.T) { setup: func(t *testing.T) { op.EXPECT(). IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). - DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { - session = &oauth2.Session{ + DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { + session := &oauth2.Session{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "alice", @@ -180,8 +180,8 @@ func TestUserinfo(t *testing.T) { setup: func(t *testing.T) { op.EXPECT(). IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). - DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { - session = &oauth2.Session{ + DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { + session := &oauth2.Session{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "another-alice", @@ -215,8 +215,8 @@ func TestUserinfo(t *testing.T) { setup: func(t *testing.T) { op.EXPECT(). IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). - DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { - session = &oauth2.Session{ + DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { + session := &oauth2.Session{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "alice", @@ -250,8 +250,8 @@ func TestUserinfo(t *testing.T) { setup: func(t *testing.T) { op.EXPECT(). IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). - DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { - session = &oauth2.Session{ + DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { + session := &oauth2.Session{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "alice", @@ -278,8 +278,8 @@ func TestUserinfo(t *testing.T) { setup: func(t *testing.T) { op.EXPECT(). IntrospectToken(gomock.Any(), gomock.Eq("access-token"), gomock.Eq(fosite.AccessToken), gomock.Any()). - DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, session fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { - session = &oauth2.Session{ + DoAndReturn(func(_ context.Context, _ string, _ fosite.TokenType, _ fosite.Session, _ ...string) (fosite.TokenType, fosite.AccessRequester, error) { + session := &oauth2.Session{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: "alice", diff --git a/oauth2/introspector_test.go b/oauth2/introspector_test.go index 1905476b9f6..6511a77e33e 100644 --- a/oauth2/introspector_test.go +++ b/oauth2/introspector_test.go @@ -35,7 +35,7 @@ func TestIntrospectorSDK(t *testing.T) { conf.MustSet(ctx, config.KeyIssuerURL, "https://foobariss") reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) internal.AddFositeExamples(reg) tokens := Tokens(reg.OAuth2ProviderConfig(), 4) diff --git a/oauth2/oauth2_auth_code_bench_test.go b/oauth2/oauth2_auth_code_bench_test.go new file mode 100644 index 00000000000..92b54f8ff81 --- /dev/null +++ b/oauth2/oauth2_auth_code_bench_test.go @@ -0,0 +1,305 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2_test + +import ( + "context" + "flag" + "net/http" + "os" + "runtime" + "runtime/pprof" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + semconv "go.opentelemetry.io/otel/semconv/v1.12.0" + "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" + + hydra "github.com/ory/hydra-client-go/v2" + hc "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/hydra/v2/jwk" + "github.com/ory/hydra/v2/x" + "github.com/ory/x/contextx" + "github.com/ory/x/pointerx" + "github.com/ory/x/stringsx" +) + +var ( + prof = flag.String("profile", "", "write a CPU profile to this filename") + conc = flag.Int("conc", 100, "dispatch this many requests concurrently") + tracing = flag.Bool("tracing", false, "send OpenTelemetry traces to localhost:4318") +) + +func BenchmarkAuthCode(b *testing.B) { + flag.Parse() + + ctx := context.Background() + + spans := tracetest.NewSpanRecorder() + opts := []trace.TracerProviderOption{ + trace.WithSpanProcessor(spans), + trace.WithResource(resource.NewWithAttributes( + semconv.SchemaURL, attribute.String(string(semconv.ServiceNameKey), "BenchmarkAuthCode"), + )), + } + if *tracing { + exporter, err := otlptracehttp.New(ctx, otlptracehttp.WithInsecure(), otlptracehttp.WithEndpoint("localhost:4318")) + require.NoError(b, err) + opts = append(opts, trace.WithSpanProcessor(trace.NewSimpleSpanProcessor(exporter))) + } + provider := trace.NewTracerProvider(opts...) + + tracer := provider.Tracer("BenchmarkAuthCode") + otel.SetTextMapPropagator(propagation.TraceContext{}) + otel.SetTracerProvider(provider) + + ctx, span := tracer.Start(ctx, "BenchmarkAuthCode") + defer span.End() + + ctx = context.WithValue(ctx, oauth2.HTTPClient, otelhttp.DefaultClient) + + dsn := stringsx.Coalesce(os.Getenv("DSN"), "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable&max_conns=20&max_idle_conns=20") + // dsn := "mysql://root:secret@tcp(localhost:3444)/mysql?max_conns=16&max_idle_conns=16" + // dsn := "cockroach://root@localhost:3446/defaultdb?sslmode=disable&max_conns=16&max_idle_conns=16" + reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) + reg.Config().MustSet(ctx, config.KeyLogLevel, "error") + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + reg.Config().MustSet(ctx, config.KeyRefreshTokenHookURL, "") + oauth2Keys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OAuth2JWTKeyName, "sig") + require.NoError(b, err) + oidcKeys, err := jwk.GenerateJWK(ctx, jose.ES256, x.OpenIDConnectKeyName, "sig") + require.NoError(b, err) + _, _ = oauth2Keys, oidcKeys + require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OAuth2JWTKeyName, oauth2Keys)) + require.NoError(b, reg.KeyManager().UpdateKeySet(ctx, x.OpenIDConnectKeyName, oidcKeys)) + _, adminTS := testhelpers.NewOAuth2Server(ctx, b, reg) + var ( + authURL = reg.Config().OAuth2AuthURL(ctx).String() + tokenURL = reg.Config().OAuth2TokenURL(ctx).String() + nonce = uuid.New() + ) + + newOAuth2Client := func(b *testing.B, cb string) (*hc.Client, *oauth2.Config) { + secret := uuid.New() + c := &hc.Client{ + Secret: secret, + RedirectURIs: []string{cb}, + ResponseTypes: []string{"id_token", "code", "token"}, + GrantTypes: []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}, + Scope: "hydra offline openid", + Audience: []string{"https://api.ory.sh/"}, + } + require.NoError(b, reg.ClientManager().CreateClient(ctx, c)) + return c, &oauth2.Config{ + ClientID: c.GetID(), + ClientSecret: secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + AuthStyle: oauth2.AuthStyleInHeader, + }, + Scopes: strings.Split(c.Scope, " "), + } + } + + cfg := hydra.NewConfiguration() + cfg.HTTPClient = otelhttp.DefaultClient + adminClient := hydra.NewAPIClient(cfg) + adminClient.GetConfig().Servers = hydra.ServerConfigurations{{URL: adminTS.URL}} + + getAuthorizeCode := func(ctx context.Context, b *testing.B, conf *oauth2.Config, c *http.Client, params ...oauth2.AuthCodeOption) (string, *http.Response) { + if c == nil { + c = testhelpers.NewEmptyJarClient(b) + } + + state := uuid.New() + + req, err := http.NewRequestWithContext(ctx, "GET", conf.AuthCodeURL(state, params...), nil) + require.NoError(b, err) + resp, err := c.Do(req) + require.NoError(b, err) + defer resp.Body.Close() + + q := resp.Request.URL.Query() + require.EqualValues(b, state, q.Get("state")) + return q.Get("code"), resp + } + + acceptLoginHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(request *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest) http.HandlerFunc { + return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rr, _, err := adminClient.OAuth2Api.GetOAuth2LoginRequest(ctx).LoginChallenge(r.URL.Query().Get("login_challenge")).Execute() + require.NoError(b, err) + + assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris) + assert.EqualValues(b, r.URL.Query().Get("login_challenge"), rr.Challenge) + assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.Contains(b, rr.RequestUrl, authURL) + + acceptBody := hydra.AcceptOAuth2LoginRequest{ + Subject: uuid.New(), + Remember: pointerx.Ptr(!rr.Skip), + Acr: pointerx.Ptr("1"), + Amr: []string{"pwd"}, + Context: map[string]interface{}{"context": "bar"}, + } + if checkRequestPayload != nil { + if b := checkRequestPayload(rr); b != nil { + acceptBody = *b + } + } + + v, _, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(ctx). + LoginChallenge(r.URL.Query().Get("login_challenge")). + AcceptOAuth2LoginRequest(acceptBody). + Execute() + require.NoError(b, err) + require.NotEmpty(b, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + }), "acceptLoginHandler").ServeHTTP + } + + acceptConsentHandler := func(b *testing.B, c *hc.Client, checkRequestPayload func(*hydra.OAuth2ConsentRequest)) http.HandlerFunc { + return otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rr, _, err := adminClient.OAuth2Api.GetOAuth2ConsentRequest(ctx).ConsentChallenge(r.URL.Query().Get("consent_challenge")).Execute() + require.NoError(b, err) + + assert.EqualValues(b, c.GetID(), pointerx.Deref(rr.Client.ClientId)) + assert.Empty(b, pointerx.Deref(rr.Client.ClientSecret)) + assert.EqualValues(b, c.GrantTypes, rr.Client.GrantTypes) + assert.EqualValues(b, c.LogoURI, pointerx.Deref(rr.Client.LogoUri)) + assert.EqualValues(b, c.RedirectURIs, rr.Client.RedirectUris) + // assert.EqualValues(b, subject, pointerx.Deref(rr.Subject)) + assert.EqualValues(b, []string{"hydra", "offline", "openid"}, rr.RequestedScope) + assert.EqualValues(b, r.URL.Query().Get("consent_challenge"), rr.Challenge) + assert.Contains(b, *rr.RequestUrl, authURL) + if checkRequestPayload != nil { + checkRequestPayload(rr) + } + + assert.Equal(b, map[string]interface{}{"context": "bar"}, rr.Context) + v, _, err := adminClient.OAuth2Api.AcceptOAuth2ConsentRequest(ctx). + ConsentChallenge(r.URL.Query().Get("consent_challenge")). + AcceptOAuth2ConsentRequest(hydra.AcceptOAuth2ConsentRequest{ + GrantScope: []string{"hydra", "offline", "openid"}, Remember: pointerx.Ptr(true), RememberFor: pointerx.Ptr[int64](0), + GrantAccessTokenAudience: rr.RequestedAccessTokenAudience, + Session: &hydra.AcceptOAuth2ConsentRequestSession{ + AccessToken: map[string]interface{}{"foo": "bar"}, + IdToken: map[string]interface{}{"bar": "baz"}, + }, + }). + Execute() + require.NoError(b, err) + require.NotEmpty(b, v.RedirectTo) + http.Redirect(w, r, v.RedirectTo, http.StatusFound) + }), "acceptConsentHandler").ServeHTTP + } + + run := func(b *testing.B, strategy string) func(*testing.B) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + c, conf := newOAuth2Client(b, testhelpers.NewCallbackURL(b, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(b, reg.Config(), + acceptLoginHandler(b, c, nil), + acceptConsentHandler(b, c, nil), + ) + + return func(b *testing.B) { + //pop.Debug = true + code, _ := getAuthorizeCode(ctx, b, conf, nil, oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(b, code) + + _, err := conf.Exchange(ctx, code) + //pop.Debug = false + require.NoError(b, err) + } + } + + b.ResetTimer() + + b.SetParallelism(*conc / runtime.GOMAXPROCS(0)) + + b.Run("strategy=jwt", func(b *testing.B) { + initialDBSpans := dbSpans(spans) + B := run(b, "jwt") + + stop := profile(b) + defer stop() + + var totalMS int64 = 0 + b.RunParallel(func(p *testing.PB) { + defer func(t0 time.Time) { + atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds())) + }(time.Now()) + for p.Next() { + B(b) + } + }) + + b.ReportMetric(0, "ns/op") + b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op") + b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op") + b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s") + }) + + b.Run("strategy=opaque", func(b *testing.B) { + initialDBSpans := dbSpans(spans) + B := run(b, "opaque") + + stop := profile(b) + defer stop() + + var totalMS int64 = 0 + b.RunParallel(func(p *testing.PB) { + defer func(t0 time.Time) { + atomic.AddInt64(&totalMS, int64(time.Since(t0).Milliseconds())) + }(time.Now()) + for p.Next() { + B(b) + } + }) + + b.ReportMetric(0, "ns/op") + b.ReportMetric(float64(atomic.LoadInt64(&totalMS))/float64(b.N), "ms/op") + b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op") + b.ReportMetric(float64(b.N)/b.Elapsed().Seconds(), "ops/s") + }) + +} + +func profile(t testing.TB) (stop func()) { + t.Helper() + if *prof == "" { + return func() {} // noop + } + f, err := os.Create(*prof) + require.NoError(t, err) + require.NoError(t, pprof.StartCPUProfile(f)) + return func() { + pprof.StopCPUProfile() + require.NoError(t, f.Close()) + t.Log("Wrote profile to", f.Name()) + } +} diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index b4811fead10..93349df73ea 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/ory/hydra/v2/flow" "github.com/ory/x/ioutilx" "github.com/ory/x/requirex" @@ -30,7 +31,6 @@ import ( "github.com/pborman/uuid" "github.com/tidwall/gjson" - "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/internal/testhelpers" "github.com/ory/x/contextx" @@ -50,7 +50,7 @@ import ( "github.com/ory/x/snapshotx" ) -func noopHandler(t *testing.T) httprouter.Handle { +func noopHandler(*testing.T) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { w.WriteHeader(http.StatusNotImplemented) } @@ -347,6 +347,108 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { }) }) + t.Run("suite=invalid query params", func(t *testing.T) { + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + otherClient, _ := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, nil), + acceptConsentHandler(t, c, subject, nil), + ) + + withWrongClientAfterLogin := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("login_verifier") { + return nil + } + q.Set("client_id", otherClient.ID.String()) + req.URL.RawQuery = q.Encode() + return nil + }, + } + withWrongClientAfterConsent := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("consent_verifier") { + return nil + } + q.Set("client_id", otherClient.ID.String()) + req.URL.RawQuery = q.Encode() + return nil + }, + } + + withWrongScopeAfterLogin := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("login_verifier") { + return nil + } + q.Set("scope", "invalid scope") + req.URL.RawQuery = q.Encode() + return nil + }, + } + + withWrongScopeAfterConsent := &http.Client{ + Jar: testhelpers.NewEmptyCookieJar(t), + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.Path != "/oauth2/auth" { + return nil + } + q := req.URL.Query() + if !q.Has("consent_verifier") { + return nil + } + q.Set("scope", "invalid scope") + req.URL.RawQuery = q.Encode() + return nil + }, + } + + for _, tc := range []struct { + name string + client *http.Client + expectedResponse string + }{{ + name: "fails with wrong client ID after login", + client: withWrongClientAfterLogin, + expectedResponse: "access_denied", + }, { + name: "fails with wrong client ID after consent", + client: withWrongClientAfterConsent, + expectedResponse: "invalid_client", + }, { + name: "fails with wrong scopes after login", + client: withWrongScopeAfterLogin, + expectedResponse: "invalid_scope", + }, { + name: "fails with wrong scopes after consent", + client: withWrongScopeAfterConsent, + expectedResponse: "invalid_scope", + }} { + t.Run("case="+tc.name, func(t *testing.T) { + state := uuid.New() + resp, err := tc.client.Get(conf.AuthCodeURL(state)) + require.NoError(t, err) + assert.Equal(t, tc.expectedResponse, resp.Request.URL.Query().Get("error"), "%s", resp.Request.URL.RawQuery) + resp.Body.Close() + }) + } + }) + t.Run("case=checks if request fails when subject is empty", func(t *testing.T) { testhelpers.NewLoginConsentUI(t, reg.Config(), func(w http.ResponseWriter, r *http.Request) { _, res, err := adminClient.OAuth2Api.AcceptOAuth2LoginRequest(ctx). @@ -702,7 +804,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { } hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ + Session: flow.AcceptOAuth2ConsentRequestSession{ AccessToken: claims, IDToken: claims, }, @@ -894,8 +996,8 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { conf.MustSet(ctx, config.KeyScopeStrategy, "DEPRECATED_HIERARCHICAL_SCOPE_STRATEGY") conf.MustSet(ctx, config.KeyAccessTokenStrategy, strat.d) reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName) - internal.MustEnsureRegistryKeys(reg, x.OAuth2JWTKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(ctx, reg, x.OAuth2JWTKeyName) consentStrategy := &consentMock{} router := x.NewRouterPublic() @@ -1102,7 +1204,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { require.NotEmpty(t, code) - token, err := oauthConfig.Exchange(oauth2.NoContext, code) + token, err := oauthConfig.Exchange(context.TODO(), code) if tc.expectOAuthTokenError { require.Error(t, err) return @@ -1263,7 +1365,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { } hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ + Session: flow.AcceptOAuth2ConsentRequestSession{ AccessToken: claims, IDToken: claims, }, @@ -1446,7 +1548,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) t.Run("duplicate code exchange fails", func(t *testing.T) { - token, err := oauthConfig.Exchange(oauth2.NoContext, code) + token, err := oauthConfig.Exchange(context.TODO(), code) require.Error(t, err) require.Nil(t, token) }) diff --git a/oauth2/oauth2_client_credentials_bench_test.go b/oauth2/oauth2_client_credentials_bench_test.go new file mode 100644 index 00000000000..310727f34cc --- /dev/null +++ b/oauth2/oauth2_client_credentials_bench_test.go @@ -0,0 +1,162 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2_test + +import ( + "context" + "encoding/json" + "net/url" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + goauth2 "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" + + hc "github.com/ory/hydra/v2/client" + "github.com/ory/hydra/v2/driver/config" + "github.com/ory/hydra/v2/internal" + "github.com/ory/hydra/v2/internal/testhelpers" + "github.com/ory/hydra/v2/x" + "github.com/ory/x/contextx" + "github.com/ory/x/requirex" +) + +func BenchmarkClientCredentials(b *testing.B) { + ctx := context.Background() + + spans := tracetest.NewSpanRecorder() + tracer := trace.NewTracerProvider(trace.WithSpanProcessor(spans)).Tracer("") + + dsn := "postgres://postgres:secret@127.0.0.1:3445/postgres?sslmode=disable" + reg := internal.NewRegistrySQLFromURL(b, dsn, true, new(contextx.Default)).WithTracer(tracer) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + public, admin := testhelpers.NewOAuth2Server(ctx, b, reg) + + var newCustomClient = func(b *testing.B, c *hc.Client) (*hc.Client, clientcredentials.Config) { + unhashedSecret := c.Secret + require.NoError(b, reg.ClientManager().CreateClient(ctx, c)) + return c, clientcredentials.Config{ + ClientID: c.GetID(), + ClientSecret: unhashedSecret, + TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), + Scopes: strings.Split(c.Scope, " "), + EndpointParams: url.Values{"audience": c.Audience}, + } + } + + var newClient = func(b *testing.B) (*hc.Client, clientcredentials.Config) { + cc, config := newCustomClient(b, &hc.Client{ + Secret: uuid.New().String(), + RedirectURIs: []string{public.URL + "/callback"}, + ResponseTypes: []string{"token"}, + GrantTypes: []string{"client_credentials"}, + Scope: "foobar", + Audience: []string{"https://api.ory.sh/"}, + }) + return cc, config + } + + var getToken = func(t *testing.B, conf clientcredentials.Config) (*goauth2.Token, error) { + conf.AuthStyle = goauth2.AuthStyleInHeader + return conf.Token(context.Background()) + } + + var encodeOr = func(b *testing.B, val interface{}, or string) string { + out, err := json.Marshal(val) + require.NoError(b, err) + if string(out) == "null" { + return or + } + + return string(out) + } + + var inspectToken = func(b *testing.B, token *goauth2.Token, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time, checkExtraClaims bool) { + introspection := testhelpers.IntrospectToken(b, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: conf.ClientSecret}, token.AccessToken, admin) + + check := func(res gjson.Result) { + assert.EqualValues(b, cl.GetID(), res.Get("client_id").String(), "%s", res.Raw) + assert.EqualValues(b, cl.GetID(), res.Get("sub").String(), "%s", res.Raw) + assert.EqualValues(b, reg.Config().IssuerURL(ctx).String(), res.Get("iss").String(), "%s", res.Raw) + + assert.EqualValues(b, res.Get("nbf").Int(), res.Get("iat").Int(), "%s", res.Raw) + requirex.EqualTime(b, expectedExp, time.Unix(res.Get("exp").Int(), 0), time.Second) + + assert.EqualValues(b, encodeOr(b, conf.EndpointParams["audience"], "[]"), res.Get("aud").Raw, "%s", res.Raw) + + if checkExtraClaims { + require.True(b, res.Get("ext.hooked").Bool()) + } + } + + check(introspection) + assert.True(b, introspection.Get("active").Bool()) + assert.EqualValues(b, "access_token", introspection.Get("token_use").String()) + assert.EqualValues(b, "Bearer", introspection.Get("token_type").String()) + assert.EqualValues(b, strings.Join(conf.Scopes, " "), introspection.Get("scope").String(), "%s", introspection.Raw) + + if strategy != "jwt" { + return + } + + body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) + require.NoError(b, err) + + jwtClaims := gjson.ParseBytes(body) + assert.NotEmpty(b, jwtClaims.Get("jti").String()) + assert.EqualValues(b, encodeOr(b, conf.Scopes, "[]"), jwtClaims.Get("scp").Raw, "%s", introspection.Raw) + check(jwtClaims) + } + + var getAndInspectToken = func(b *testing.B, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time, checkExtraClaims bool) { + token, err := getToken(b, conf) + require.NoError(b, err) + inspectToken(b, token, cl, conf, strategy, expectedExp, checkExtraClaims) + } + + run := func(strategy string) func(b *testing.B) { + return func(t *testing.B) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + + cl, conf := newClient(b) + getAndInspectToken(b, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), false) + } + } + + b.Run("strategy=jwt", func(b *testing.B) { + initialDBSpans := dbSpans(spans) + for i := 0; i < b.N; i++ { + run("jwt")(b) + } + b.ReportMetric(0, "ns/op") + b.ReportMetric(float64(b.Elapsed().Milliseconds())/float64(b.N), "ms/op") + b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op") + }) + + b.Run("strategy=opaque", func(b *testing.B) { + initialDBSpans := dbSpans(spans) + for i := 0; i < b.N; i++ { + run("opaque")(b) + } + b.ReportMetric(0, "ns/op") + b.ReportMetric(float64(b.Elapsed().Milliseconds())/float64(b.N), "ms/op") + b.ReportMetric((float64(dbSpans(spans)-initialDBSpans))/float64(b.N), "queries/op") + }) +} + +func dbSpans(spans *tracetest.SpanRecorder) (count int) { + for _, s := range spans.Started() { + if strings.HasPrefix(s.Name(), "sql-") { + count++ + } + } + return +} diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 1703bda9c38..059036a2700 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -22,7 +22,7 @@ import ( goauth2 "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" - "github.com/ory/hydra/v2/consent" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/internal/testhelpers" hydraoauth2 "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" @@ -276,7 +276,7 @@ func TestClientCredentials(t *testing.T) { } hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ + Session: flow.AcceptOAuth2ConsentRequestSession{ AccessToken: claims, IDToken: claims, }, diff --git a/oauth2/oauth2_helper_test.go b/oauth2/oauth2_helper_test.go index ea679c24189..52a30e5975e 100644 --- a/oauth2/oauth2_helper_test.go +++ b/oauth2/oauth2_helper_test.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" "github.com/ory/fosite" + "github.com/ory/hydra/v2/flow" "github.com/ory/x/sqlxx" "github.com/ory/hydra/v2/client" @@ -25,27 +26,27 @@ type consentMock struct { requestTime time.Time } -func (c *consentMock) HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*consent.AcceptOAuth2ConsentRequest, error) { +func (c *consentMock) HandleOAuth2AuthorizationRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*flow.AcceptOAuth2ConsentRequest, *flow.Flow, error) { if c.deny { - return nil, fosite.ErrRequestForbidden + return nil, nil, fosite.ErrRequestForbidden } - return &consent.AcceptOAuth2ConsentRequest{ - ConsentRequest: &consent.OAuth2ConsentRequest{ + return &flow.AcceptOAuth2ConsentRequest{ + ConsentRequest: &flow.OAuth2ConsentRequest{ Subject: "foo", ACR: "1", }, AuthenticatedAt: sqlxx.NullTime(c.authTime), GrantedScope: []string{"offline", "openid", "hydra.*"}, - Session: &consent.AcceptOAuth2ConsentRequestSession{ + Session: &flow.AcceptOAuth2ConsentRequestSession{ AccessToken: map[string]interface{}{}, IDToken: map[string]interface{}{}, }, RequestedAt: c.requestTime, - }, nil + }, nil, nil } -func (c *consentMock) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*consent.LogoutResult, error) { +func (c *consentMock) HandleOpenIDConnectLogout(ctx context.Context, w http.ResponseWriter, r *http.Request) (*flow.LogoutResult, error) { panic("not implemented") } diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index 1aa1f8179ff..b975af21c72 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -20,7 +20,7 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/ory/fosite/token/jwt" - "github.com/ory/hydra/v2/consent" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/jwk" hydraoauth2 "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" @@ -342,7 +342,7 @@ func TestJWTBearer(t *testing.T) { } hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ + Session: flow.AcceptOAuth2ConsentRequestSession{ AccessToken: claims, IDToken: claims, }, @@ -417,7 +417,7 @@ func TestJWTBearer(t *testing.T) { } hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ + Session: flow.AcceptOAuth2ConsentRequestSession{ AccessToken: claims, IDToken: claims, }, diff --git a/oauth2/oauth2_refresh_token_test.go b/oauth2/oauth2_refresh_token_test.go index 94b1ae26f8d..208fb20f78b 100644 --- a/oauth2/oauth2_refresh_token_test.go +++ b/oauth2/oauth2_refresh_token_test.go @@ -49,8 +49,8 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { // first read. workers := 10 - token := "234c678fed33c1d2025537ae464a1ebf7d23fc4a" - tokenSignature := "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31" + token := "234c678fed33c1d2025537ae464a1ebf7d23fc4a" //nolint:gosec + tokenSignature := "4c7c7e8b3a77ad0c3ec846a21653c48b45dbfa31" //nolint:gosec testClient := hc.Client{ ID: uuid.Must(uuid.NewV4()), Secret: "secret", @@ -112,7 +112,7 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) { defer wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) //nolint:gosec // all workers will block here until the for loop above has launched all the worker go-routines // this is to ensure we fire all the workers off at the same <-barrier diff --git a/oauth2/registry.go b/oauth2/registry.go index 38ac335bb11..52f9f7bb9bf 100644 --- a/oauth2/registry.go +++ b/oauth2/registry.go @@ -6,6 +6,7 @@ package oauth2 import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/jwk" @@ -21,6 +22,7 @@ type InternalRegistry interface { x.RegistryLogger consent.Registry Registry + FlowCipher() *aead.XChaCha20Poly1305 } type Registry interface { diff --git a/oauth2/revocator_test.go b/oauth2/revocator_test.go index a2eb5f3d4b3..71b85e63ea2 100644 --- a/oauth2/revocator_test.go +++ b/oauth2/revocator_test.go @@ -63,7 +63,7 @@ func TestRevoke(t *testing.T) { conf := internal.NewConfigurationWithDefaults() reg := internal.NewRegistryMemory(t, conf, &contextx.Default{}) - internal.MustEnsureRegistryKeys(reg, x.OpenIDConnectKeyName) + internal.MustEnsureRegistryKeys(context.Background(), reg, x.OpenIDConnectKeyName) internal.AddFositeExamples(reg) tokens := Tokens(reg.OAuth2ProviderConfig(), 4) diff --git a/oauth2/session.go b/oauth2/session.go index e543a1e123f..3032925fe20 100644 --- a/oauth2/session.go +++ b/oauth2/session.go @@ -16,6 +16,7 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" + "github.com/ory/hydra/v2/flow" "github.com/ory/x/stringslice" ) @@ -29,6 +30,8 @@ type Session struct { ConsentChallenge string `json:"consent_challenge"` ExcludeNotBeforeClaim bool `json:"exclude_not_before_claim"` AllowedTopLevelClaims []string `json:"allowed_top_level_claims"` + + Flow *flow.Flow `json:"-"` } func NewSession(subject string) *Session { diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index f7ca4416a71..fc33fe3813d 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -12,10 +12,10 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/hydra/v2/flow" "github.com/ory/hydra/v2/x" "github.com/ory/fosite" - "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/driver/config" "github.com/ory/x/errorsx" ) @@ -54,7 +54,7 @@ type TokenHookRequest struct { // swagger:ignore type TokenHookResponse struct { // Session is the session data returned by the hook. - Session consent.AcceptOAuth2ConsentRequestSession `json:"session"` + Session flow.AcceptOAuth2ConsentRequestSession `json:"session"` } func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, hookURL *url.URL, reqBodyBytes []byte, session *Session) error { diff --git a/oauth2/trust/doc.go b/oauth2/trust/doc.go index c30e9521ac0..16de4977dd3 100644 --- a/oauth2/trust/doc.go +++ b/oauth2/trust/doc.go @@ -14,11 +14,15 @@ import ( // OAuth2 JWT Bearer Grant Type Issuer Trust Relationships // // swagger:model trustedOAuth2JwtGrantIssuers +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type trustedOAuth2JwtGrantIssuers []trustedOAuth2JwtGrantIssuer // OAuth2 JWT Bearer Grant Type Issuer Trust Relationship // // swagger:model trustedOAuth2JwtGrantIssuer +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type trustedOAuth2JwtGrantIssuer struct { // example: 9edc811f-4e28-453c-9b46-4de65f00217f ID string `json:"id"` @@ -51,6 +55,8 @@ type trustedOAuth2JwtGrantIssuer struct { // OAuth2 JWT Bearer Grant Type Issuer Trusted JSON Web Key // // swagger:model trustedOAuth2JwtGrantJsonWebKey +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type trustedOAuth2JwtGrantJsonWebKey struct { // The "set" is basically a name for a group(set) of keys. Will be the same as "issuer" in grant. // example: https://jwt-idp.example.com diff --git a/oauth2/trust/handler.go b/oauth2/trust/handler.go index 7bc622e95c0..453ab376975 100644 --- a/oauth2/trust/handler.go +++ b/oauth2/trust/handler.go @@ -43,6 +43,8 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin) { // Trust OAuth2 JWT Bearer Grant Type Issuer Request Body // // swagger:model trustOAuth2JwtGrantIssuer +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type trustOAuth2JwtGrantIssuerBody struct { // The "issuer" identifies the principal that issued the JWT assertion (same as "iss" claim in JWT). // @@ -78,6 +80,8 @@ type trustOAuth2JwtGrantIssuerBody struct { // Trust OAuth2 JWT Bearer Grant Type Issuer Request // // swagger:parameters trustOAuth2JwtGrantIssuer +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type trustOAuth2JwtGrantIssuer struct { // in: body Body trustOAuth2JwtGrantIssuerBody @@ -140,6 +144,8 @@ func (h *Handler) trustOAuth2JwtGrantIssuer(w http.ResponseWriter, r *http.Reque // Get Trusted OAuth2 JWT Bearer Grant Type Issuer Request // // swagger:parameters getTrustedOAuth2JwtGrantIssuer +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type getTrustedOAuth2JwtGrantIssuer struct { // The id of the desired grant // @@ -181,6 +187,8 @@ func (h *Handler) getTrustedOAuth2JwtGrantIssuer(w http.ResponseWriter, r *http. // Delete Trusted OAuth2 JWT Bearer Grant Type Issuer Request // // swagger:parameters deleteTrustedOAuth2JwtGrantIssuer +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type deleteTrustedOAuth2JwtGrantIssuer struct { // The id of the desired grant // in: path @@ -223,6 +231,8 @@ func (h *Handler) deleteTrustedOAuth2JwtGrantIssuer(w http.ResponseWriter, r *ht // List Trusted OAuth2 JWT Bearer Grant Type Issuers Request // // swagger:parameters listTrustedOAuth2JwtGrantIssuers +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type listTrustedOAuth2JwtGrantIssuers struct { // If optional "issuer" is supplied, only jwt-bearer grants with this issuer will be returned. // diff --git a/persistence/definitions.go b/persistence/definitions.go index f3a9230ba0a..27a8b0fa037 100644 --- a/persistence/definitions.go +++ b/persistence/definitions.go @@ -6,6 +6,10 @@ package persistence import ( "context" + "github.com/gofrs/uuid" + + "github.com/ory/x/networkx" + "github.com/gobuffalo/pop/v6" "github.com/ory/hydra/v2/client" @@ -30,8 +34,14 @@ type ( PrepareMigration(context.Context) error Connection(context.Context) *pop.Connection Ping() error + Networker } Provider interface { Persister() Persister } + + Networker interface { + NetworkID(ctx context.Context) uuid.UUID + DetermineNetwork(ctx context.Context) (*networkx.Network, error) + } ) diff --git a/persistence/sql/migratest/assertion_helpers.go b/persistence/sql/migratest/assertion_helpers.go index 242f2460891..36f512a2cca 100644 --- a/persistence/sql/migratest/assertion_helpers.go +++ b/persistence/sql/migratest/assertion_helpers.go @@ -8,7 +8,7 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/instana/testify/require" + "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/flow" testhelpersuuid "github.com/ory/hydra/v2/internal/testhelpers/uuid" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json index 0d4b588349d..d89f26c7d42 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0001.json @@ -32,7 +32,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0001", @@ -52,7 +53,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0001": "0001" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json index da822717327..369ba83ba25 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0002.json @@ -32,7 +32,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0002", @@ -52,7 +53,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0002": "0002" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json index 0c8587a0383..66718c0ba27 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0003.json @@ -32,7 +32,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0003", @@ -52,7 +53,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0003": "0003" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json index 08fbbf88023..e707616aa87 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0004.json @@ -34,7 +34,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0004", @@ -56,7 +57,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0004": "0004" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json index 1bebff1778d..fcc4760db32 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0005.json @@ -34,7 +34,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0005", @@ -56,7 +57,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0005": "0005" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json index af35899c259..825ca5b9b00 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0006.json @@ -34,7 +34,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0006", @@ -56,7 +57,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0006": "0006" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json index 509653dbf89..1d20de4190f 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0007.json @@ -34,7 +34,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0007", @@ -56,7 +57,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0007": "0007" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json index 7da6b5b2c10..3ed3dad5245 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0008.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0008", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0008": "0008" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json index f59ac706aaa..61f8bbabf0c 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0009.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0009", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0009": "0009" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json index 99135f5f763..a886dd0aefe 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0010.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0010", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0010": "0010" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json index ab8c93003b7..dda3212a8d7 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0011.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0011", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0011": "0011" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json index 53c58242a1a..d6491837a10 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0012.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0012", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0012": "0012" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json index b39ef9aca29..89ca9f7daf4 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0013.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0013", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0013": "0013" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json index fff06cbd01d..d020259b581 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0014.json @@ -36,7 +36,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0014", @@ -58,7 +59,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0014": "0014" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json index 4a013571bed..78ee82f16d5 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0015.json @@ -41,7 +41,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0015", @@ -65,7 +66,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0015": "0015" diff --git a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json index 803bab67ce6..e3bddee39a1 100644 --- a/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json +++ b/persistence/sql/migratest/fixtures/hydra_oauth2_flow/challenge-0016.json @@ -41,7 +41,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "LoginAuthenticatedAt": null, "ConsentChallengeID": "challenge-0016", @@ -65,7 +66,8 @@ "error_description": "", "error_hint": "", "status_code": 0, - "error_debug": "" + "error_debug": "", + "valid": false }, "SessionIDToken": { "session_id_token-0016": "0016" diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 460ffb910fd..7c4db0c81d2 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -18,8 +18,8 @@ import ( "github.com/bradleyjkemp/cupaloy/v2" "github.com/fatih/structs" "github.com/gofrs/uuid" - "github.com/instana/testify/assert" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/gobuffalo/pop/v6" @@ -143,7 +143,7 @@ func TestMigrations(t *testing.T) { }) t.Run("case=hydra_oauth2_authentication_session", func(t *testing.T) { - ss := []consent.LoginSession{} + ss := []flow.LoginSession{} c.All(&ss) require.Equal(t, 16, len(ss)) @@ -168,7 +168,7 @@ func TestMigrations(t *testing.T) { }) t.Run("case=hydra_oauth2_logout_request", func(t *testing.T) { - lrs := []consent.LogoutRequest{} + lrs := []flow.LogoutRequest{} c.All(&lrs) require.Equal(t, 6, len(lrs)) diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.down.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.down.sql new file mode 100644 index 00000000000..a391920ba8f --- /dev/null +++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.down.sql @@ -0,0 +1,12 @@ +CREATE UNIQUE INDEX hydra_oauth2_flow_login_verifier_idx ON hydra_oauth2_flow (login_verifier); +CREATE UNIQUE INDEX hydra_oauth2_flow_consent_verifier_idx ON hydra_oauth2_flow (consent_verifier); + +CREATE INDEX hydra_oauth2_flow_multi_query_idx + ON hydra_oauth2_flow + ( + consent_error ASC, state ASC, subject ASC, + client_id ASC, consent_skip ASC, consent_remember + ASC, nid ASC + ); + +DROP INDEX hydra_oauth2_flow_previous_consents_idx; diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.down.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.down.sql new file mode 100644 index 00000000000..16d4e470dae --- /dev/null +++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.down.sql @@ -0,0 +1,12 @@ +CREATE UNIQUE INDEX hydra_oauth2_flow_login_verifier_idx ON hydra_oauth2_flow (login_verifier); +CREATE UNIQUE INDEX hydra_oauth2_flow_consent_verifier_idx ON hydra_oauth2_flow (consent_verifier); + +CREATE INDEX hydra_oauth2_flow_multi_query_idx + ON hydra_oauth2_flow + ( + consent_error(2) ASC, state ASC, subject ASC, + client_id ASC, consent_skip ASC, consent_remember + ASC, nid ASC + ); + +DROP INDEX hydra_oauth2_flow_previous_consents_idx ON hydra_oauth2_flow; diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.up.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.up.sql new file mode 100644 index 00000000000..d7f86b61f94 --- /dev/null +++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.mysql.up.sql @@ -0,0 +1,6 @@ +DROP INDEX hydra_oauth2_flow_login_verifier_idx ON hydra_oauth2_flow; +DROP INDEX hydra_oauth2_flow_consent_verifier_idx ON hydra_oauth2_flow; +DROP INDEX hydra_oauth2_flow_multi_query_idx ON hydra_oauth2_flow; + +CREATE INDEX hydra_oauth2_flow_previous_consents_idx + ON hydra_oauth2_flow (subject, client_id, nid, consent_skip, consent_error(2), consent_remember); diff --git a/persistence/sql/migrations/20230606112801000001_remove_flow_indices.up.sql b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.up.sql new file mode 100644 index 00000000000..d522d3482f5 --- /dev/null +++ b/persistence/sql/migrations/20230606112801000001_remove_flow_indices.up.sql @@ -0,0 +1,6 @@ +DROP INDEX hydra_oauth2_flow_login_verifier_idx; +DROP INDEX hydra_oauth2_flow_consent_verifier_idx; +DROP INDEX hydra_oauth2_flow_multi_query_idx; + +CREATE INDEX IF NOT EXISTS hydra_oauth2_flow_previous_consents_idx + ON hydra_oauth2_flow (subject, client_id, nid, consent_skip, consent_error, consent_remember); diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index c9f9678ab9a..82971db2a2c 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -16,14 +16,15 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/storage" + "github.com/ory/hydra/v2/aead" "github.com/ory/hydra/v2/driver/config" - "github.com/ory/hydra/v2/jwk" "github.com/ory/hydra/v2/persistence" "github.com/ory/hydra/v2/x" "github.com/ory/x/contextx" "github.com/ory/x/errorsx" "github.com/ory/x/logrusx" "github.com/ory/x/networkx" + "github.com/ory/x/otelx" "github.com/ory/x/popx" ) @@ -48,16 +49,17 @@ type ( } Dependencies interface { ClientHasher() fosite.Hasher - KeyCipher() *jwk.AEAD + KeyCipher() *aead.AESGCM + FlowCipher() *aead.XChaCha20Poly1305 contextx.Provider x.RegistryLogger x.TracingProvider } ) -func (p *Persister) BeginTX(ctx context.Context) (context.Context, error) { +func (p *Persister) BeginTX(ctx context.Context) (_ context.Context, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.BeginTX") - defer span.End() + defer otelx.End(span, &err) fallback := &pop.Connection{TX: &pop.Tx{}} if popx.GetConnection(ctx, fallback).TX != fallback.TX { @@ -77,9 +79,9 @@ func (p *Persister) BeginTX(ctx context.Context) (context.Context, error) { return popx.WithTransaction(ctx, c), err } -func (p *Persister) Commit(ctx context.Context) error { +func (p *Persister) Commit(ctx context.Context) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Commit") - defer span.End() + defer otelx.End(span, &err) fallback := &pop.Connection{TX: &pop.Tx{}} tx := popx.GetConnection(ctx, fallback) @@ -90,9 +92,9 @@ func (p *Persister) Commit(ctx context.Context) error { return errorsx.WithStack(tx.TX.Commit()) } -func (p *Persister) Rollback(ctx context.Context) error { +func (p *Persister) Rollback(ctx context.Context) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Rollback") - defer span.End() + defer otelx.End(span, &err) fallback := &pop.Connection{TX: &pop.Tx{}} tx := popx.GetConnection(ctx, fallback) diff --git a/persistence/sql/persister_client.go b/persistence/sql/persister_client.go index 8e45870dab5..482f7126a88 100644 --- a/persistence/sql/persister_client.go +++ b/persistence/sql/persister_client.go @@ -6,20 +6,20 @@ package sql import ( "context" - "github.com/gofrs/uuid" - "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" "github.com/ory/x/errorsx" + "github.com/ory/x/otelx" "github.com/ory/fosite" "github.com/ory/hydra/v2/client" "github.com/ory/x/sqlcon" ) -func (p *Persister) GetConcreteClient(ctx context.Context, id string) (*client.Client, error) { +func (p *Persister) GetConcreteClient(ctx context.Context, id string) (c *client.Client, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteClient") - defer span.End() + defer otelx.End(span, &err) var cl client.Client if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&cl); err != nil { @@ -32,9 +32,9 @@ func (p *Persister) GetClient(ctx context.Context, id string) (fosite.Client, er return p.GetConcreteClient(ctx, id) } -func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) error { +func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UpdateClient") - defer span.End() + defer otelx.End(span, &err) return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { o, err := p.GetConcreteClient(ctx, cl.GetID()) @@ -71,9 +71,9 @@ func (p *Persister) UpdateClient(ctx context.Context, cl *client.Client) error { }) } -func (p *Persister) Authenticate(ctx context.Context, id string, secret []byte) (*client.Client, error) { +func (p *Persister) Authenticate(ctx context.Context, id string, secret []byte) (_ *client.Client, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.Authenticate") - defer span.End() + defer otelx.End(span, &err) c, err := p.GetConcreteClient(ctx, id) if err != nil { @@ -87,9 +87,9 @@ func (p *Persister) Authenticate(ctx context.Context, id string, secret []byte) return c, nil } -func (p *Persister) CreateClient(ctx context.Context, c *client.Client) error { +func (p *Persister) CreateClient(ctx context.Context, c *client.Client) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateClient") - defer span.End() + defer otelx.End(span, &err) h, err := p.r.ClientHasher().Hash(ctx, []byte(c.Secret)) if err != nil { @@ -106,11 +106,11 @@ func (p *Persister) CreateClient(ctx context.Context, c *client.Client) error { return sqlcon.HandleError(p.CreateWithNetwork(ctx, c)) } -func (p *Persister) DeleteClient(ctx context.Context, id string) error { +func (p *Persister) DeleteClient(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteClient") - defer span.End() + defer otelx.End(span, &err) - _, err := p.GetConcreteClient(ctx, id) + _, err = p.GetConcreteClient(ctx, id) if err != nil { return err } @@ -118,9 +118,9 @@ func (p *Persister) DeleteClient(ctx context.Context, id string) error { return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("id = ?", id).Delete(&client.Client{})) } -func (p *Persister) GetClients(ctx context.Context, filters client.Filter) ([]client.Client, error) { +func (p *Persister) GetClients(ctx context.Context, filters client.Filter) (_ []client.Client, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClients") - defer span.End() + defer otelx.End(span, &err) cs := make([]client.Client, 0) @@ -141,10 +141,10 @@ func (p *Persister) GetClients(ctx context.Context, filters client.Filter) ([]cl return cs, nil } -func (p *Persister) CountClients(ctx context.Context) (int, error) { +func (p *Persister) CountClients(ctx context.Context) (n int, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountClients") - defer span.End() + defer otelx.End(span, &err) - n, err := p.QueryWithNetwork(ctx).Count(&client.Client{}) + n, err = p.QueryWithNetwork(ctx).Count(&client.Client{}) return n, sqlcon.HandleError(err) } diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index 8f1fca3d490..bd401d56423 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -11,7 +11,9 @@ import ( "time" "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/ory/hydra/v2/oauth2/flowctx" "github.com/ory/x/sqlxx" "github.com/ory/x/errorsx" @@ -95,7 +97,7 @@ func (p *Persister) RevokeSubjectLoginSession(ctx context.Context, subject strin ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.RevokeSubjectLoginSession") defer span.End() - err := p.QueryWithNetwork(ctx).Where("subject = ?", subject).Delete(&consent.LoginSession{}) + err := p.QueryWithNetwork(ctx).Where("subject = ?", subject).Delete(&flow.LoginSession{}) if err != nil { return sqlcon.HandleError(err) } @@ -158,34 +160,22 @@ func (p *Persister) GetForcedObfuscatedLoginSession(ctx context.Context, client, // CreateConsentRequest configures fields that are introduced or changed in the // consent request. It doesn't touch fields that would be copied from the login // request. -func (p *Persister) CreateConsentRequest(ctx context.Context, req *consent.OAuth2ConsentRequest) error { +func (p *Persister) CreateConsentRequest(ctx context.Context, f *flow.Flow, req *flow.OAuth2ConsentRequest) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateConsentRequest") defer span.End() - c, err := p.Connection(ctx).RawQuery(` -UPDATE hydra_oauth2_flow -SET - state = ?, - consent_challenge_id = ?, - consent_skip = ?, - consent_verifier = ?, - consent_csrf = ? -WHERE login_challenge = ? AND nid = ?; -`, - flow.FlowStateConsentInitialized, - sqlxx.NullString(req.ID), - req.Skip, - req.Verifier, - req.CSRF, - req.LoginChallenge.String(), - p.NetworkID(ctx), - ).ExecWithCount() - if err != nil { - return sqlcon.HandleError(err) + if f == nil { + return errorsx.WithStack(x.ErrNotFound.WithDebug("Flow is nil")) } - if c != 1 { + if f.ID != req.LoginChallenge.String() || f.NID != p.NetworkID(ctx) { return errorsx.WithStack(x.ErrNotFound) } + f.State = flow.FlowStateConsentInitialized + f.ConsentChallengeID = sqlxx.NullString(req.ID) + f.ConsentSkip = req.Skip + f.ConsentVerifier = sqlxx.NullString(req.Verifier) + f.ConsentCSRF = sqlxx.NullString(req.CSRF) + return nil } @@ -193,16 +183,22 @@ func (p *Persister) GetFlowByConsentChallenge(ctx context.Context, challenge str ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetFlowByConsentChallenge") defer span.End() - f := &flow.Flow{} - - if err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("consent_challenge_id = ?", challenge).First(f)); err != nil { - return nil, err + // challenge contains the flow. + f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), challenge, flowctx.AsConsentChallenge) + if err != nil { + return nil, errorsx.WithStack(x.ErrNotFound) + } + if f.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(x.ErrNotFound) + } + if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) { + return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The consent request has expired, please try again.")) } return f, nil } -func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*consent.OAuth2ConsentRequest, error) { +func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*flow.OAuth2ConsentRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConsentRequest") defer span.End() @@ -214,15 +210,24 @@ func (p *Persister) GetConsentRequest(ctx context.Context, challenge string) (*c return nil, err } + // We need to overwrite the ID with the encoded flow (challenge) so that the client is not confused. + f.ConsentChallengeID = sqlxx.NullString(challenge) + return f.GetConsentRequest(), nil } -func (p *Persister) CreateLoginRequest(ctx context.Context, req *consent.LoginRequest) error { +func (p *Persister) CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (*flow.Flow, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginRequest") defer span.End() f := flow.NewFlow(req) - return sqlcon.HandleError(p.CreateWithNetwork(ctx, f)) + nid := p.NetworkID(ctx) + if nid == uuid.Nil { + return nil, errorsx.WithStack(x.ErrNotFound) + } + f.NID = nid + + return f, nil } func (p *Persister) GetFlow(ctx context.Context, loginChallenge string) (*flow.Flow, error) { @@ -230,130 +235,166 @@ func (p *Persister) GetFlow(ctx context.Context, loginChallenge string) (*flow.F defer span.End() var f flow.Flow - return &f, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return errorsx.WithStack(x.ErrNotFound) - } - return sqlcon.HandleError(err) + if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(x.ErrNotFound) } - - return nil - }) + return nil, sqlcon.HandleError(err) + } + return &f, nil } -func (p *Persister) GetLoginRequest(ctx context.Context, loginChallenge string) (*consent.LoginRequest, error) { +func (p *Persister) GetLoginRequest(ctx context.Context, loginChallenge string) (*flow.LoginRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetLoginRequest") defer span.End() - var lr *consent.LoginRequest - return lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - var f flow.Flow - if err := p.QueryWithNetwork(ctx).Where("login_challenge = ?", loginChallenge).First(&f); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return errorsx.WithStack(x.ErrNotFound) - } - return sqlcon.HandleError(err) - } - lr = f.GetLoginRequest() + f, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), loginChallenge, flowctx.AsLoginChallenge) + if err != nil { + return nil, errorsx.WithStack(x.ErrNotFound.WithWrap(err)) + } + if f.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(x.ErrNotFound) + } + if f.RequestedAt.Add(p.config.ConsentRequestMaxAge(ctx)).Before(time.Now()) { + return nil, errorsx.WithStack(fosite.ErrRequestUnauthorized.WithHint("The login request has expired, please try again.")) + } + lr := f.GetLoginRequest() + // Restore the short challenge ID, which was previously sent to the encoded flow, + // to make sure that the challenge ID in the returned flow matches the param. + lr.ID = loginChallenge - return nil - }) + return lr, nil } -func (p *Persister) HandleConsentRequest(ctx context.Context, r *consent.AcceptOAuth2ConsentRequest) (*consent.OAuth2ConsentRequest, error) { +func (p *Persister) HandleConsentRequest(ctx context.Context, f *flow.Flow, r *flow.AcceptOAuth2ConsentRequest) (*flow.OAuth2ConsentRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleConsentRequest") defer span.End() - f := &flow.Flow{} - - if err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("consent_challenge_id = ?", r.ID).First(f)); errors.Is(err, sqlcon.ErrNoRows) { - return nil, err + if f == nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil")) } - + if f.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(x.ErrNotFound) + } + // Restore the short challenge ID, which was previously sent to the encoded flow, + // to make sure that the challenge ID in the returned flow matches the param. + r.ID = f.ConsentChallengeID.String() if err := f.HandleConsentRequest(r); err != nil { return nil, errorsx.WithStack(err) } - _, err := p.UpdateWithNetwork(ctx, f) - if err != nil { - return nil, sqlcon.HandleError(err) - } - - return p.GetConsentRequest(ctx, r.ID) + return f.GetConsentRequest(), nil } -func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*consent.AcceptOAuth2ConsentRequest, error) { +func (p *Persister) VerifyAndInvalidateConsentRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.AcceptOAuth2ConsentRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateConsentRequest") defer span.End() - var r consent.AcceptOAuth2ConsentRequest - return &r, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - var f flow.Flow - if err := p.QueryWithNetwork(ctx).Where("consent_verifier = ?", verifier).First(&f); err != nil { - return sqlcon.HandleError(err) - } + if f == nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil")) + } + if f.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(sqlcon.ErrNoRows) + } - if err := f.InvalidateConsentRequest(); err != nil { - return errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error())) - } + updatedFlow, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsConsentVerifier) + if err != nil { + return nil, errorsx.WithStack(fosite.ErrAccessDenied.WithHint("The consent verifier has already been used, has not been granted, or is invalid.")) + } + if updatedFlow.ID != f.ID { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Consent verifier does not match login request.")) + } + if updatedFlow.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(sqlcon.ErrNoRows) + } - r = *f.GetHandledConsentRequest() - _, err := p.UpdateWithNetwork(ctx, &f) - return err - }) + // Update flow from login request, but keep requested at. + updatedFlow.NID = f.NID + updatedFlow.ConsentCSRF = f.ConsentCSRF + updatedFlow.ConsentVerifier = f.ConsentVerifier + *f = *updatedFlow + + if err = f.InvalidateConsentRequest(); err != nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error())) + } + + // We set the consent challenge ID to a new UUID that we can use as a foreign key in the database + // without encoding the whole flow. + f.ConsentChallengeID = sqlxx.NullString(uuid.Must(uuid.NewV4()).String()) + + if err = p.Connection(ctx).Create(f); err != nil { + return nil, sqlcon.HandleError(err) + } + + return f.GetHandledConsentRequest(), nil } -func (p *Persister) HandleLoginRequest(ctx context.Context, challenge string, r *consent.HandledLoginRequest) (lr *consent.LoginRequest, err error) { +func (p *Persister) HandleLoginRequest(ctx context.Context, f *flow.Flow, challenge string, r *flow.HandledLoginRequest) (lr *flow.LoginRequest, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.HandleLoginRequest") defer span.End() - return lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - f, err := p.GetFlow(ctx, challenge) - if err != nil { - return sqlcon.HandleError(err) - } - err = f.HandleLoginRequest(r) - if err != nil { - return err - } - - _, err = p.UpdateWithNetwork(ctx, f) - if err != nil { - return sqlcon.HandleError(err) - } + if f == nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil")) + } + if f.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(x.ErrNotFound) + } + r.ID = f.ID + err = f.HandleLoginRequest(r) + if err != nil { + return nil, err + } - lr, err = p.GetLoginRequest(ctx, challenge) - return sqlcon.HandleError(err) - }) + return p.GetLoginRequest(ctx, challenge) } -func (p *Persister) VerifyAndInvalidateLoginRequest(ctx context.Context, verifier string) (*consent.HandledLoginRequest, error) { +func (p *Persister) VerifyAndInvalidateLoginRequest(ctx context.Context, f *flow.Flow, verifier string) (*flow.HandledLoginRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateLoginRequest") defer span.End() - var d consent.HandledLoginRequest - return &d, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - var f flow.Flow - if err := p.QueryWithNetwork(ctx).Where("login_verifier = ?", verifier).First(&f); err != nil { - return sqlcon.HandleError(err) - } + if f == nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Flow was nil")) + } + if f.NID != p.NetworkID(ctx) { + return nil, errorsx.WithStack(sqlcon.ErrNoRows) + } - if err := f.InvalidateLoginRequest(); err != nil { - return errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error())) - } + updatedFlow, err := flowctx.Decode[flow.Flow](ctx, p.r.FlowCipher(), verifier, flowctx.AsLoginVerifier) + if err != nil { + return nil, errorsx.WithStack(sqlcon.ErrNoRows) + } + if f.NID != updatedFlow.NID { + return nil, errorsx.WithStack(sqlcon.ErrNoRows) + } - d = f.GetHandledLoginRequest() - _, err := p.UpdateWithNetwork(ctx, &f) - return sqlcon.HandleError(err) - }) + if updatedFlow.ID != f.ID { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug("Login verifier does not match login request.")) + } + + // Update flow from login request, but keep requested at. + updatedFlow.NID = f.NID + updatedFlow.RequestedAt = f.RequestedAt + updatedFlow.LoginCSRF = f.LoginCSRF + updatedFlow.LoginVerifier = f.LoginVerifier + *f = *updatedFlow + + if err := f.InvalidateLoginRequest(); err != nil { + return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error())) + } + d := f.GetHandledLoginRequest() + + return &d, nil } -func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (*consent.LoginSession, error) { +func (p *Persister) GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRememberedLoginSession") defer span.End() - var s consent.LoginSession + if s := loginSessionFromCookie; s != nil && s.NID == p.NetworkID(ctx) && s.ID == id && s.Remember { + return s, nil + } + + var s flow.LoginSession if err := p.QueryWithNetwork(ctx).Where("remember = TRUE").Find(&s, id); errors.Is(err, sql.ErrNoRows) { return nil, errorsx.WithStack(x.ErrNotFound) @@ -364,30 +405,56 @@ func (p *Persister) GetRememberedLoginSession(ctx context.Context, id string) (* return &s, nil } -func (p *Persister) ConfirmLoginSession(ctx context.Context, id string, authenticatedAt time.Time, subject string, remember bool) error { +func (p *Persister) ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authenticatedAt time.Time, subject string, remember bool) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ConfirmLoginSession") defer span.End() - _, err := p.Connection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).UpdateQuery(&consent.LoginSession{ + // Since we previously cached the login session, we now need to persist it to db. + if session != nil { + if session.NID != p.NetworkID(ctx) || session.ID != id { + return errorsx.WithStack(x.ErrNotFound) + } + session.AuthenticatedAt = sqlxx.NullTime(authenticatedAt.Truncate(time.Second)) + session.Subject = subject + session.Remember = remember + + return p.CreateWithNetwork(ctx, session) + } + + // In some unit tests, we still confirm the login session without data from the cookie. We can remove this case + // once all tests are fixed. + n, err := p.Connection(ctx).Where("id = ? AND nid = ?", id, p.NetworkID(ctx)).UpdateQuery(&flow.LoginSession{ AuthenticatedAt: sqlxx.NullTime(authenticatedAt), Subject: subject, Remember: remember, }, "authenticated_at", "subject", "remember") - return sqlcon.HandleError(err) + if err != nil { + return sqlcon.HandleError(err) + } + if n == 0 { + return errorsx.WithStack(x.ErrNotFound) + } + return nil } -func (p *Persister) CreateLoginSession(ctx context.Context, session *consent.LoginSession) error { +func (p *Persister) CreateLoginSession(ctx context.Context, session *flow.LoginSession) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLoginSession") defer span.End() - return sqlcon.HandleError(p.CreateWithNetwork(ctx, session)) + nid := p.NetworkID(ctx) + if nid == uuid.Nil { + return errorsx.WithStack(x.ErrNotFound) + } + session.NID = nid + + return nil } func (p *Persister) DeleteLoginSession(ctx context.Context, id string) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteLoginSession") defer span.End() - count, err := p.Connection(ctx).RawQuery("DELETE FROM hydra_oauth2_authentication_session WHERE id=? AND nid = ?", id, p.NetworkID(ctx)).ExecWithCount() + count, err := p.Connection(ctx).RawQuery("DELETE FROM hydra_oauth2_authentication_session WHERE id=? AND nid=?", id, p.NetworkID(ctx)).ExecWithCount() if count == 0 { return errorsx.WithStack(x.ErrNotFound) } else { @@ -395,18 +462,14 @@ func (p *Persister) DeleteLoginSession(ctx context.Context, id string) error { } } -func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) ([]consent.AcceptOAuth2ConsentRequest, error) { +func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, client, subject string) (rs []flow.AcceptOAuth2ConsentRequest, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindGrantedAndRememberedConsentRequests") defer span.End() - rs := make([]consent.AcceptOAuth2ConsentRequest, 0) - - return rs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - f := &flow.Flow{} - - if err := c. - Where( - strings.TrimSpace(fmt.Sprintf(` + var f flow.Flow + if err = p.Connection(ctx). + Where( + strings.TrimSpace(fmt.Sprintf(` (state = %d OR state = %d) AND subject = ? AND client_id = ? AND @@ -414,24 +477,21 @@ consent_skip=FALSE AND consent_error='{}' AND consent_remember=TRUE AND nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, - )), - subject, client, p.NetworkID(ctx)). - Order("requested_at DESC"). - Limit(1). - First(f); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return errorsx.WithStack(consent.ErrNoPreviousConsentFound) - } - return sqlcon.HandleError(err) + )), + subject, client, p.NetworkID(ctx)). + Order("requested_at DESC"). + Limit(1). + First(&f); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(consent.ErrNoPreviousConsentFound) } + return nil, sqlcon.HandleError(err) + } - var err error - rs, err = p.filterExpiredConsentRequests(ctx, []consent.AcceptOAuth2ConsentRequest{*f.GetHandledConsentRequest()}) - return err - }) + return p.filterExpiredConsentRequests(ctx, []flow.AcceptOAuth2ConsentRequest{*f.GetHandledConsentRequest()}) } -func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) ([]consent.AcceptOAuth2ConsentRequest, error) { +func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsGrantedConsentRequests") defer span.End() @@ -457,7 +517,7 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, return nil, sqlcon.HandleError(err) } - var rs []consent.AcceptOAuth2ConsentRequest + var rs []flow.AcceptOAuth2ConsentRequest for _, f := range fs { rs = append(rs, *f.GetHandledConsentRequest()) } @@ -465,7 +525,7 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, return p.filterExpiredConsentRequests(ctx, rs) } -func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, limit, offset int) ([]consent.AcceptOAuth2ConsentRequest, error) { +func (p *Persister) FindSubjectsSessionGrantedConsentRequests(ctx context.Context, subject, sid string, limit, offset int) ([]flow.AcceptOAuth2ConsentRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FindSubjectsSessionGrantedConsentRequests") defer span.End() @@ -492,7 +552,7 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, return nil, sqlcon.HandleError(err) } - var rs []consent.AcceptOAuth2ConsentRequest + var rs []flow.AcceptOAuth2ConsentRequest for _, f := range fs { rs = append(rs, *f.GetHandledConsentRequest()) } @@ -518,11 +578,11 @@ nid = ?`, flow.FlowStateConsentUsed, flow.FlowStateConsentUnused, return n, sqlcon.HandleError(err) } -func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests []consent.AcceptOAuth2ConsentRequest) ([]consent.AcceptOAuth2ConsentRequest, error) { +func (p *Persister) filterExpiredConsentRequests(ctx context.Context, requests []flow.AcceptOAuth2ConsentRequest) ([]flow.AcceptOAuth2ConsentRequest, error) { _, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.filterExpiredConsentRequests") defer span.End() - var result []consent.AcceptOAuth2ConsentRequest + var result []flow.AcceptOAuth2ConsentRequest for _, v := range requests { if v.RememberFor > 0 && v.RequestedAt.Add(time.Duration(v.RememberFor)*time.Second).Before(time.Now().UTC()) { continue @@ -553,10 +613,9 @@ func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, s defer span.End() var cs []client.Client - return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - if err := c.RawQuery( - /* #nosec G201 - channel can either be "front" or "back" */ - fmt.Sprintf(` + if err := p.Connection(ctx).RawQuery( + /* #nosec G201 - channel can either be "front" or "back" */ + fmt.Sprintf(` SELECT DISTINCT c.* FROM hydra_client as c JOIN hydra_oauth2_flow as f ON (c.id = f.client_id) WHERE @@ -566,29 +625,28 @@ WHERE f.login_session_id = ? AND f.nid = ? AND c.nid = ?`, - channel, - channel, - ), - subject, - sid, - p.NetworkID(ctx), - p.NetworkID(ctx), - ).All(&cs); err != nil { - return sqlcon.HandleError(err) - } + channel, + channel, + ), + subject, + sid, + p.NetworkID(ctx), + p.NetworkID(ctx), + ).All(&cs); err != nil { + return nil, sqlcon.HandleError(err) + } - return nil - }) + return cs, nil } -func (p *Persister) CreateLogoutRequest(ctx context.Context, request *consent.LogoutRequest) error { +func (p *Persister) CreateLogoutRequest(ctx context.Context, request *flow.LogoutRequest) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateLogoutRequest") defer span.End() return errorsx.WithStack(p.CreateWithNetwork(ctx, request)) } -func (p *Persister) AcceptLogoutRequest(ctx context.Context, challenge string) (*consent.LogoutRequest, error) { +func (p *Persister) AcceptLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.AcceptLogoutRequest") defer span.End() @@ -613,37 +671,35 @@ func (p *Persister) RejectLogoutRequest(ctx context.Context, challenge string) e } } -func (p *Persister) GetLogoutRequest(ctx context.Context, challenge string) (*consent.LogoutRequest, error) { +func (p *Persister) GetLogoutRequest(ctx context.Context, challenge string) (*flow.LogoutRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetLogoutRequest") defer span.End() - var lr consent.LogoutRequest + var lr flow.LogoutRequest return &lr, sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("challenge = ? AND rejected = FALSE", challenge).First(&lr)) } -func (p *Persister) VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*consent.LogoutRequest, error) { +func (p *Persister) VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*flow.LogoutRequest, error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.VerifyAndInvalidateLogoutRequest") defer span.End() - var lr consent.LogoutRequest - return &lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - if count, err := c.RawQuery( - "UPDATE hydra_oauth2_logout_request SET was_used=TRUE WHERE nid = ? AND verifier=? AND was_used=FALSE AND accepted=TRUE AND rejected=FALSE", - p.NetworkID(ctx), - verifier, - ).ExecWithCount(); count == 0 && err == nil { - return errorsx.WithStack(x.ErrNotFound) - } else if err != nil { - return sqlcon.HandleError(err) - } + var lr flow.LogoutRequest + if count, err := p.Connection(ctx).RawQuery( + "UPDATE hydra_oauth2_logout_request SET was_used=TRUE WHERE nid = ? AND verifier=? AND was_used=FALSE AND accepted=TRUE AND rejected=FALSE", + p.NetworkID(ctx), + verifier, + ).ExecWithCount(); count == 0 && err == nil { + return nil, errorsx.WithStack(x.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } - err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier=?", verifier).First(&lr)) - if err != nil { - return err - } + err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier=?", verifier).First(&lr)) + if err != nil { + return nil, err + } - return nil - }) + return &lr, nil } func (p *Persister) FlushInactiveLoginConsentRequests(ctx context.Context, notAfter time.Time, limit int, batchSize int) error { diff --git a/persistence/sql/persister_jwk.go b/persistence/sql/persister_jwk.go index fe8041326b8..92eb3cf9cea 100644 --- a/persistence/sql/persister_jwk.go +++ b/persistence/sql/persister_jwk.go @@ -47,7 +47,7 @@ func (p *Persister) AddKey(ctx context.Context, set string, key *jose.JSONWebKey return errorsx.WithStack(err) } - encrypted, err := p.r.KeyCipher().Encrypt(ctx, out) + encrypted, err := p.r.KeyCipher().Encrypt(ctx, out, nil) if err != nil { return errorsx.WithStack(err) } @@ -71,7 +71,7 @@ func (p *Persister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWe return errorsx.WithStack(err) } - encrypted, err := p.r.KeyCipher().Encrypt(ctx, out) + encrypted, err := p.r.KeyCipher().Encrypt(ctx, out, nil) if err != nil { return err } @@ -133,7 +133,7 @@ func (p *Persister) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebK return nil, sqlcon.HandleError(err) } - key, err := p.r.KeyCipher().Decrypt(ctx, j.Key) + key, err := p.r.KeyCipher().Decrypt(ctx, j.Key, nil) if err != nil { return nil, errorsx.WithStack(err) } @@ -148,7 +148,7 @@ func (p *Persister) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebK }, nil } -func (p *Persister) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) { +func (p *Persister) GetKeySet(ctx context.Context, set string) (keys *jose.JSONWebKeySet, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetKeySet") defer span.End() @@ -164,9 +164,9 @@ func (p *Persister) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKey return nil, errors.Wrap(x.ErrNotFound, "") } - keys := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}} + keys = &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}} for _, d := range js { - key, err := p.r.KeyCipher().Decrypt(ctx, d.Key) + key, err := p.r.KeyCipher().Decrypt(ctx, d.Key, nil) if err != nil { return nil, errorsx.WithStack(err) } diff --git a/persistence/sql/persister_nid_test.go b/persistence/sql/persister_nid_test.go index 29f4e613f72..cf36ba61c45 100644 --- a/persistence/sql/persister_nid_test.go +++ b/persistence/sql/persister_nid_test.go @@ -6,15 +6,17 @@ package sql_test import ( "context" "database/sql" + "encoding/json" "testing" "time" + "github.com/ory/hydra/v2/persistence" "github.com/ory/x/uuidx" "github.com/ory/x/assertx" "github.com/gofrs/uuid" - "github.com/instana/testify/require" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "gopkg.in/square/go-jose.v2" @@ -88,7 +90,7 @@ func (s *PersisterTestSuite) TestAcceptLogoutRequest() { lrAccepted, err := r.ConsentManager().AcceptLogoutRequest(s.t2, lr.ID) require.Error(t, err) - require.Equal(t, &consent.LogoutRequest{}, lrAccepted) + require.Equal(t, &flow.LogoutRequest{}, lrAccepted) actual, err := r.ConsentManager().GetLogoutRequest(s.t1, lr.ID) require.NoError(t, err) @@ -182,17 +184,20 @@ func (s *PersisterTestSuite) TestConfirmLoginSession() { for k, r := range s.registries { t.Run(k, func(t *testing.T) { require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls)) - expected := &consent.LoginSession{} - require.NoError(t, r.Persister().Connection(context.Background()).Find(expected, ls.ID)) - require.NoError(t, r.Persister().ConfirmLoginSession(s.t2, expected.ID, time.Now(), expected.Subject, !expected.Remember)) - actual := &consent.LoginSession{} + // Expects the login session to be confirmed in the correct context. + require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember)) + actual := &flow.LoginSession{} require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID)) - require.Equal(t, expected, actual) + exp, _ := json.Marshal(ls) + act, _ := json.Marshal(actual) + require.JSONEq(t, string(exp), string(act)) - require.NoError(t, r.Persister().ConfirmLoginSession(s.t1, expected.ID, time.Now(), expected.Subject, !expected.Remember)) - require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID)) - require.NotEqual(t, expected, actual) + // Can't find the login session in the wrong context. + require.ErrorIs(t, + r.Persister().ConfirmLoginSession(s.t2, ls, ls.ID, time.Now(), ls.Subject, !ls.Remember), + x.ErrNotFound, + ) }) } } @@ -202,8 +207,8 @@ func (s *PersisterTestSuite) TestCreateSession() { ls := newLoginSession() for k, r := range s.registries { t.Run(k, func(t *testing.T) { - require.NoError(t, r.Persister().CreateLoginSession(s.t1, ls)) - actual := &consent.LoginSession{} + persistLoginSession(s.t1, t, r.Persister(), ls) + actual := &flow.LoginSession{} require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, ls.ID)) require.Equal(t, s.t1NID, actual.NID) ls.NID = actual.NID @@ -280,12 +285,12 @@ func (s *PersisterTestSuite) TestCountSubjectsGrantedConsentRequests() { require.Equal(t, 0, count) sessionID := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) client := &client.Client{LegacyClientID: "client-id"} require.NoError(t, r.Persister().CreateClient(s.t1, client)) f := newFlow(s.t1NID, client.LegacyClientID, sub, sqlxx.NullString(sessionID)) f.ConsentSkip = false - f.ConsentError = &consent.RequestDeniedError{} + f.ConsentError = &flow.RequestDeniedError{} f.State = flow.FlowStateConsentUnused require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) @@ -359,18 +364,18 @@ func (s *PersisterTestSuite) TestCreateConsentRequest() { sessionID := uuid.Must(uuid.NewV4()).String() client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - req := &consent.OAuth2ConsentRequest{ + req := &flow.OAuth2ConsentRequest{ ID: "consent-request-id", LoginChallenge: sqlxx.NullString(f.ID), Skip: false, Verifier: "verifier", CSRF: "csrf", } - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req)) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req)) actual := flow.Flow{} require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, f.ID)) @@ -418,12 +423,11 @@ func (s *PersisterTestSuite) TestCreateLoginRequest() { for k, r := range s.registries { t.Run(k, func(t *testing.T) { client := &client.Client{LegacyClientID: "client-id"} - lr := consent.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()} + lr := flow.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()} require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.ConsentManager().CreateLoginRequest(s.t1, &lr)) - f := flow.Flow{} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&f, lr.ID)) + f, err := r.ConsentManager().CreateLoginRequest(s.t1, &lr) + require.NoError(t, err) require.Equal(t, s.t1NID, f.NID) }) } @@ -433,9 +437,9 @@ func (s *PersisterTestSuite) TestCreateLoginSession() { t := s.T() for k, r := range s.registries { t.Run(k, func(t *testing.T) { - ls := consent.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true} + ls := flow.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true} require.NoError(t, r.Persister().CreateLoginSession(s.t1, &ls)) - actual, err := r.Persister().GetRememberedLoginSession(s.t1, ls.ID) + actual, err := r.Persister().GetRememberedLoginSession(s.t1, &ls, ls.ID) require.NoError(t, err) require.Equal(t, s.t1NID, actual.NID) }) @@ -447,7 +451,7 @@ func (s *PersisterTestSuite) TestCreateLogoutRequest() { for k, r := range s.registries { t.Run(k, func(t *testing.T) { client := &client.Client{LegacyClientID: "client-id"} - lr := consent.LogoutRequest{ + lr := flow.LogoutRequest{ // TODO there is not FK for SessionID so we don't need it here; TODO make sure the missing FK is intentional ID: uuid.Must(uuid.NewV4()).String(), ClientID: sql.NullString{Valid: true, String: client.LegacyClientID}, @@ -626,15 +630,15 @@ func (s *PersisterTestSuite) TestDeleteLoginSession() { t := s.T() for k, r := range s.registries { t.Run(k, func(t *testing.T) { - ls := consent.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true} - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &ls)) + ls := flow.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true} + persistLoginSession(s.t1, t, r.Persister(), &ls) require.Error(t, r.Persister().DeleteLoginSession(s.t2, ls.ID)) - _, err := r.Persister().GetRememberedLoginSession(s.t1, ls.ID) + _, err := r.Persister().GetRememberedLoginSession(s.t1, nil, ls.ID) require.NoError(t, err) require.NoError(t, r.Persister().DeleteLoginSession(s.t1, ls.ID)) - _, err = r.Persister().GetRememberedLoginSession(s.t1, ls.ID) + _, err = r.Persister().GetRememberedLoginSession(s.t1, nil, ls.ID) require.Error(t, err) }) } @@ -734,11 +738,10 @@ func (s *PersisterTestSuite) TestFindGrantedAndRememberedConsentRequests() { sessionID := uuid.Must(uuid.NewV4()).String() client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - req := &consent.OAuth2ConsentRequest{ + req := &flow.OAuth2ConsentRequest{ ID: "consent-request-id", LoginChallenge: sqlxx.NullString(f.ID), Skip: false, @@ -746,14 +749,15 @@ func (s *PersisterTestSuite) TestFindGrantedAndRememberedConsentRequests() { CSRF: "csrf", } - hcr := &consent.AcceptOAuth2ConsentRequest{ + hcr := &flow.AcceptOAuth2ConsentRequest{ ID: req.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true, } - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req)) - _, err := r.Persister().HandleConsentRequest(s.t1, hcr) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req)) + _, err := r.Persister().HandleConsentRequest(s.t1, f, hcr) require.NoError(t, err) + require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) actual, err := r.Persister().FindGrantedAndRememberedConsentRequests(s.t2, client.LegacyClientID, f.Subject) require.Error(t, err) @@ -773,11 +777,11 @@ func (s *PersisterTestSuite) TestFindSubjectsGrantedConsentRequests() { sessionID := uuid.Must(uuid.NewV4()).String() client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - req := &consent.OAuth2ConsentRequest{ + req := &flow.OAuth2ConsentRequest{ ID: "consent-request-id", LoginChallenge: sqlxx.NullString(f.ID), Skip: false, @@ -785,13 +789,13 @@ func (s *PersisterTestSuite) TestFindSubjectsGrantedConsentRequests() { CSRF: "csrf", } - hcr := &consent.AcceptOAuth2ConsentRequest{ + hcr := &flow.AcceptOAuth2ConsentRequest{ ID: req.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true, } - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req)) - _, err := r.Persister().HandleConsentRequest(s.t1, hcr) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req)) + _, err := r.Persister().HandleConsentRequest(s.t1, f, hcr) require.NoError(t, err) actual, err := r.Persister().FindSubjectsGrantedConsentRequests(s.t2, f.Subject, 100, 0) @@ -876,7 +880,7 @@ func (s *PersisterTestSuite) TestFlushInactiveLoginConsentRequests() { client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) f.RequestedAt = time.Now().Add(-24 * time.Hour) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) @@ -1056,18 +1060,18 @@ func (s *PersisterTestSuite) TestGetConsentRequest() { sessionID := uuid.Must(uuid.NewV4()).String() client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - req := &consent.OAuth2ConsentRequest{ - ID: "consent-request-id", + req := &flow.OAuth2ConsentRequest{ + ID: x.Must(f.ToConsentChallenge(s.t1, r)), LoginChallenge: sqlxx.NullString(f.ID), Skip: false, Verifier: "verifier", CSRF: "csrf", } - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req)) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req)) actual, err := r.Persister().GetConsentRequest(s.t2, req.ID) require.Error(t, err) @@ -1087,7 +1091,7 @@ func (s *PersisterTestSuite) TestGetFlow() { sessionID := uuid.Must(uuid.NewV4()).String() client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) @@ -1112,19 +1116,20 @@ func (s *PersisterTestSuite) TestGetFlowByConsentChallenge() { sessionID := uuid.Must(uuid.NewV4()).String() client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + require.NoError(t, r.Persister().CreateLoginSession(s.t1, &flow.LoginSession{ID: sessionID})) require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) store, ok := r.Persister().(*persistencesql.Persister) if !ok { t.Fatal("type assertion failed") } - _, err := store.GetFlowByConsentChallenge(s.t2, f.ConsentChallengeID.String()) + challenge := x.Must(f.ToConsentChallenge(s.t1, r)) + + _, err := store.GetFlowByConsentChallenge(s.t2, challenge) require.Error(t, err) - _, err = store.GetFlowByConsentChallenge(s.t1, f.ConsentChallengeID.String()) + _, err = store.GetFlowByConsentChallenge(s.t1, challenge) require.NoError(t, err) }) } @@ -1179,19 +1184,20 @@ func (s *PersisterTestSuite) TestGetLoginRequest() { for k, r := range s.registries { t.Run(k, func(t *testing.T) { client := &client.Client{LegacyClientID: "client-id"} - lr := consent.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()} + lr := flow.LoginRequest{ID: "lr-id", ClientID: client.LegacyClientID, RequestedAt: time.Now()} require.NoError(t, r.Persister().CreateClient(s.t1, client)) - require.NoError(t, r.ConsentManager().CreateLoginRequest(s.t1, &lr)) - f := flow.Flow{} - require.NoError(t, r.Persister().Connection(context.Background()).Find(&f, lr.ID)) + f, err := r.ConsentManager().CreateLoginRequest(s.t1, &lr) + require.NoError(t, err) require.Equal(t, s.t1NID, f.NID) - actual, err := r.Persister().GetLoginRequest(s.t2, lr.ID) + challenge := x.Must(f.ToLoginChallenge(s.t1, r)) + + actual, err := r.Persister().GetLoginRequest(s.t2, challenge) require.Error(t, err) require.Nil(t, actual) - actual, err = r.Persister().GetLoginRequest(s.t1, lr.ID) + actual, err = r.Persister().GetLoginRequest(s.t1, challenge) require.NoError(t, err) require.NotNil(t, actual) }) @@ -1203,7 +1209,7 @@ func (s *PersisterTestSuite) TestGetLogoutRequest() { for k, r := range s.registries { t.Run(k, func(t *testing.T) { client := &client.Client{LegacyClientID: "client-id"} - lr := consent.LogoutRequest{ + lr := flow.LogoutRequest{ ID: uuid.Must(uuid.NewV4()).String(), ClientID: sql.NullString{Valid: true, String: client.LegacyClientID}, } @@ -1213,11 +1219,11 @@ func (s *PersisterTestSuite) TestGetLogoutRequest() { actual, err := r.Persister().GetLogoutRequest(s.t2, lr.ID) require.Error(t, err) - require.Equal(t, &consent.LogoutRequest{}, actual) + require.Equal(t, &flow.LogoutRequest{}, actual) actual, err = r.Persister().GetLogoutRequest(s.t1, lr.ID) require.NoError(t, err) - require.NotEqual(t, &consent.LogoutRequest{}, actual) + require.NotEqual(t, &flow.LogoutRequest{}, actual) }) } } @@ -1368,14 +1374,14 @@ func (s *PersisterTestSuite) TestGetRememberedLoginSession() { t := s.T() for k, r := range s.registries { t.Run(k, func(t *testing.T) { - ls := consent.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true} + ls := flow.LoginSession{ID: uuid.Must(uuid.NewV4()).String(), Remember: true} require.NoError(t, r.Persister().CreateLoginSession(s.t1, &ls)) - actual, err := r.Persister().GetRememberedLoginSession(s.t2, ls.ID) + actual, err := r.Persister().GetRememberedLoginSession(s.t2, &ls, ls.ID) require.Error(t, err) require.Nil(t, actual) - actual, err = r.Persister().GetRememberedLoginSession(s.t1, ls.ID) + actual, err = r.Persister().GetRememberedLoginSession(s.t1, &ls, ls.ID) require.NoError(t, err) require.NotNil(t, actual) }) @@ -1389,13 +1395,12 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() { sessionID := uuid.Must(uuid.NewV4()).String() c1 := &client.Client{LegacyClientID: uuidx.NewV4().String()} f := newFlow(s.t1NID, c1.LegacyClientID, "sub", sqlxx.NullString(sessionID)) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, c1)) c1.ID = uuid.Nil require.NoError(t, r.Persister().CreateClient(s.t2, c1)) - require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - req := &consent.OAuth2ConsentRequest{ + req := &flow.OAuth2ConsentRequest{ ID: "consent-request-id", LoginChallenge: sqlxx.NullString(f.ID), Skip: false, @@ -1403,23 +1408,24 @@ func (s *PersisterTestSuite) TestHandleConsentRequest() { CSRF: "csrf", } - hcr := &consent.AcceptOAuth2ConsentRequest{ + hcr := &flow.AcceptOAuth2ConsentRequest{ ID: req.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true, } - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, req)) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, f, req)) - actualCR, err := r.Persister().HandleConsentRequest(s.t2, hcr) + actualCR, err := r.Persister().HandleConsentRequest(s.t2, f, hcr) require.Error(t, err) require.Nil(t, actualCR) actual, err := r.Persister().FindGrantedAndRememberedConsentRequests(s.t1, c1.LegacyClientID, f.Subject) require.Error(t, err) require.Equal(t, 0, len(actual)) - actualCR, err = r.Persister().HandleConsentRequest(s.t1, hcr) + actualCR, err = r.Persister().HandleConsentRequest(s.t1, f, hcr) require.NoError(t, err) require.NotNil(t, actualCR) + require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) actual, err = r.Persister().FindGrantedAndRememberedConsentRequests(s.t1, c1.LegacyClientID, f.Subject) require.NoError(t, err) require.Equal(t, 1, len(actual)) @@ -1497,21 +1503,51 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithBackChannelLogo t2f2.LoginVerifier = "t2f2-login-verifier" t2f2.ConsentVerifier = "t2f2-consent-verifier" - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: t1f1.SessionID.String()})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()}) require.NoError(t, r.Persister().Connection(context.Background()).Create(t1f1)) require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f1)) require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f2)) - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, &consent.OAuth2ConsentRequest{ID: t1f1.ID, LoginChallenge: sqlxx.NullString(t1f1.ID), Skip: false, Verifier: t1f1.ConsentVerifier.String(), CSRF: "csrf"})) - require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f1.ID, LoginChallenge: sqlxx.NullString(t2f1.ID), Skip: false, Verifier: t2f1.ConsentVerifier.String(), CSRF: "csrf"})) - require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f2.ID, LoginChallenge: sqlxx.NullString(t2f2.ID), Skip: false, Verifier: t2f2.ConsentVerifier.String(), CSRF: "csrf"})) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, t1f1, &flow.OAuth2ConsentRequest{ + ID: t1f1.ID, + LoginChallenge: sqlxx.NullString(t1f1.ID), + Skip: false, + Verifier: t1f1.ConsentVerifier.String(), + CSRF: "csrf", + })) + require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f1, &flow.OAuth2ConsentRequest{ + ID: t2f1.ID, + LoginChallenge: sqlxx.NullString(t2f1.ID), + Skip: false, + Verifier: t2f1.ConsentVerifier.String(), + CSRF: "csrf", + })) + require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f2, &flow.OAuth2ConsentRequest{ + ID: t2f2.ID, + LoginChallenge: sqlxx.NullString(t2f2.ID), + Skip: false, + Verifier: t2f2.ConsentVerifier.String(), + CSRF: "csrf", + })) - _, err := r.Persister().HandleConsentRequest(s.t1, &consent.AcceptOAuth2ConsentRequest{ID: t1f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true}) + _, err := r.Persister().HandleConsentRequest(s.t1, t1f1, &flow.AcceptOAuth2ConsentRequest{ + ID: t1f1.ID, + HandledAt: sqlxx.NullTime(time.Now()), + Remember: true, + }) require.NoError(t, err) - _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true}) + _, err = r.Persister().HandleConsentRequest(s.t2, t2f1, &flow.AcceptOAuth2ConsentRequest{ + ID: t2f1.ID, + HandledAt: sqlxx.NullTime(time.Now()), + Remember: true, + }) require.NoError(t, err) - _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f2.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true}) + _, err = r.Persister().HandleConsentRequest(s.t2, t2f2, &flow.AcceptOAuth2ConsentRequest{ + ID: t2f2.ID, + HandledAt: sqlxx.NullTime(time.Now()), + Remember: true, + }) require.NoError(t, err) cs, err := r.Persister().ListUserAuthenticatedClientsWithBackChannelLogout(s.t1, "sub", t1f1.SessionID.String()) @@ -1551,21 +1587,51 @@ func (s *PersisterTestSuite) TestListUserAuthenticatedClientsWithFrontChannelLog t2f2.LoginVerifier = "t2f2-login-verifier" t2f2.ConsentVerifier = "t2f2-consent-verifier" - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: t1f1.SessionID.String()})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: t1f1.SessionID.String()}) require.NoError(t, r.Persister().Connection(context.Background()).Create(t1f1)) require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f1)) require.NoError(t, r.Persister().Connection(context.Background()).Create(t2f2)) - require.NoError(t, r.Persister().CreateConsentRequest(s.t1, &consent.OAuth2ConsentRequest{ID: t1f1.ID, LoginChallenge: sqlxx.NullString(t1f1.ID), Skip: false, Verifier: t1f1.ConsentVerifier.String(), CSRF: "csrf"})) - require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f1.ID, LoginChallenge: sqlxx.NullString(t2f1.ID), Skip: false, Verifier: t2f1.ConsentVerifier.String(), CSRF: "csrf"})) - require.NoError(t, r.Persister().CreateConsentRequest(s.t2, &consent.OAuth2ConsentRequest{ID: t2f2.ID, LoginChallenge: sqlxx.NullString(t2f2.ID), Skip: false, Verifier: t2f2.ConsentVerifier.String(), CSRF: "csrf"})) + require.NoError(t, r.Persister().CreateConsentRequest(s.t1, t1f1, &flow.OAuth2ConsentRequest{ + ID: t1f1.ID, + LoginChallenge: sqlxx.NullString(t1f1.ID), + Skip: false, + Verifier: t1f1.ConsentVerifier.String(), + CSRF: "csrf", + })) + require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f1, &flow.OAuth2ConsentRequest{ + ID: t2f1.ID, + LoginChallenge: sqlxx.NullString(t2f1.ID), + Skip: false, + Verifier: t2f1.ConsentVerifier.String(), + CSRF: "csrf", + })) + require.NoError(t, r.Persister().CreateConsentRequest(s.t2, t2f2, &flow.OAuth2ConsentRequest{ + ID: t2f2.ID, + LoginChallenge: sqlxx.NullString(t2f2.ID), + Skip: false, + Verifier: t2f2.ConsentVerifier.String(), + CSRF: "csrf", + })) - _, err := r.Persister().HandleConsentRequest(s.t1, &consent.AcceptOAuth2ConsentRequest{ID: t1f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true}) + _, err := r.Persister().HandleConsentRequest(s.t1, t1f1, &flow.AcceptOAuth2ConsentRequest{ + ID: t1f1.ID, + HandledAt: sqlxx.NullTime(time.Now()), + Remember: true, + }) require.NoError(t, err) - _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f1.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true}) + _, err = r.Persister().HandleConsentRequest(s.t2, t2f1, &flow.AcceptOAuth2ConsentRequest{ + ID: t2f1.ID, + HandledAt: sqlxx.NullTime(time.Now()), + Remember: true, + }) require.NoError(t, err) - _, err = r.Persister().HandleConsentRequest(s.t2, &consent.AcceptOAuth2ConsentRequest{ID: t2f2.ID, HandledAt: sqlxx.NullTime(time.Now()), Remember: true}) + _, err = r.Persister().HandleConsentRequest(s.t2, t2f2, &flow.AcceptOAuth2ConsentRequest{ + ID: t2f2.ID, + HandledAt: sqlxx.NullTime(time.Now()), + Remember: true, + }) require.NoError(t, err) cs, err := r.Persister().ListUserAuthenticatedClientsWithFrontChannelLogout(s.t1, "sub", t1f1.SessionID.String()) @@ -1639,7 +1705,7 @@ func (s *PersisterTestSuite) TestRejectLogoutRequest() { require.NoError(t, r.ConsentManager().RejectLogoutRequest(s.t1, lr.ID)) actual, err = r.ConsentManager().GetLogoutRequest(s.t1, lr.ID) require.Error(t, err) - require.Equal(t, &consent.LogoutRequest{}, actual) + require.Equal(t, &flow.LogoutRequest{}, actual) }) } } @@ -1729,7 +1795,7 @@ func (s *PersisterTestSuite) TestRevokeSubjectClientConsentSession() { client := &client.Client{LegacyClientID: "client-id"} f := newFlow(s.t1NID, client.LegacyClientID, "sub", sqlxx.NullString(sessionID)) f.RequestedAt = time.Now().Add(-24 * time.Hour) - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) require.NoError(t, r.Persister().CreateClient(s.t1, client)) require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) @@ -1900,7 +1966,7 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateConsentRequest() { t.Run(k, func(t *testing.T) { sub := uuid.Must(uuid.NewV4()).String() sessionID := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) client := &client.Client{LegacyClientID: "client-id"} require.NoError(t, r.Persister().CreateClient(s.t1, client)) f := newFlow(s.t1NID, client.LegacyClientID, sub, sqlxx.NullString(sessionID)) @@ -1909,24 +1975,22 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateConsentRequest() { f.ConsentRemember = false crf := 86400 f.ConsentRememberFor = &crf - f.ConsentError = &consent.RequestDeniedError{} + f.ConsentError = &flow.RequestDeniedError{} f.SessionAccessToken = map[string]interface{}{} f.SessionIDToken = map[string]interface{}{} f.ConsentWasHandled = false f.State = flow.FlowStateConsentUnused - require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - actual := &flow.Flow{} - _, err := r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t2, f.ConsentVerifier.String()) + consentVerifier := x.Must(f.ToConsentVerifier(s.t1, r)) + + _, err := r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t2, f, consentVerifier) require.Error(t, err) - require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID)) - require.Equal(t, flow.FlowStateConsentUnused, actual.State) - require.Equal(t, false, actual.ConsentWasHandled) - _, err = r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t1, f.ConsentVerifier.String()) + require.Equal(t, flow.FlowStateConsentUnused, f.State) + require.Equal(t, false, f.ConsentWasHandled) + _, err = r.ConsentManager().VerifyAndInvalidateConsentRequest(s.t1, f, consentVerifier) require.NoError(t, err) - require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID)) - require.Equal(t, flow.FlowStateConsentUsed, actual.State) - require.Equal(t, true, actual.ConsentWasHandled) + require.Equal(t, flow.FlowStateConsentUsed, f.State) + require.Equal(t, true, f.ConsentWasHandled) }) } } @@ -1937,24 +2001,21 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateLoginRequest() { t.Run(k, func(t *testing.T) { sub := uuid.Must(uuid.NewV4()).String() sessionID := uuid.Must(uuid.NewV4()).String() - require.NoError(t, r.Persister().CreateLoginSession(s.t1, &consent.LoginSession{ID: sessionID})) + persistLoginSession(s.t1, t, r.Persister(), &flow.LoginSession{ID: sessionID}) client := &client.Client{LegacyClientID: "client-id"} require.NoError(t, r.Persister().CreateClient(s.t1, client)) f := newFlow(s.t1NID, client.LegacyClientID, sub, sqlxx.NullString(sessionID)) f.State = flow.FlowStateLoginUnused - require.NoError(t, r.Persister().Connection(context.Background()).Create(f)) - actual := &flow.Flow{} - _, err := r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t2, f.LoginVerifier) + loginVerifier := x.Must(f.ToLoginVerifier(s.t1, r)) + _, err := r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t2, f, loginVerifier) require.Error(t, err) - require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID)) - require.Equal(t, flow.FlowStateLoginUnused, actual.State) - require.Equal(t, false, actual.LoginWasUsed) - _, err = r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t1, f.LoginVerifier) + require.Equal(t, flow.FlowStateLoginUnused, f.State) + require.Equal(t, false, f.LoginWasUsed) + _, err = r.ConsentManager().VerifyAndInvalidateLoginRequest(s.t1, f, loginVerifier) require.NoError(t, err) - require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, f.ID)) - require.Equal(t, flow.FlowStateLoginUsed, actual.State) - require.Equal(t, true, actual.LoginWasUsed) + require.Equal(t, flow.FlowStateLoginUsed, f.State) + require.Equal(t, true, f.LoginWasUsed) }) } } @@ -1974,8 +2035,8 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateLogoutRequest() { lrInvalidated, err := r.ConsentManager().VerifyAndInvalidateLogoutRequest(s.t2, lr.Verifier) require.Error(t, err) - require.Equal(t, &consent.LogoutRequest{}, lrInvalidated) - actual := &consent.LogoutRequest{} + require.Nil(t, lrInvalidated) + actual := &flow.LogoutRequest{} require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, lr.ID)) require.Equal(t, expected, actual) @@ -2026,9 +2087,9 @@ func newFlow(nid uuid.UUID, clientID string, subject string, sessionID sqlxx.Nul ID: uuid.Must(uuid.NewV4()).String(), ClientID: clientID, Subject: subject, - ConsentError: &consent.RequestDeniedError{}, + ConsentError: &flow.RequestDeniedError{}, State: flow.FlowStateConsentUnused, - LoginError: &consent.RequestDeniedError{}, + LoginError: &flow.RequestDeniedError{}, Context: sqlxx.JSONRawMessage{}, AMR: sqlxx.StringSliceJSONFormat{}, ConsentChallengeID: sqlxx.NullString("not-null"), @@ -2050,8 +2111,8 @@ func newGrant(keySet string, keyID string) trust.Grant { } } -func newLogoutRequest() *consent.LogoutRequest { - return &consent.LogoutRequest{ +func newLogoutRequest() *flow.LogoutRequest { + return &flow.LogoutRequest{ ID: uuid.Must(uuid.NewV4()).String(), } } @@ -2065,15 +2126,11 @@ func newKey(ksID string, use string) jose.JSONWebKey { } func newKeySet(id string, use string) *jose.JSONWebKeySet { - ks, err := jwk.GenerateJWK(context.Background(), jose.RS256, id, use) - if err != nil { - panic(err) - } - return ks + return x.Must(jwk.GenerateJWK(context.Background(), jose.RS256, id, use)) } -func newLoginSession() *consent.LoginSession { - return &consent.LoginSession{ +func newLoginSession() *flow.LoginSession { + return &flow.LoginSession{ ID: uuid.Must(uuid.NewV4()).String(), AuthenticatedAt: sqlxx.NullTime(time.Time{}), Subject: uuid.Must(uuid.NewV4()).String(), @@ -2084,3 +2141,9 @@ func newLoginSession() *consent.LoginSession { func requireKeySetEqual(t *testing.T, expected *jose.JSONWebKeySet, actual *jose.JSONWebKeySet) { assertx.EqualAsJSON(t, expected, actual) } + +func persistLoginSession(ctx context.Context, t *testing.T, p persistence.Persister, session *flow.LoginSession) { + t.Helper() + require.NoError(t, p.CreateLoginSession(ctx, session)) + require.NoError(t, p.Connection(ctx).Create(session)) +} diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index e2ae3887516..e5217139424 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -13,21 +13,18 @@ import ( "strings" "time" - "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - - "github.com/ory/x/errorsx" - - "github.com/ory/fosite/storage" - "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/ory/fosite" + "github.com/ory/fosite/storage" + "github.com/ory/hydra/v2/oauth2" + "github.com/ory/hydra/v2/x/events" + "github.com/ory/x/errorsx" + "github.com/ory/x/otelx" "github.com/ory/x/sqlcon" "github.com/ory/x/stringsx" - - "github.com/ory/hydra/v2/oauth2" ) var _ oauth2.AssertionJWTReader = &Persister{} @@ -80,7 +77,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin } if p.config.EncryptSessionData(ctx) { - ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session) + ciphertext, err := p.r.KeyCipher().Encrypt(ctx, session, nil) if err != nil { return nil, errorsx.WithStack(err) } @@ -115,14 +112,14 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin }, nil } -func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (*fosite.Request, error) { +func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (_ *fosite.Request, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.toRequest") - defer span.End() + defer otelx.End(span, &err) sess := r.Session if !gjson.ValidBytes(sess) { var err error - sess, err = p.r.KeyCipher().Decrypt(ctx, string(sess)) + sess, err = p.r.KeyCipher().Decrypt(ctx, string(sess), nil) if err != nil { return nil, errorsx.WithStack(err) } @@ -173,9 +170,9 @@ func (p *Persister) hashSignature(_ context.Context, signature string, table tab return signature } -func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) error { +func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid") - defer span.End() + defer otelx.End(span, &err) j, err := p.GetClientAssertionJWT(ctx, jti) if errors.Is(err, sqlcon.ErrNoRows) { @@ -192,9 +189,9 @@ func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) err return nil } -func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) error { +func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWT") - defer span.End() + defer otelx.End(span, &err) // delete expired; this cleanup spares us the need for a background worker if err := p.QueryWithNetwork(ctx).Where("expires_at < CURRENT_TIMESTAMP").Delete(&oauth2.BlacklistedJTI{}); err != nil { @@ -212,31 +209,31 @@ func (p *Persister) SetClientAssertionJWT(ctx context.Context, jti string, exp t return nil } -func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (*oauth2.BlacklistedJTI, error) { +func (p *Persister) GetClientAssertionJWT(ctx context.Context, j string) (_ *oauth2.BlacklistedJTI, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetClientAssertionJWT") - defer span.End() + defer otelx.End(span, &err) jti := oauth2.NewBlacklistedJTI(j, time.Time{}) return jti, sqlcon.HandleError(p.QueryWithNetwork(ctx).Find(jti, jti.ID)) } -func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) error { +func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.BlacklistedJTI) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.SetClientAssertionJWTRaw") - defer span.End() + defer otelx.End(span, &err) return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti)) } -func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error { +func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createSession") - defer span.End() + defer otelx.End(span, &err) req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table) if err != nil { return err } - if err := sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { + if err = sqlcon.HandleError(p.CreateWithNetwork(ctx, req)); errors.Is(err, sqlcon.ErrConcurrentUpdate) { return errors.Wrap(fosite.ErrSerializationFailure, err.Error()) } else if err != nil { return err @@ -244,44 +241,39 @@ func (p *Persister) createSession(ctx context.Context, signature string, request return nil } -func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (fosite.Requester, error) { +func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature") - defer span.End() + defer otelx.End(span, &err) r := OAuth2RequestSQL{Table: table} - var fr fosite.Requester - - return fr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - // We look for the signature as well as the hash of the signature here. - // This is because we now always store the hash of the signature in the database, - // regardless of the type of the signature. In previous versions, we only stored - // the hash of the signature for JWT tokens. - // - // This code will be removed in a future version. - err := p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r) - if errors.Is(err, sql.ErrNoRows) { - return errorsx.WithStack(fosite.ErrNotFound) - } else if err != nil { - return sqlcon.HandleError(err) - } else if !r.Active { - fr, err = r.toRequest(ctx, session, p) - if err != nil { - return err - } else if table == sqlTableCode { - return errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode) - } - - return errorsx.WithStack(fosite.ErrInactiveToken) + + // We look for the signature as well as the hash of the signature here. + // This is because we now always store the hash of the signature in the database, + // regardless of the type of the signature. In previous versions, we only stored + // the hash of the signature for JWT tokens. + // + // This code will be removed in a future version. + err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r) + if errors.Is(err, sql.ErrNoRows) { + return nil, errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return nil, sqlcon.HandleError(err) + } else if !r.Active { + fr, err := r.toRequest(ctx, session, p) + if err != nil { + return nil, err + } else if table == sqlTableCode { + return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode) } + return fr, errorsx.WithStack(fosite.ErrInactiveToken) + } - fr, err = r.toRequest(ctx, session, p) - return err - }) + return r.toRequest(ctx, session, p) } -func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error { +func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature") - defer span.End() + defer otelx.End(span, &err) signature = p.hashSignature(ctx, signature, table) @@ -291,7 +283,7 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri // the hash of the signature for JWT tokens. // // This code will be removed in a future version. - err := sqlcon.HandleError( + err = sqlcon.HandleError( p.QueryWithNetwork(ctx). Where("signature IN (?, ?)", signature, SignatureHash(signature)). Delete(&OAuth2RequestSQL{Table: table})) @@ -306,9 +298,9 @@ func (p *Persister) deleteSessionBySignature(ctx context.Context, signature stri return nil } -func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) error { +func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID") - defer span.End() + defer otelx.End(span, &err) /* #nosec G201 table is static */ if err := p.QueryWithNetwork(ctx). @@ -326,9 +318,9 @@ func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, tab return nil } -func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, table tableName) error { +func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, table tableName) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deactivateSessionByRequestID") - defer span.End() + defer otelx.End(span, &err) /* #nosec G201 table is static */ return sqlcon.HandleError( @@ -342,23 +334,22 @@ func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, ) } -func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAuthorizeCodeSession") - defer span.End() - - return p.createSession(ctx, signature, requester, sqlTableCode) +func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { + return otelx.WithSpan(ctx, "persistence.sql.CreateAuthorizeCodeSession", func(ctx context.Context) error { + return p.createSession(ctx, signature, requester, sqlTableCode) + }) } func (p *Persister) GetAuthorizeCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAuthorizeCodeSession") - defer span.End() + defer otelx.End(span, &err) return p.findSessionBySignature(ctx, signature, session, sqlTableCode) } func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InvalidateAuthorizeCodeSession") - defer span.End() + defer otelx.End(span, &err) /* #nosec G201 table is static */ return sqlcon.HandleError( @@ -372,67 +363,100 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur ) } -func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - return p.createSession(ctx, signature, requester, sqlTableAccess) +func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) error { + events.Trace(ctx, events.AccessTokenIssued, events.WithRequest(requester)) + return otelx.WithSpan(ctx, "persistence.sql.CreateAccessTokenSession", func(ctx context.Context) error { + return p.createSession(ctx, signature, requester, sqlTableAccess) + }) } func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession") + defer otelx.End(span, &err) return p.findSessionBySignature(ctx, signature, session, sqlTableAccess) } -func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) { - return p.deleteSessionBySignature(ctx, signature, sqlTableAccess) +func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) error { + return otelx.WithSpan(ctx, "persistence.sql.DeleteAccessTokenSession", func(ctx context.Context) error { + return p.deleteSessionBySignature(ctx, signature, sqlTableAccess) + }) } -func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { - return p.createSession(ctx, signature, requester, sqlTableRefresh) +func (p *Persister) CreateRefreshTokenSession(ctx context.Context, signature string, requester fosite.Requester) error { + events.Trace(ctx, events.RefreshTokenIssued, events.WithRequest(requester)) + return otelx.WithSpan(ctx, "persistence.sql.CreateRefreshTokenSession", func(ctx context.Context) error { + return p.createSession(ctx, signature, requester, sqlTableRefresh) + }) } func (p *Persister) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetRefreshTokenSession") + defer otelx.End(span, &err) return p.findSessionBySignature(ctx, signature, session, sqlTableRefresh) } -func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { - return p.deleteSessionBySignature(ctx, signature, sqlTableRefresh) +func (p *Persister) DeleteRefreshTokenSession(ctx context.Context, signature string) error { + return otelx.WithSpan(ctx, "persistence.sql.DeleteRefreshTokenSession", func(ctx context.Context) error { + return p.deleteSessionBySignature(ctx, signature, sqlTableRefresh) + }) } func (p *Persister) CreateOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) error { - return p.createSession(ctx, signature, requester, sqlTableOpenID) + events.Trace(ctx, events.IdentityTokenIssued, events.WithRequest(requester)) + return otelx.WithSpan(ctx, "persistence.sql.CreateOpenIDConnectSession", func(ctx context.Context) error { + return p.createSession(ctx, signature, requester, sqlTableOpenID) + }) } -func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (fosite.Requester, error) { +func (p *Persister) GetOpenIDConnectSession(ctx context.Context, signature string, requester fosite.Requester) (_ fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetOpenIDConnectSession") + defer otelx.End(span, &err) return p.findSessionBySignature(ctx, signature, requester.GetSession(), sqlTableOpenID) } func (p *Persister) DeleteOpenIDConnectSession(ctx context.Context, signature string) error { - return p.deleteSessionBySignature(ctx, signature, sqlTableOpenID) + return otelx.WithSpan(ctx, "persistence.sql.DeleteOpenIDConnectSession", func(ctx context.Context) error { + return p.deleteSessionBySignature(ctx, signature, sqlTableOpenID) + }) } -func (p *Persister) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { +func (p *Persister) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (_ fosite.Requester, err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPKCERequestSession") + defer otelx.End(span, &err) return p.findSessionBySignature(ctx, signature, session, sqlTablePKCE) } func (p *Persister) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { - return p.createSession(ctx, signature, requester, sqlTablePKCE) + return otelx.WithSpan(ctx, "persistence.sql.CreatePKCERequestSession", func(ctx context.Context) error { + return p.createSession(ctx, signature, requester, sqlTablePKCE) + }) } func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature string) error { - return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE) + return otelx.WithSpan(ctx, "persistence.sql.DeletePKCERequestSession", func(ctx context.Context) error { + return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE) + }) } func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) error { - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + return otelx.WithSpan(ctx, "persistence.sql.RevokeRefreshToken", func(ctx context.Context) error { + return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + }) } func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, _ string) error { - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + return otelx.WithSpan(ctx, "persistence.sql.RevokeRefreshTokenMaybeGracePeriod", func(ctx context.Context) error { + return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) + }) } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) error { - return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) + return otelx.WithSpan(ctx, "persistence.sql.RevokeAccessToken", func(ctx context.Context) error { + return p.deleteSessionByRequestID(ctx, id, sqlTableAccess) + }) } -func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) error { +func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int, table tableName, lifespan time.Duration) (err error) { /* #nosec G201 table is static */ // The value of notAfter should be the minimum between input parameter and token max expire based on its configured age requestMaxExpire := time.Now().Add(-lifespan) @@ -440,8 +464,6 @@ func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, notAfter = requestMaxExpire } - var err error - totalDeletedCount := 0 for deletedRecords := batchSize; totalDeletedCount < limit && deletedRecords == batchSize; { d := batchSize @@ -469,16 +491,22 @@ func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time, } func (p *Persister) FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error { - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx)) + return otelx.WithSpan(ctx, "persistence.sql.FlushInactiveAccessTokens", func(ctx context.Context) error { + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableAccess, p.config.GetAccessTokenLifespan(ctx)) + }) } func (p *Persister) FlushInactiveRefreshTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error { - return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx)) + return otelx.WithSpan(ctx, "persistence.sql.FlushInactiveRefreshTokens", func(ctx context.Context) error { + return p.flushInactiveTokens(ctx, notAfter, limit, batchSize, sqlTableRefresh, p.config.GetRefreshTokenLifespan(ctx)) + }) } func (p *Persister) DeleteAccessTokens(ctx context.Context, clientID string) error { - /* #nosec G201 table is static */ - return sqlcon.HandleError( - p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), - ) + return otelx.WithSpan(ctx, "persistence.sql.DeleteAccessTokens", func(ctx context.Context) error { + /* #nosec G201 table is static */ + return sqlcon.HandleError( + p.QueryWithNetwork(ctx).Where("client_id=?", clientID).Delete(&OAuth2RequestSQL{Table: sqlTableAccess}), + ) + }) } diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 475d32b88a8..ad71b374909 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -12,9 +12,9 @@ import ( "github.com/gobuffalo/pop/v6" "github.com/gofrs/uuid" - "github.com/instana/testify/assert" - "github.com/instana/testify/require" "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ory/hydra/v2/client" "github.com/ory/hydra/v2/consent" @@ -52,12 +52,12 @@ func testRegistry(t *testing.T, ctx context.Context, k string, t1 driver.Registr parallel = false } - t.Run("package=consent/manager="+k, consent.ManagerTests(t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel)) - t.Run("package=consent/manager="+k, consent.ManagerTests(t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel)) + t.Run("package=consent/manager="+k, consent.ManagerTests(t1, t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel)) + t.Run("package=consent/manager="+k, consent.ManagerTests(t2, t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel)) t.Run("parallel-boundary", func(t *testing.T) { - t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t1.Config(), t1.ConsentManager(), t1.ClientManager(), t1.OAuth2Storage(), "t1", parallel)) - t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t2.Config(), t2.ConsentManager(), t2.ClientManager(), t2.OAuth2Storage(), "t2", parallel)) + t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t1, "t1", parallel)) + t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(t2, "t2", parallel)) }) t.Run("package=jwk/manager="+k, func(t *testing.T) { @@ -186,7 +186,7 @@ func TestManagers(t *testing.T) { ) } t.Run("package=consent/manager="+k+"/case=nid", - consent.TestHelperNID(t1.ClientManager(), t1.ConsentManager(), t2.ConsentManager()), + consent.TestHelperNID(t1, t1.ConsentManager(), t2.ConsentManager()), ) } } diff --git a/quickstart.yml b/quickstart.yml index d6f853838b0..0cc88da1eb8 100644 --- a/quickstart.yml +++ b/quickstart.yml @@ -12,7 +12,7 @@ version: "3.7" services: hydra: - image: oryd/hydra:v2.1.2 + image: oryd/hydra:v2.2.0-rc.2 ports: - "4444:4444" # Public port - "4445:4445" # Admin port @@ -34,7 +34,7 @@ services: networks: - intranet hydra-migrate: - image: oryd/hydra:v2.1.2 + image: oryd/hydra:v2.2.0-rc.2 environment: - DSN=sqlite:///var/lib/sqlite/db.sqlite?_fk=true command: migrate -c /etc/config/hydra/hydra.yml sql -e --yes @@ -52,7 +52,7 @@ services: consent: environment: - HYDRA_ADMIN_URL=http://hydra:4445 - image: oryd/hydra-login-consent-node:v2.1.2 + image: oryd/hydra-login-consent-node:v2.2.0-rc.2 ports: - "3000:3000" restart: unless-stopped diff --git a/test/conformance/hydra/Dockerfile b/test/conformance/hydra/Dockerfile index df86aefa45b..58eb6d8155a 100644 --- a/test/conformance/hydra/Dockerfile +++ b/test/conformance/hydra/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.19-buster AS builder +FROM golang:1.20-buster AS builder RUN apt-get update && \ apt-get install --no-install-recommends -y \ diff --git a/x/doc_swagger.go b/x/doc_swagger.go index af7944ca09d..5c8fb350e8b 100644 --- a/x/doc_swagger.go +++ b/x/doc_swagger.go @@ -7,11 +7,15 @@ package x // typically 201. // // swagger:response emptyResponse +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type emptyResponse struct{} // Error // // swagger:model errorOAuth2 +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type errorOAuth2 struct { // Error Name string `json:"error"` @@ -40,6 +44,8 @@ type errorOAuth2 struct { // Default Error Response // // swagger:response errorOAuth2Default +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type errorOAuth2Default struct { // in: body Body errorOAuth2 @@ -48,6 +54,8 @@ type errorOAuth2Default struct { // Bad Request Error Response // // swagger:response errorOAuth2BadRequest +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type errorOAuth2BadRequest struct { // in: body Body errorOAuth2 @@ -56,6 +64,8 @@ type errorOAuth2BadRequest struct { // Not Found Error Response // // swagger:response errorOAuth2NotFound +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type errorOAuth2NotFound struct { // in: body Body errorOAuth2 diff --git a/x/errors.go b/x/errors.go index 229884a5d51..f90802bf50a 100644 --- a/x/errors.go +++ b/x/errors.go @@ -31,3 +31,10 @@ func LogError(r *http.Request, err error, logger *logrusx.Logger) { logger.WithRequest(r). WithError(err).Errorln("An error occurred") } + +func Must[T any](t T, err error) T { + if err != nil { + panic(err) + } + return t +} diff --git a/x/events/events.go b/x/events/events.go new file mode 100644 index 00000000000..4ea5207472b --- /dev/null +++ b/x/events/events.go @@ -0,0 +1,86 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package events + +import ( + "context" + + otelattr "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/ory/fosite" + "github.com/ory/x/otelx/semconv" +) + +const ( + // LoginAccepted will be emitted when the login UI accepts a login request. + LoginAccepted semconv.Event = "OAuth2LoginAccepted" + + // LoginRejected will be emitted when the login UI rejects a login request. + LoginRejected semconv.Event = "OAuth2LoginRejected" + + // ConsentAccepted will be emitted when the consent UI accepts a consent request. + ConsentAccepted semconv.Event = "OAuth2ConsentAccepted" + + // ConsentRejected will be emitted when the consent UI rejects a consent request. + ConsentRejected semconv.Event = "OAuth2ConsentRejected" + + // ConsentRevoked will be emitted when the user revokes a consent request. + ConsentRevoked semconv.Event = "OAuth2ConsentRevoked" + + // AccessTokenIssued will be emitted by requests to POST /oauth2/token in case the request was successful. + AccessTokenIssued semconv.Event = "OAuth2AccessTokenIssued" //nolint:gosec + + // TokenExchangeError will be emitted by requests to POST /oauth2/token in case the request was unsuccessful. + TokenExchangeError semconv.Event = "OAuth2TokenExchangeError" //nolint:gosec + + // AccessTokenInspected will be emitted by requests to POST /admin/oauth2/introspect. + AccessTokenInspected semconv.Event = "OAuth2AccessTokenInspected" //nolint:gosec + + // AccessTokenRevoked will be emitted by requests to POST /oauth2/revoke. + AccessTokenRevoked semconv.Event = "OAuth2AccessTokenRevoked" //nolint:gosec + + // RefreshTokenIssued will be emitted when a refresh token is issued. + RefreshTokenIssued semconv.Event = "OAuth2RefreshTokenIssued" //nolint:gosec + + // IdentityTokenIssued will be emitted when a refresh token is issued. + IdentityTokenIssued semconv.Event = "OIDCIdentityTokenIssued" //nolint:gosec +) + +const ( + attributeKeyOAuth2ClientID = "OAuth2ClientID" + attributeKeyOAuth2Subject = "OAuth2Subject" +) + +// WithClientID emits the client ID as part of the event. +func WithClientID(clientID string) trace.EventOption { + return trace.WithAttributes(otelattr.String(attributeKeyOAuth2ClientID, clientID)) +} + +// WithSubject emits the subject as part of the event. +func WithSubject(subject string) trace.EventOption { + return trace.WithAttributes(otelattr.String(attributeKeyOAuth2Subject, subject)) +} + +// WithRequest emits the subject and client ID from the fosite request as part of the event. +func WithRequest(request fosite.Requester) trace.EventOption { + var attributes []otelattr.KeyValue + if client := request.GetClient(); client != nil { + attributes = append(attributes, otelattr.String(attributeKeyOAuth2ClientID, client.GetID())) + } + if session := request.GetSession(); session != nil { + attributes = append(attributes, otelattr.String(attributeKeyOAuth2Subject, session.GetSubject())) + } + + return trace.WithAttributes(attributes...) +} + +// Trace emits an event with the given attributes. +func Trace(ctx context.Context, event semconv.Event, opts ...trace.EventOption) { + allOpts := append([]trace.EventOption{trace.WithAttributes(semconv.AttributesFromContext(ctx)...)}, opts...) + trace.SpanFromContext(ctx).AddEvent( + string(event), + allOpts..., + ) +} diff --git a/x/oauth2cors/cors.go b/x/oauth2cors/cors.go index c01f39aeadd..050da40b0a2 100644 --- a/x/oauth2cors/cors.go +++ b/x/oauth2cors/cors.go @@ -4,8 +4,6 @@ package oauth2cors import ( - "context" - "fmt" "net/http" "strings" @@ -21,113 +19,128 @@ import ( ) func Middleware( - ctx context.Context, reg interface { x.RegistryLogger oauth2.Registry client.Registry }) func(h http.Handler) http.Handler { - opts, enabled := reg.Config().CORS(ctx, config.PublicInterface) - if !enabled { - return func(h http.Handler) http.Handler { - return h - } - } - - var alwaysAllow = len(opts.AllowedOrigins) == 0 - var patterns []glob.Glob - for _, o := range opts.AllowedOrigins { - if o == "*" { - alwaysAllow = true - } - // if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator. - // This way g := glob.Compile("http://**") g.Match("http://google.com") returns true. - if splittedO := strings.Split(o, "://"); len(splittedO) != 1 && splittedO[1] == "*" { - o = fmt.Sprintf("%s://**", splittedO[0]) - } - g, err := glob.Compile(strings.ToLower(o), '.') - if err != nil { - reg.Logger().WithError(err).Fatalf("Unable to parse cors origin: %s", o) - } - - patterns = append(patterns, g) - } - - options := cors.Options{ - AllowedOrigins: opts.AllowedOrigins, - AllowedMethods: opts.AllowedMethods, - AllowedHeaders: opts.AllowedHeaders, - ExposedHeaders: opts.ExposedHeaders, - MaxAge: opts.MaxAge, - AllowCredentials: opts.AllowCredentials, - OptionsPassthrough: opts.OptionsPassthrough, - Debug: opts.Debug, - AllowOriginRequestFunc: func(r *http.Request, origin string) bool { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - if alwaysAllow { - return true - } - origin = strings.ToLower(origin) - for _, p := range patterns { - if p.Match(origin) { - return true - } + opts, enabled := reg.Config().CORS(ctx, config.PublicInterface) + if !enabled { + reg.Logger().Debug("not enhancing CORS per client, as CORS is disabled") + h.ServeHTTP(w, r) + return } - // pre-flight requests do not contain credentials (cookies, HTTP authorization) - // so we return true in all cases here. - if r.Method == http.MethodOptions { - return true - } - - var clientID string - - // if the client uses client_secret_post auth it will provide its client ID in form data - clientID = r.PostFormValue("client_id") - - // if the client uses client_secret_basic auth the client ID will be the username component - if clientID == "" { - clientID, _, _ = r.BasicAuth() - } - - // otherwise, this may be a bearer auth request, in which case we can introspect the token - if clientID == "" { - token := fosite.AccessTokenFromRequest(r) - if token == "" { - return false + alwaysAllow := len(opts.AllowedOrigins) == 0 + patterns := make([]glob.Glob, 0, len(opts.AllowedOrigins)) + for _, o := range opts.AllowedOrigins { + if o == "*" { + alwaysAllow = true + break } - - session := oauth2.NewSessionWithCustomClaims("", reg.Config().AllowedTopLevelClaims(ctx)) - _, ar, err := reg.OAuth2Provider().IntrospectToken(ctx, token, fosite.AccessToken, session) + // if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator. + // This way g := glob.Compile("http://**") g.Match("http://google.com") returns true. + if scheme, rest, found := strings.Cut(o, "://"); found && rest == "*" { + o = scheme + "://**" + } + g, err := glob.Compile(strings.ToLower(o), '.') if err != nil { - return false + reg.Logger().WithError(err).WithField("pattern", o).Error("Unable to parse CORS origin, ignoring it") + continue } - clientID = ar.GetClient().GetID() + patterns = append(patterns, g) } - cl, err := reg.ClientManager().GetConcreteClient(ctx, clientID) - if err != nil { - return false - } + options := cors.Options{ + AllowedOrigins: opts.AllowedOrigins, + AllowedMethods: opts.AllowedMethods, + AllowedHeaders: opts.AllowedHeaders, + ExposedHeaders: opts.ExposedHeaders, + MaxAge: opts.MaxAge, + AllowCredentials: opts.AllowCredentials, + OptionsPassthrough: opts.OptionsPassthrough, + Debug: opts.Debug, + AllowOriginRequestFunc: func(r *http.Request, origin string) bool { + ctx := r.Context() + if alwaysAllow { + return true + } + + origin = strings.ToLower(origin) + for _, p := range patterns { + if p.Match(origin) { + return true + } + } + + // pre-flight requests do not contain credentials (cookies, HTTP authorization) + // so we return true in all cases here. + if r.Method == http.MethodOptions { + return true + } + + var clientID string + + // if the client uses client_secret_post auth it will provide its client ID in form data + clientID = r.PostFormValue("client_id") + + // if the client uses client_secret_basic auth the client ID will be the username component + if clientID == "" { + clientID, _, _ = r.BasicAuth() + } + + // otherwise, this may be a bearer auth request, in which case we can introspect the token + if clientID == "" { + token := fosite.AccessTokenFromRequest(r) + if token == "" { + return false + } + + session := oauth2.NewSessionWithCustomClaims("", reg.Config().AllowedTopLevelClaims(ctx)) + _, ar, err := reg.OAuth2Provider().IntrospectToken(ctx, token, fosite.AccessToken, session) + if err != nil { + return false + } + + clientID = ar.GetClient().GetID() + } + + cl, err := reg.ClientManager().GetConcreteClient(ctx, clientID) + if err != nil { + return false + } + + for _, o := range cl.AllowedCORSOrigins { + if o == "*" { + return true + } + + // if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator. + // This way g := glob.Compile("http://**") g.Match("http://google.com") returns true. + if scheme, rest, found := strings.Cut(o, "://"); found && rest == "*" { + o = scheme + "://**" + } + + g, err := glob.Compile(strings.ToLower(o), '.') + if err != nil { + return false + } + if g.Match(origin) { + return true + } + } - for _, o := range cl.AllowedCORSOrigins { - if o == "*" { - return true - } - g, err := glob.Compile(strings.ToLower(o), '.') - if err != nil { return false - } - if g.Match(origin) { - return true - } + }, } - return false - }, + reg.Logger().Debug("enhancing CORS per client") + cors.New(options).Handler(h).ServeHTTP(w, r) + }) } - - return cors.New(options).Handler } diff --git a/x/oauth2cors/cors_test.go b/x/oauth2cors/cors_test.go index d9c063dec50..62b57d29ff1 100644 --- a/x/oauth2cors/cors_test.go +++ b/x/oauth2cors/cors_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/ory/hydra/v2/driver" - "github.com/ory/hydra/v2/x/oauth2cors" "github.com/ory/x/contextx" "github.com/ory/hydra/v2/x" @@ -30,8 +29,10 @@ import ( ) func TestOAuth2AwareCORSMiddleware(t *testing.T) { + ctx := context.Background() r := internal.NewRegistryMemory(t, internal.NewConfigurationWithDefaults(), &contextx.Default{}) - token, signature, _ := r.OAuth2HMACStrategy().GenerateAccessToken(context.Background(), nil) + token, signature, _ := r.OAuth2HMACStrategy().GenerateAccessToken(ctx, nil) + for k, tc := range []struct { prep func(*testing.T, driver.Registry) d string @@ -52,8 +53,8 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should reject when basic auth but client does not exist and cors enabled", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo", "bar"))}}, @@ -62,11 +63,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should reject when post auth client exists but origin not allowed", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}}, @@ -77,11 +78,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when post auth client exists and origin allowed", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}}, @@ -92,11 +93,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should reject when basic auth client exists but origin not allowed", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-2", "bar"))}}, @@ -105,10 +106,10 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when basic auth client exists and origin allowed", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-3", "bar"))}}, @@ -117,11 +118,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when basic auth client exists and origin allowed", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-3", "bar"))}}, @@ -130,11 +131,24 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed per client", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{}) + + // Ignore unique violations + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*.foobar.com"}}) + }, + code: http.StatusNotImplemented, + header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-4", "bar"))}}, + expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}}, + }, + { + d: "should accept when basic auth client exists and wildcard origin is allowed per client", + prep: func(t *testing.T, r driver.Registry) { + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*.foobar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-4", "bar"))}}, @@ -143,11 +157,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when basic auth client exists and origin (with full wildcard) is allowed globally", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"*"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"*"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-5", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-5", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"*"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-5", "bar"))}}, @@ -156,11 +170,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed globally", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://*.foobar.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://*.foobar.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-6", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-6", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-6", "bar"))}}, @@ -169,11 +183,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept when basic auth client exists and origin (with full wildcard) allowed per client", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-7", Secret: "bar", AllowedCORSOrigins: []string{"*"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-7", Secret: "bar", AllowedCORSOrigins: []string{"*"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-7", "bar"))}}, @@ -182,8 +196,8 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should succeed on pre-flight request when token introspection fails", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer 1234"}}, @@ -193,8 +207,8 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should fail when token introspection fails", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer 1234"}}, @@ -203,8 +217,8 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should work when token introspection returns a session", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"}) sess := oauth2.NewSession("foo-9") sess.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour)) ar := fosite.NewAccessRequest(sess) @@ -212,8 +226,8 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { ar.Client = cl // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), cl) - _ = r.OAuth2Storage().CreateAccessTokenSession(context.Background(), signature, ar) + _ = r.ClientManager().CreateClient(ctx, cl) + _ = r.OAuth2Storage().CreateAccessTokenSession(ctx, signature, ar) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer " + token}}, @@ -222,12 +236,12 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept any allowed specified origin protocol", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-11", Secret: "bar", AllowedCORSOrigins: []string{"*"}}) - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://*", "https://*"}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-11", Secret: "bar", AllowedCORSOrigins: []string{"*"}}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://*", "https://*"}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-11", "bar"))}}, @@ -236,11 +250,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept client origin when basic auth client exists and origin is set at the client as well as the server", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://**.example.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://**.example.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-12", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-12", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://myapp.example.biz"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-12", "bar"))}}, @@ -249,11 +263,11 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { { d: "should accept server origin when basic auth client exists and origin is set at the client as well as the server", prep: func(t *testing.T, r driver.Registry) { - r.Config().MustSet(context.Background(), "serve.public.cors.enabled", true) - r.Config().MustSet(context.Background(), "serve.public.cors.allowed_origins", []string{"http://**.example.com"}) + r.Config().MustSet(ctx, "serve.public.cors.enabled", true) + r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://**.example.com"}) // Ignore unique violations - _ = r.ClientManager().CreateClient(context.Background(), &client.Client{LegacyClientID: "foo-13", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}}) + _ = r.ClientManager().CreateClient(ctx, &client.Client{LegacyClientID: "foo-13", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}}) }, code: http.StatusNotImplemented, header: http.Header{"Origin": {"http://client-app.example.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-13", "bar"))}}, @@ -278,7 +292,7 @@ func TestOAuth2AwareCORSMiddleware(t *testing.T) { } res := httptest.NewRecorder() - oauth2cors.Middleware(context.Background(), r)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.OAuth2AwareMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) })).ServeHTTP(res, req) require.NoError(t, err) diff --git a/x/sqlx.go b/x/sqlx.go index 7ca0e5a727d..0b90b923665 100644 --- a/x/sqlx.go +++ b/x/sqlx.go @@ -71,6 +71,8 @@ func (ns *Duration) UnmarshalJSON(data []byte) error { } // swagger:model NullDuration +// +//lint:ignore U1000 Used to generate Swagger and OpenAPI definitions type swaggerNullDuration string // NullDuration represents a nullable JSON and SQL compatible time.Duration.