Skip to content

Commit

Permalink
Fix to retrun oidc claims based on the requested scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
kayathiri4 committed Sep 6, 2024
1 parent b0ab272 commit 82f451e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private Map<String, Object> getUserClaimsInOIDCDialect(OAuthTokenReqMessageConte
userClaimsInOIDCDialect = retrieveClaimsForLocalUser(requestMsgCtx);
} else {
// Get claim map from the cached attributes
userClaimsInOIDCDialect = getOIDCClaimMapFromUserAttributes(userAttributes);
userClaimsInOIDCDialect = getOIDCClaimsFromUserAttributes(userAttributes, requestMsgCtx);
}

Object hasNonOIDCClaimsProperty = requestMsgCtx.getProperty(OIDCConstants.HAS_NON_OIDC_CLAIMS);
Expand Down Expand Up @@ -594,6 +594,27 @@ private Map<String, Object> getOIDCClaimMapFromUserAttributes(Map<ClaimMapping,
return claims;
}

/**
* Get oidc claims mapping.
*
* @param userAttributes User attributes.
* @param requestMsgCtx Request Context.
* @return User attributes Map.
*/
private Map<String, Object> getOIDCClaimsFromUserAttributes(Map<ClaimMapping, String> userAttributes,
OAuthTokenReqMessageContext requestMsgCtx)
throws IdentityOAuth2Exception {

String spTenantDomain = getServiceProviderTenantDomain(requestMsgCtx);
Map<String, String> claims = new HashMap<>();
if (isNotEmpty(userAttributes)) {
for (Map.Entry<ClaimMapping, String> entry : userAttributes.entrySet()) {
claims.put(entry.getKey().getRemoteClaim().getClaimUri(), entry.getValue().toString());
}
}
return OIDCClaimUtil.getMergedUserClaimsInOIDCDialect(spTenantDomain, claims);
}

private Map<String, Object> getUserClaimsInOIDCDialect(String spTenantDomain, String clientId,
AuthenticatedUser authenticatedUser)
throws IdentityApplicationManagementException, IdentityException, UserStoreException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,30 @@ public static Map<String, Object> filterUserClaimsBasedOnConsent(Map<String, Obj
}
}

/**
* Get oidc claims mapping.
*
* @param spTenantDomain Tenant domain.
* @param claims User claims
* @return user attributes.
* @throws IdentityOAuth2Exception If an exception occurred while getting user claims.
*/
public static Map<String, Object> getMergedUserClaimsInOIDCDialect(String spTenantDomain,
Map<String, String> claims)
throws IdentityOAuth2Exception {

Map<String, Object> oidcClaims = new HashMap<>();
try {
oidcClaims = OIDCClaimUtil.getUserClaimsInOIDCDialect(spTenantDomain, claims);
// Merge the initial claims into oidcClaims, while prioritizing the initial claims map.
oidcClaims.putAll(claims);
return oidcClaims;
} catch (ClaimMetadataException e) {
throw new IdentityOAuth2Exception("Error occurred while mapping claims for user: " +
" from userstore.", e);
}
}

public static Map<String, Object> filterUserClaimsBasedOnConsent(Map<String, Object> userClaims,
AuthenticatedUser authenticatedUser,
String clientId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,15 @@ public void testHandleCustomClaimsWithoutClaimsInUserAttributes() throws Excepti
MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager =
mockStatic(JDBCPersistenceManager.class);
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration = mockStatic(
OAuthServerConfiguration.class)) {
OAuthServerConfiguration.class);
MockedStatic<ClaimMetadataHandler> claimMetadataHandler = mockStatic(ClaimMetadataHandler.class);) {
// Create a token request with User Attributes.
JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder();
Map<ClaimMapping, String> userAttributes = new HashMap<>();
userAttributes.put(SAML2BearerGrantHandlerTest.buildClaimMapping(COUNTRY), TestConstants.CLAIM_VALUE1);
userAttributes.put(SAML2BearerGrantHandlerTest.buildClaimMapping(EMAIL), TestConstants.CLAIM_VALUE2);
OAuthTokenReqMessageContext requestMsgCtx = getTokenReqMessageContextForFederatedUser(userAttributes);
getUserClaimsMap(claimMetadataHandler);

// Mock to return all the scopes when the consent is asked for.
UserRealm userRealm = getUserRealmWithUserClaims(USER_CLAIMS_MAP);
Expand All @@ -516,7 +518,8 @@ public void testHandleCustomClaimsWithoutClaimsInRefreshFlow() throws Exception
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration = mockStatic(
OAuthServerConfiguration.class);
MockedStatic<AuthorizationGrantCache> authorizationGrantCache =
mockStatic(AuthorizationGrantCache.class);) {
mockStatic(AuthorizationGrantCache.class);
MockedStatic<ClaimMetadataHandler> claimMetadataHandler = mockStatic(ClaimMetadataHandler.class);) {
JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder();
OAuthTokenReqMessageContext requestMsgCtx = getTokenReqMessageContextForFederatedUser(null);
// Add the relevant oidc claims to scope resource.
Expand All @@ -535,6 +538,7 @@ public void testHandleCustomClaimsWithoutClaimsInRefreshFlow() throws Exception
new AuthorizationGrantCacheEntry(userAttributes);
authorizationGrantCacheEntry.setSubjectClaim(requestMsgCtx.getAuthorizedUser().getUserName());
mockAuthorizationGrantCache(authorizationGrantCacheEntry, authorizationGrantCache);
getUserClaimsMap(claimMetadataHandler);

RefreshTokenValidationDataDO refreshTokenValidationDataDO =
Mockito.mock(RefreshTokenValidationDataDO.class);
Expand Down Expand Up @@ -1169,7 +1173,8 @@ public void testHandleClaimsForOAuthTokenReqMessageContextWithAuthorizationCode(
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration = mockStatic(
OAuthServerConfiguration.class);
MockedStatic<AuthorizationGrantCache> authorizationGrantCache =
mockStatic(AuthorizationGrantCache.class);) {
mockStatic(AuthorizationGrantCache.class);
MockedStatic<ClaimMetadataHandler> claimMetadataHandler = mockStatic(ClaimMetadataHandler.class);) {
JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder();
Map<ClaimMapping, String> userAttributes = new HashMap<>();
userAttributes.put(SAML2BearerGrantHandlerTest.buildClaimMapping(COUNTRY), TestConstants.CLAIM_VALUE1);
Expand All @@ -1179,6 +1184,7 @@ public void testHandleClaimsForOAuthTokenReqMessageContextWithAuthorizationCode(

AuthorizationGrantCacheEntry authorizationGrantCacheEntry = mock(AuthorizationGrantCacheEntry.class);
mockAuthorizationGrantCache(authorizationGrantCacheEntry, authorizationGrantCache);
getUserClaimsMap(claimMetadataHandler);

UserRealm userRealm = getUserRealmWithUserClaims(USER_CLAIMS_MAP);
mockUserRealm(requestMsgCtx.getAuthorizedUser().toString(), userRealm, identityTenantUtil);
Expand Down

0 comments on commit 82f451e

Please sign in to comment.