diff --git a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticator.java index 726dc8ee8a..0089b95376 100644 --- a/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticator.java @@ -19,17 +19,14 @@ import java.text.ParseException; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeUnit; -import org.apache.http.StatusLine; -import org.apache.http.client.methods.HttpGet; +import org.apache.http.HttpEntity; import org.apache.http.client.config.RequestConfig; -import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.impl.client.HttpClients; -import org.apache.http.conn.HttpClientConnectionManager; -import org.apache.http.HttpEntity; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -46,8 +43,8 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; -import static org.apache.http.entity.ContentType.APPLICATION_JSON; import static org.apache.http.HttpHeaders.AUTHORIZATION; +import static org.apache.http.entity.ContentType.APPLICATION_JSON; import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.APPLICATION_JWT; import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.CLIENT_ID; import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.ISSUER_ID_URL; @@ -140,9 +137,9 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex HttpGet httpGet = new HttpGet(this.userInfoEndpoint); RequestConfig requestConfig = RequestConfig.custom() - .setConnectionRequestTimeout(requestTimeoutMs) - .setConnectTimeout(requestTimeoutMs) - .build(); + .setConnectionRequestTimeout(requestTimeoutMs) + .setConnectTimeout(requestTimeoutMs) + .build(); httpGet.setConfig(requestConfig); httpGet.addHeader(AUTHORIZATION, request.getHeaders().get(AUTHORIZATION).get(0)); @@ -150,10 +147,9 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex // HTTPGet should internally verify the appropriate TLS cert. try (CloseableHttpResponse response = httpClient.execute(httpGet)) { - StatusLine statusLine = response.getStatusLine(); - if (statusLine.getStatusCode() < 200 || statusLine.getStatusCode() >= 300) { + if (response.getStatusLine().getStatusCode() < 200 || response.getStatusLine().getStatusCode() >= 300) { throw new AuthenticatorUnavailableException( - "Error while getting " + this.userInfoEndpoint + ": " + statusLine + "Error while getting " + this.userInfoEndpoint + ": Invalid status code " + response.getStatusLine().getStatusCode() ); } @@ -166,16 +162,16 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex String contentType = httpEntity.getContentType().getValue(); if (!contentType.contains(APPLICATION_JSON.getMimeType()) && !contentType.contains(APPLICATION_JWT)) { throw new AuthenticatorUnavailableException( - "Error while getting " + this.userInfoEndpoint + ": Invalid content type in response" + "Error while getting " + this.userInfoEndpoint + ": Invalid content type in response" ); } String userinfoContent; try ( - // got this from ChatGpt & Amazon Q - InputStream inputStream = httpEntity.getContent(); - InputStreamReader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8) + // got this from ChatGpt & Amazon Q + InputStream inputStream = httpEntity.getContent(); + InputStreamReader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8) ) { StringBuilder content = new StringBuilder(); char[] buffer = new char[8192]; @@ -186,7 +182,7 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex userinfoContent = content.toString(); } catch (IOException e) { throw new AuthenticatorUnavailableException( - "Error while getting " + this.userInfoEndpoint + ": Unable to read response content" + "Error while getting " + this.userInfoEndpoint + ": Unable to read response content" ); } @@ -203,7 +199,7 @@ public AuthCredentials extractCredentials0(SecurityRequest request, ThreadContex String missing = validateResponseClaims(claims, id, isSigned); if (!missing.isBlank()) { throw new AuthenticatorUnavailableException( - "Error while getting " + this.userInfoEndpoint + ": Missing or invalid required claims in response: " + missing + "Error while getting " + this.userInfoEndpoint + ": Missing or invalid required claims in response: " + missing ); } @@ -243,8 +239,8 @@ private String validateResponseClaims(JWTClaimsSet claims, String id, boolean is missing = missing.concat("iss"); } if (claims.getAudience() == null - || claims.getAudience().toString().isBlank() - || !claims.getAudience().contains(settings.get(CLIENT_ID))) { + || claims.getAudience().toString().isBlank() + || !claims.getAudience().contains(settings.get(CLIENT_ID))) { missing = missing.concat("aud"); } } @@ -269,15 +265,15 @@ protected KeyProvider initKeyProvider(Settings settings, Path configPath) throws if (jwksUri != null && !jwksUri.isBlank()) { keySetRetriever = new KeySetRetriever( - getSSLConfig(settings, configPath), - settings.getAsBoolean("cache_jwks_endpoint", false), - jwksUri + getSSLConfig(settings, configPath), + settings.getAsBoolean("cache_jwks_endpoint", false), + jwksUri ); } else { keySetRetriever = new KeySetRetriever( - settings.get("openid_connect_url"), - getSSLConfig(settings, configPath), - settings.getAsBoolean("cache_jwks_endpoint", false) + settings.get("openid_connect_url"), + getSSLConfig(settings, configPath), + settings.getAsBoolean("cache_jwks_endpoint", false) ); } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticatorTests.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticatorTests.java index 29369efb9d..9d66a99aea 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticatorTests.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/HTTPOpenIdAuthenticatorTests.java @@ -1,2 +1,693 @@ -package com.amazon.dlic.auth.http.jwt.keybyoidc;public class HTTPOpenIdAuthenticatorTests { +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package com.amazon.dlic.auth.http.jwt.keybyoidc; + +import java.util.HashMap; +import java.util.List; + +import com.google.common.collect.ImmutableMap; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.common.settings.Settings; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.util.FakeRestRequest; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.APPLICATION_JWT; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.CLIENT_ID; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.ISSUER_ID_URL; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.MCCOY_SUBJECT; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.OIDC_TEST_AUD; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.OIDC_TEST_ISS; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.ROLES_CLAIM; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.spy; + +public class HTTPOpenIdAuthenticatorTests { + + protected static MockIpdServer mockIdpServer; + + @BeforeClass + public static void setUp() throws Exception { + mockIdpServer = new MockIpdServer(TestJwk.Jwks.ALL); + } + + @AfterClass + public static void tearDown() { + if (mockIdpServer != null) { + try { + mockIdpServer.close(); + } catch (Exception e) {} + } + } + + @Test + public void basicTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void jwksUriTest() throws Exception { + Settings settings = Settings.builder() + .put("jwks_uri", mockIdpServer.getJwksUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void jwksMissingRequiredIssuerInClaimTest() throws Exception { + Settings settings = Settings.builder() + .put("jwks_uri", mockIdpServer.getJwksUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NO_ISSUER_OCT_1), new HashMap<>()) + .asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void jwksNotMatchingRequiredIssuerInClaimTest() throws Exception { + Settings settings = Settings.builder().put("jwks_uri", mockIdpServer.getJwksUri()).put("required_issuer", "Wrong Issuer").build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void jwksMatchAtLeastOneRequiredAudienceInClaimTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void jwksMissingRequiredAudienceInClaimTest() throws Exception { + Settings settings = Settings.builder() + .put("jwks_uri", mockIdpServer.getJwksUri()) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_NO_AUDIENCE_OCT_1), new HashMap<>()) + .asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void jwksNotMatchingRequiredAudienceInClaimTest() throws Exception { + Settings settings = Settings.builder() + .put("jwks_uri", mockIdpServer.getJwksUri()) + .put("required_audience", "Wrong Audience") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_2), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void jwksUriMissingTest() { + var exception = Assert.assertThrows(Exception.class, () -> { + HTTPOpenIdAuthenticator jwtAuth = new HTTPOpenIdAuthenticator(Settings.builder().build(), null); + jwtAuth.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap<>()).asSecurityRequest(), + null + ); + }); + + assertThat(exception.getMessage(), is("Authentication backend failed")); + assertThat(exception.getClass(), is(OpenSearchSecurityException.class)); + } + + @Test + public void testEscapeKid() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_INVALID_KID), + new HashMap() + ).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void bearerTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()) + .asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void testRoles() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("roles_key", ROLES_CLAIM) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1), new HashMap()) + .asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles(), is(TestJwts.TEST_ROLES)); + } + + @Test + public void testExp() throws Exception { + Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_EXPIRED_SIGNED_OCT_1), new HashMap<>()) + .asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void testExpInSkew() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("jwt_clock_skew_tolerance_seconds", "10") + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + long expiringDate = System.currentTimeMillis() / 1000 - 5; + long notBeforeDate = System.currentTimeMillis() / 1000 - 25; + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), + new HashMap<>() + ).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + } + + @Test + public void testNbf() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("jwt_clock_skew_tolerance_seconds", "0") + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + long expiringDate = 20 + System.currentTimeMillis() / 1000; + long notBeforeDate = 5 + System.currentTimeMillis() / 1000; + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), + new HashMap<>() + ).asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void testNbfInSkew() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("jwt_clock_skew_tolerance_seconds", "10") + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + long expiringDate = 20 + System.currentTimeMillis() / 1000; + long notBeforeDate = 5 + System.currentTimeMillis() / 1000; + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.createMcCoySignedOct1(notBeforeDate, expiringDate)), + new HashMap<>() + ).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + } + + @Test + public void testRS256() throws Exception { + + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void testBadSignature() throws Exception { + + Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_X), new HashMap<>()).asSecurityRequest(), + null + ); + + Assert.assertNull(creds); + } + + @Test + public void testPeculiarJsonEscaping() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.PeculiarEscaping.MC_COY_SIGNED_RSA_1), new HashMap<>()) + .asSecurityRequest(), + null + ); + + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getAttributes().get("attr.jwt.aud"), is(List.of(TestJwts.TEST_AUDIENCE).toString())); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void userinfoEndpointReturnsJwtWithAllRequirementsTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri()) + .put(CLIENT_ID, OIDC_TEST_AUD) + .put(ISSUER_ID_URL, OIDC_TEST_ISS) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = spy(new HTTPOpenIdAuthenticator(settings, null)); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void userinfoEndpointReturnsJwtWithRequiredAudIssFailsTest() throws Exception { // Setting a required issuer or audience + // alongside userinfo endpoint settings causes + // failures in signed response cases + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri()) + .put(CLIENT_ID, OIDC_TEST_AUD) + .put(ISSUER_ID_URL, OIDC_TEST_ISS) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = null; + String message = ""; + try { + creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + } catch (RuntimeException e) { + message = e.getMessage(); + } + Assert.assertNull(creds); + assertTrue(message.contains("JWT audience rejected")); + } + + @Test + public void userinfoEndpointReturnsJwtWithMatchingRequiredAudIssPassesTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri()) + .put(CLIENT_ID, OIDC_TEST_AUD) + .put(ISSUER_ID_URL, OIDC_TEST_ISS) + .put("required_issuer", OIDC_TEST_ISS) + .put("required_audience", OIDC_TEST_AUD) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1_OIDC, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(4)); + } + + @Test + public void userinfoEndpointReturnsJwtMissingIssuerTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri()) + .put(CLIENT_ID, OIDC_TEST_AUD) + .put(ISSUER_ID_URL, "http://www.differentexample.com") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = null; + String message = ""; + try { + creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + } catch (AuthenticatorUnavailableException e) { + message = e.getMessage(); + } + Assert.assertNull(creds); + assertTrue(message.contains("Missing or invalid required claims in response: iss")); + } + + @Test + public void userinfoEndpointReturnsJwtMissingAudienceTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri()) + .put(CLIENT_ID, "aDifferentTestClient") + .put(ISSUER_ID_URL, "http://www.example.com") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = null; + String message = ""; + try { + creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + } catch (AuthenticatorUnavailableException e) { + message = e.getMessage(); + } + Assert.assertNull(creds); + assertTrue(message.contains("Missing or invalid required claims in response: aud")); + } + + @Test + public void userinfoEndpointReturnsJwtMismatchedSubTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoSignedUri()) + .put(CLIENT_ID, "testClient") + .put(ISSUER_ID_URL, "http://www.example.com") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + AuthCredentials creds = null; + String message = ""; + try { + creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.STEPHEN_RSA_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + } catch (AuthenticatorUnavailableException e) { + message = e.getMessage(); + } + Assert.assertNull(creds); + assertTrue(message.contains("Missing or invalid required claims in response: sub")); + } + + @Test + public void userinfoEndpointReturnsJsonWithAllRequirementsTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoUri()) + .put(CLIENT_ID, "testClient") + .put(ISSUER_ID_URL, "http://www.example.com") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = spy(new HTTPOpenIdAuthenticator(settings, null)); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(2)); + } + + @Test + public void userinfoEndpointReturnsJsonMismatchedSubTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoUri()) + .put(CLIENT_ID, "testClient") + .put(ISSUER_ID_URL, "http://www.example.com") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + AuthCredentials creds = null; + String message = ""; + try { + creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.STEPHEN_RSA_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + } catch (AuthenticatorUnavailableException e) { + message = e.getMessage(); + } + Assert.assertNull(creds); + assertTrue(message.contains("Missing or invalid required claims in response: sub")); + } + + @Test + public void userinfoEndpointReturnsResponseNot2xxTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoUri()) + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE + ",another_audience") + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + AuthCredentials creds = null; + String message = ""; + try { + creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + } catch (AuthenticatorUnavailableException e) { + message = e.getMessage(); + } + Assert.assertNull(creds); + assertTrue(message.contains("Error while getting")); + } + + @Test + public void userinfoEndpointReturnsJsonWithRequiredAudIssPassesTest() throws Exception { + Settings settings = Settings.builder() + .put("openid_connect_url", mockIdpServer.getDiscoverUri()) + .put("userinfo_endpoint", mockIdpServer.getUserinfoUri()) + .put(CLIENT_ID, "testClient") + .put(ISSUER_ID_URL, "http://www.example.com") + .put("required_issuer", TestJwts.TEST_ISSUER) + .put("required_audience", TestJwts.TEST_AUDIENCE) + .build(); + + HTTPOpenIdAuthenticator openIdAuthenticator = new HTTPOpenIdAuthenticator(settings, null); + + AuthCredentials creds = openIdAuthenticator.extractCredentials( + new FakeRestRequest( + ImmutableMap.of("Authorization", "Bearer " + TestJwts.MC_COY_SIGNED_OCT_1, "Content-Type", APPLICATION_JWT), + new HashMap<>() + ).asSecurityRequest(), + null + ); + Assert.assertNotNull(creds); + assertThat(creds.getUsername(), is(MCCOY_SUBJECT)); + assertThat(creds.getBackendRoles().size(), is(0)); + assertThat(creds.getAttributes().size(), is(2)); + } } diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java index 6b4fb9ef81..25a919eebb 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/MockIpdServer.java @@ -29,6 +29,7 @@ import javax.net.ssl.SSLSocket; import javax.net.ssl.TrustManagerFactory; +import org.apache.http.Header; import org.apache.http.HttpConnectionFactory; import org.apache.http.HttpException; import org.apache.http.HttpRequest; @@ -51,9 +52,18 @@ import org.opensearch.security.test.helper.network.SocketUtils; import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jwt.JWTClaimsSet; + +import static org.apache.http.entity.ContentType.APPLICATION_JSON; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.OpenIdConstants.APPLICATION_JWT; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.MCCOY_SUBJECT; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.TEST_ROLES_STRING; +import static com.amazon.dlic.auth.http.jwt.keybyoidc.TestJwts.createSigned; class MockIpdServer implements Closeable { final static String CTX_DISCOVER = "/discover"; + final static String CTX_USERINFO_SIGNED = "/api/oauth/userinfo/signed"; + final static String CTX_USERINFO = "/api/oauth/userinfo"; final static String CTX_KEYS = "/api/oauth/keys"; private final HttpServer httpServer; @@ -91,8 +101,21 @@ public void handle(HttpRequest request, HttpResponse response, HttpContext conte handleKeysRequest(request, response, context); } - }); + }) + .registerHandler(CTX_USERINFO, new HttpRequestHandler() { + + @Override + public void handle(HttpRequest request, HttpResponse response, HttpContext context) throws HttpException, IOException { + handleUserinfoRequest(request, response, context); + } + }) + .registerHandler(CTX_USERINFO_SIGNED, new HttpRequestHandler() { + @Override + public void handle(HttpRequest request, HttpResponse response, HttpContext context) throws HttpException, IOException { + handleUserinfoRequestSigned(request, response, context); + } + }); if (ssl) { serverBootstrap = serverBootstrap.setSslContext(createSSLContext()).setSslSetupHandler(new SSLServerSetupHandler() { @@ -145,6 +168,14 @@ public String getDiscoverUri() { return uri + CTX_DISCOVER; } + public String getUserinfoUri() { + return uri + CTX_USERINFO; + } + + public String getUserinfoSignedUri() { + return uri + CTX_USERINFO_SIGNED; + } + public String getJwksUri() { return uri + CTX_KEYS; } @@ -164,6 +195,54 @@ protected void handleDiscoverRequest(HttpRequest request, HttpResponse response, ); } + protected void handleUserinfoRequestSigned(HttpRequest request, HttpResponse response, HttpContext context) throws HttpException, + IOException { + + Header headers = request.getFirstHeader("Authorization"); + String requestToken; + + String authHeaderValue = headers.getValue(); + if (authHeaderValue.startsWith("Bearer")) { + requestToken = authHeaderValue.substring(7).trim(); + } else { + response.setStatusCode(401); + return; + } + + response.setStatusCode(200); + response.setHeader("content-type", APPLICATION_JWT); + + // We have to manually form the response content since we don't want to need to pass settings info into the test class + JWTClaimsSet claims = new JWTClaimsSet.Builder().claim("sub", MCCOY_SUBJECT) + .claim("roles", TEST_ROLES_STRING) + .claim("iss", "http://www.example.com") + .claim("aud", "testClient") + .build(); + String content = createSigned(claims, TestJwk.OCT_1); + response.setEntity(new StringEntity(content)); + } + + protected void handleUserinfoRequest(HttpRequest request, HttpResponse response, HttpContext context) throws HttpException, + IOException { + Header headers = request.getFirstHeader("Authorization"); + String requestToken; + + String authHeaderValue = headers.getValue(); + if (authHeaderValue.startsWith("Bearer")) { + requestToken = authHeaderValue.substring(7).trim(); + } else { + response.setStatusCode(401); + return; + } + + response.setStatusCode(200); + response.setHeader("content-type", APPLICATION_JSON.getMimeType()); + + // We have to manually form the response content since we don't want to need to pass settings info into the test class + JWTClaimsSet claims = new JWTClaimsSet.Builder().claim("sub", MCCOY_SUBJECT).claim("roles", TEST_ROLES_STRING).build(); + response.setEntity(new StringEntity(claims.toString())); + } + protected void handleKeysRequest(HttpRequest request, HttpResponse response, HttpContext context) throws HttpException, IOException { response.setStatusCode(200); response.setEntity(new StringEntity(jwks.toString(false))); diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java index 6e3548926e..0d57ac6a01 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/SingleKeyHTTPJwtKeyByOpenIdConnectAuthenticatorTest.java @@ -33,7 +33,7 @@ public void basicTest() throws Exception { try { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); - HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + HTTPOpenIdAuthenticator jwtAuth = new HTTPOpenIdAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) @@ -60,7 +60,7 @@ public void wrongSigTest() throws Exception { try { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); - HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + HTTPOpenIdAuthenticator jwtAuth = new HTTPOpenIdAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_X), new HashMap()) @@ -83,7 +83,7 @@ public void noAlgTest() throws Exception { try { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); - HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + HTTPOpenIdAuthenticator jwtAuth = new HTTPOpenIdAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.MC_COY_SIGNED_RSA_1), new HashMap()) @@ -109,7 +109,7 @@ public void mismatchedAlgTest() throws Exception { try { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); - HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + HTTPOpenIdAuthenticator jwtAuth = new HTTPOpenIdAuthenticator(settings, null); AuthCredentials creds = jwtAuth.extractCredentials( new FakeRestRequest(ImmutableMap.of("Authorization", TestJwts.NoKid.MC_COY_SIGNED_RSA_1), new HashMap()) @@ -132,7 +132,7 @@ public void keyExchangeTest() throws Exception { Settings settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); - HTTPJwtKeyByOpenIdConnectAuthenticator jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + HTTPOpenIdAuthenticator jwtAuth = new HTTPOpenIdAuthenticator(settings, null); try { AuthCredentials creds = jwtAuth.extractCredentials( @@ -183,7 +183,7 @@ public void keyExchangeTest() throws Exception { mockIdpServer = new MockIpdServer(TestJwk.Jwks.RSA_2); settings = Settings.builder().put("openid_connect_url", mockIdpServer.getDiscoverUri()).build(); // port changed - jwtAuth = new HTTPJwtKeyByOpenIdConnectAuthenticator(settings, null); + jwtAuth = new HTTPOpenIdAuthenticator(settings, null); try { AuthCredentials creds = jwtAuth.extractCredentials( diff --git a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java index acc6a0dba9..26e0360779 100644 --- a/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java +++ b/src/test/java/com/amazon/dlic/auth/http/jwt/keybyoidc/TestJwts.java @@ -32,14 +32,24 @@ class TestJwts { static final Set TEST_ROLES = ImmutableSet.of("role1", "role2"); static final String TEST_ROLES_STRING = String.join(",", TEST_ROLES); + static final String OIDC_TEST_AUD = "testClient"; + + static final String OIDC_TEST_ISS = "http://www.example.com"; + static final String TEST_AUDIENCE = "TestAudience"; static final String MCCOY_SUBJECT = "Leonard McCoy"; + static final String STEPHEN_SUBJECT = "Stephen Crawford"; + static final String TEST_ISSUER = "TestIssuer"; + static final JWTClaimsSet STEPHEN = create(STEPHEN_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY_OIDC = create(MCCOY_SUBJECT, OIDC_TEST_AUD, OIDC_TEST_ISS, ROLES_CLAIM, TEST_ROLES_STRING); + static final JWTClaimsSet MC_COY_2 = create(MCCOY_SUBJECT, TEST_AUDIENCE, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); static final JWTClaimsSet MC_COY_NO_AUDIENCE = create(MCCOY_SUBJECT, null, TEST_ISSUER, ROLES_CLAIM, TEST_ROLES_STRING); @@ -58,6 +68,8 @@ class TestJwts { static final String MC_COY_SIGNED_OCT_1 = createSigned(MC_COY, TestJwk.OCT_1); + static final String MC_COY_SIGNED_OCT_1_OIDC = createSigned(MC_COY_OIDC, TestJwk.OCT_1); + static final String MC_COY_SIGNED_OCT_2 = createSigned(MC_COY_2, TestJwk.OCT_2); static final String MC_COY_SIGNED_NO_AUDIENCE_OCT_1 = createSigned(MC_COY_NO_AUDIENCE, TestJwk.OCT_1); @@ -67,6 +79,8 @@ class TestJwts { static final String MC_COY_SIGNED_RSA_1 = createSigned(MC_COY, TestJwk.RSA_1); + static final String STEPHEN_RSA_1 = createSigned(STEPHEN, TestJwk.RSA_1); + static final String MC_COY_SIGNED_RSA_X = createSigned(MC_COY, TestJwk.RSA_X); static final String MC_COY_EXPIRED_SIGNED_OCT_1 = createSigned(MC_COY_EXPIRED, TestJwk.OCT_1);