Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable PKCE in OIDC federated login flow #181

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ private OIDCAuthenticatorConstants() {
public static final String AMPERSAND_SIGN = "&";
public static final String EQUAL_SIGN = "=";

public static final String PKCE_CODE_VERIFIER = "PKCE_CODE_VERIFIER";
public static final String IS_PKCE_ENABLED = "IsPKCEEnabled";

/**
* This class holds the constants related to authenticator configuration parameters.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -138,6 +141,7 @@ public class OpenIDConnectAuthenticator extends AbstractApplicationAuthenticator

private static final Log LOG = LogFactory.getLog(OpenIDConnectAuthenticator.class);
private static final String OIDC_DIALECT = "http://wso2.org/oidc/claim";
private static final String PKCE_CODE_CHALLENGE_METHOD = "S256";

private static final String DYNAMIC_PARAMETER_LOOKUP_REGEX = "\\$\\{(\\w+)\\}";
private static final String IS_API_BASED = "IS_API_BASED";
Expand All @@ -150,6 +154,11 @@ public class OpenIDConnectAuthenticator extends AbstractApplicationAuthenticator
private static final String[] NON_USER_ATTRIBUTES = new String[]{"at_hash", "iss", "iat", "exp", "aud", "azp"};
private static final String AUTHENTICATOR_MESSAGE = "authenticatorMessage";

private static final String IS_PKCE_ENABLED_NAME = "isPKCEEnabled";
private static final String IS_PKCE_ENABLED_DISPLAY_NAME = "Enable PKCE";
private static final String IS_PKCE_ENABLED_DESCRIPTION = "Specifies that PKCE should be used for client authentication";
private static final String TYPE_BOOLEAN = "boolean";

@Override
public AuthenticatorFlowStatus process(HttpServletRequest request, HttpServletResponse response,
AuthenticationContext context)
Expand Down Expand Up @@ -514,6 +523,8 @@ protected String prepareLoginPage(HttpServletRequest request, AuthenticationCont
context.setProperty(OIDCAuthenticatorConstants.AUTHENTICATOR_NAME + STATE_PARAM_SUFFIX, state);
String nonce = UUID.randomUUID().toString();
context.setProperty(OIDC_FEDERATION_NONCE, nonce);
boolean isPKCEEnabled = Boolean.parseBoolean(
authenticatorProperties.get(OIDCAuthenticatorConstants.IS_PKCE_ENABLED));

OAuthClientRequest authzRequest;

Expand Down Expand Up @@ -585,6 +596,15 @@ protected String prepareLoginPage(HttpServletRequest request, AuthenticationCont
loginPage = loginPage + "&fidp=" + domain;
}

// If PKCE is enabled, add code_challenge and code_challenge_method to the request.
if (isPKCEEnabled) {
String codeVerifier = generateCodeVerifier();
context.setProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER, codeVerifier);
String codeChallenge = generateCodeChallenge(codeVerifier);
loginPage += "&code_challenge=" + codeChallenge + "&code_challenge_method="
+ PKCE_CODE_CHALLENGE_METHOD;
}

if (StringUtils.isNotBlank(queryString)) {
if (!queryString.startsWith("&")) {
loginPage = loginPage + "&" + queryString;
Expand Down Expand Up @@ -1467,6 +1487,9 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
String clientId = authenticatorProperties.get(OIDCAuthenticatorConstants.CLIENT_ID);
String clientSecret = authenticatorProperties.get(OIDCAuthenticatorConstants.CLIENT_SECRET);
String tokenEndPoint = getTokenEndpoint(authenticatorProperties);
boolean isPKCEEnabled = Boolean.parseBoolean(
authenticatorProperties.get(OIDCAuthenticatorConstants.IS_PKCE_ENABLED));
String codeVerifier = (String) context.getProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER);

String callbackUrl = getCallbackUrlFromInitialRequestParamMap(context);
if (StringUtils.isBlank(callbackUrl)) {
Expand All @@ -1489,9 +1512,20 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
"authentication scheme.");
}

accessTokenRequest = OAuthClientRequest.tokenLocation(tokenEndPoint).setGrantType(GrantType
.AUTHORIZATION_CODE).setRedirectURI(callbackUrl).setCode(authzResponse.getCode())
.buildBodyMessage();
OAuthClientRequest.TokenRequestBuilder tokenRequestBuilder = OAuthClientRequest
.tokenLocation(tokenEndPoint)
.setGrantType(GrantType.AUTHORIZATION_CODE)
.setRedirectURI(callbackUrl)
.setCode(authzResponse.getCode());

if (isPKCEEnabled) {
if (StringUtils.isEmpty(codeVerifier)) {
throw new AuthenticationFailedException("PKCE is enabled, but the code verifier is not found.");
}
tokenRequestBuilder.setParameter("code_verifier", codeVerifier);
}

accessTokenRequest = tokenRequestBuilder.buildBodyMessage();
String base64EncodedCredential = new String(Base64.encodeBase64((clientId + ":" +
clientSecret).getBytes()));
accessTokenRequest.addHeader(OAuth.HeaderType.AUTHORIZATION, "Basic " + base64EncodedCredential);
Expand All @@ -1501,10 +1535,22 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
LOG.debug("Authenticating to token endpoint: " + tokenEndPoint + " including client credentials "
+ "in request body.");
}
accessTokenRequest = OAuthClientRequest.tokenLocation(tokenEndPoint).setGrantType(GrantType
.AUTHORIZATION_CODE).setClientId(clientId).setClientSecret(clientSecret).setRedirectURI
(callbackUrl).setCode(authzResponse.getCode()).buildBodyMessage();
OAuthClientRequest.TokenRequestBuilder tokenRequestBuilder = OAuthClientRequest
.tokenLocation(tokenEndPoint)
.setGrantType(GrantType.AUTHORIZATION_CODE)
.setClientId(clientId)
.setClientSecret(clientSecret)
.setRedirectURI(callbackUrl)
.setCode(authzResponse.getCode());
if (isPKCEEnabled) {
if (StringUtils.isEmpty(codeVerifier)) {
throw new AuthenticationFailedException("PKCE is enabled, but the code verifier is not found.");
}
tokenRequestBuilder.setParameter("code_verifier", codeVerifier);
}
accessTokenRequest = tokenRequestBuilder.buildBodyMessage();
}
context.removeProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER);
// set 'Origin' header to access token request.
if (accessTokenRequest != null) {
// fetch the 'Hostname' configured in carbon.xml
Expand All @@ -1522,7 +1568,6 @@ protected OAuthClientRequest getAccessTokenRequest(AuthenticationContext context
} catch (URLBuilderException e) {
throw new RuntimeException("Error occurred while building URL in tenant qualified mode.", e);
}

return accessTokenRequest;
}

Expand Down Expand Up @@ -1692,6 +1737,15 @@ public List<Property> getConfigurationProperties() {
enableBasicAuth.setDisplayOrder(10);
configProperties.add(enableBasicAuth);

Property enablePKCE = new Property();
enablePKCE.setName(IS_PKCE_ENABLED_NAME);
enablePKCE.setDisplayName(IS_PKCE_ENABLED_DISPLAY_NAME);
enablePKCE.setRequired(false);
enablePKCE.setDescription(IS_PKCE_ENABLED_DESCRIPTION);
enablePKCE.setType(TYPE_BOOLEAN);
enablePKCE.setDisplayOrder(10);
configProperties.add(enablePKCE);

return configProperties;
}

Expand Down Expand Up @@ -2029,6 +2083,39 @@ private String getCallbackUrlFromInitialRequestParamMap(AuthenticationContext co
return null;
}

/**
* Generate code verifier for PKCE
*
* @return code verifier
*/
private String generateCodeVerifier() {

SecureRandom secureRandom = new SecureRandom();
byte[] codeVerifier = new byte[32];
secureRandom.nextBytes(codeVerifier);
return java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(codeVerifier);
}

/**
* Generate code challenge for PKCE
*
* @param codeVerifier code verifier
* @return code challenge
* @throws AuthenticationFailedException
*/
private String generateCodeChallenge(String codeVerifier) throws AuthenticationFailedException {

try {
byte[] bytes = codeVerifier.getBytes("US-ASCII");
MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
messageDigest.update(bytes, 0, bytes.length);
byte[] digest = messageDigest.digest();
return java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
} catch (UnsupportedEncodingException | NoSuchAlgorithmException e) {
throw new AuthenticationFailedException("Error while generating code challenge", e);
}
}

private AuthenticatorFlowStatus processLogout(HttpServletRequest request, HttpServletResponse response,
AuthenticationContext context) throws LogoutFailedException {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.oltu.oauth2.common.exception.OAuthSystemException;
import org.mockito.Matchers;
import org.mockito.Mock;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.core.classloader.annotations.SuppressStaticInitializationFor;
import org.powermock.modules.testng.PowerMockTestCase;
Expand Down Expand Up @@ -497,6 +498,7 @@ public void testPassProcessAuthenticationResponse() throws Exception {
when(mockOAuthClient.accessToken(Matchers.<OAuthClientRequest>anyObject()))
.thenReturn(mockOAuthJSONAccessTokenResponse);
when(mockOAuthJSONAccessTokenResponse.getParam(anyString())).thenReturn(idToken);
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
openIDConnectAuthenticator.processAuthenticationResponse(mockServletRequest,
mockServletResponse, mockAuthenticationContext);

Expand All @@ -521,6 +523,7 @@ public void testPassProcessAuthenticationResponseWithNonce() throws Exception {
when(mockAuthenticationContext.getExternalIdP()).thenReturn(externalIdPConfig);
when(externalIdPConfig.getIdentityProvider()).thenReturn(identityProvider);
when(identityProvider.getIdpProperties()).thenReturn(identityProviderProperties);
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
when(openIDConnectAuthenticatorDataHolder.getClaimMetadataManagementService()).thenReturn
(claimMetadataManagementService);
when(mockAuthenticationContext.getExternalIdP()).thenReturn(externalIdPConfig);
Expand All @@ -539,6 +542,29 @@ public void testPassProcessAuthenticationResponseWithNonce() throws Exception {
"Invalid Id token in the authentication context.");
}

// /**
// * Test whether the token request contains the code verifier when PKCE is enabled.
// *
// * @throws URLBuilderException
// * @throws AuthenticationFailedException
// */
// @Test()
// public void testGetAccessTokenRequestWithPKCE() throws URLBuilderException, AuthenticationFailedException {
// mockAuthenticationRequestContext(mockAuthenticationContext);
// authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "true");
// when(mockAuthenticationContext.getProperty(OIDCAuthenticatorConstants.PKCE_CODE_VERIFIER))
// .thenReturn("sample_code_verifier");
// when(mockOAuthzResponse.getCode()).thenReturn("abc");
// mockStatic(ServiceURLBuilder.class);
// ServiceURLBuilder serviceURLBuilder = mock(ServiceURLBuilder.class);
// when(ServiceURLBuilder.create()).thenReturn(serviceURLBuilder);
// when(serviceURLBuilder.build()).thenReturn(serviceURL);
// when(serviceURL.getAbsolutePublicURL()).thenReturn("http://localhost:9443");
// OAuthClientRequest request = openIDConnectAuthenticator
// .getAccessTokenRequest(mockAuthenticationContext, mockOAuthzResponse);
// assertTrue(request.getBody().contains("code_verifier=sample_code_verifier"));
// }

@Test
public void testPassProcessAuthenticationResponseWithoutAccessToken() throws Exception {

Expand All @@ -558,6 +584,7 @@ public void testPassProcessAuthenticationWithBlankCallBack() throws Exception {

setupTest();
authenticatorProperties.put("callbackUrl", " ");
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
mockStatic(IdentityUtil.class);
when(IdentityUtil.getServerURL(FrameworkConstants.COMMONAUTH, true, true))
.thenReturn("http:/localhost:9443/oauth2/callback");
Expand Down Expand Up @@ -618,6 +645,7 @@ public void testPassProcessAuthenticationWithParamValue() throws Exception {
setupTest();
when(LoggerUtils.isDiagnosticLogsEnabled()).thenReturn(true);
authenticatorProperties.put("callbackUrl", "http://localhost:8080/playground2/oauth2client");
authenticatorProperties.put(OIDCAuthenticatorConstants.IS_PKCE_ENABLED, "false");
Map<String, String> paramMap = new HashMap<>();
paramMap.put("redirect_uri", "http:/localhost:9443/oauth2/redirect");
when(mockAuthenticationContext.getProperty(OIDC_PARAM_MAP_STRING)).thenReturn(paramMap);
Expand Down
Loading