Skip to content

Commit

Permalink
Add audience validation for api and api product
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirishikesan committed Apr 2, 2024
1 parent 2bc429e commit 8bbc5a3
Show file tree
Hide file tree
Showing 22 changed files with 364 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ public void setSoapToRestSequences(List<SOAPToRestSequence> soapToRestSequences)

private String audience;

private Set<String> audiences;

public String getAudience() {
return audience;
}
Expand All @@ -236,6 +238,23 @@ public void setAudience(String audience) {
this.audience = audience;
}

/**
* To get the audiences for jwt validation
*
* @return audiences of the API
*/
public Set<String> getAudiences() {
return audiences;
}

/**
* To set the audiences for jwt validation
*
*/
public void setAudiences(Set<String> audiences) {
this.audiences = audiences;
}

public void setEnvironmentList(Set<String> environmentList) {
this.environmentList = environmentList;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ public class APIProduct implements Serializable {
* Used to set the workflow status in lifecycle state change workflow
*/
private String workflowStatus = null;

/**
* Used to set the audiences values in jwt audience validation
*/
private Set<String> audiences;
private Boolean isDefaultVersion = true;
private boolean isPublishedDefaultVersion = false;
public APIProduct(){}
Expand Down Expand Up @@ -222,6 +227,24 @@ public void setType(String type) {
this.type = type;
}
}

/**
* To get the audiences for jwt validation
*
* @return audiences of the API
*/
public Set<String> getAudiences() {
return audiences;
}

/**
* To set the audiences for jwt validation
*
*/
public void setAudiences(Set<String> audiences) {
this.audiences = audiences;
}

public String getBusinessOwner() {
return businessOwner;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@
import org.wso2.carbon.metrics.manager.Timer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
Expand All @@ -94,6 +97,7 @@ public class APIAuthenticationHandler extends AbstractHandler implements Managed
private SynapseEnvironment synapseEnvironment;

private String authorizationHeader;
private Set<String> audiences = new HashSet<>();;
private String apiKeyHeader;
private String apiSecurity;
private String apiLevelPolicy;
Expand Down Expand Up @@ -202,6 +206,26 @@ public void setAuthorizationHeader(String authorizationHeader) {
this.authorizationHeader = authorizationHeader;
}

/**
* To set the audiences for api.
*
* @param audiences the audiences of the API request.
*/
public void setAudiences(String audiences) {
if (!StringUtils.isEmpty(audiences)) {
this.audiences = new HashSet<>(Arrays.asList(audiences.split(",")));
}
}

/**
* To get the audiences of an api.
*
* @return API level audiences for JWT validation.
*/
public Set<String> getAudiences() {
return audiences;
}

/**
* To get the Api Key Header.
*
Expand Down Expand Up @@ -343,7 +367,7 @@ protected void initializeAuthenticators() {
}
if (isOAuthProtected) {
Authenticator authenticator = new OAuthAuthenticator(authorizationHeader, isOAuthBasicAuthMandatory,
removeOAuthHeadersFromOutMessage);
removeOAuthHeadersFromOutMessage, this.getAudiences());
authenticator.init(synapseEnvironment);
authenticators.add(authenticator);
}
Expand Down Expand Up @@ -689,6 +713,7 @@ private void handleAuthFailure(MessageContext messageContext, APISecurityExcepti
status = HttpStatus.SC_INTERNAL_SERVER_ERROR;
} else if (e.getErrorCode() == APISecurityConstants.API_AUTH_INCORRECT_API_RESOURCE ||
e.getErrorCode() == APISecurityConstants.API_AUTH_FORBIDDEN ||
e.getErrorCode() == APISecurityConstants.API_OAUTH_INVALID_AUDIENCES ||
e.getErrorCode() == APISecurityConstants.INVALID_SCOPE) {
status = HttpStatus.SC_FORBIDDEN;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public class APISecurityConstants {
public static final int API_AUTH_MISSING_OPEN_API_DEF = 900911;
public static final String API_AUTH_MISSING_OPEN_API_DEF_ERROR_MESSAGE = "Internal Server Error";

public static final int API_OAUTH_INVALID_AUDIENCES = 900912;
public static final String API_OAUTH_INVALID_AUDIENCES_MESSAGE = "The access token does not allow you to access the requested resource";

public static final int OAUTH_TEMPORARY_SERVER_ERROR = 900424;
public static final String OAUTH_TEMPORARY_SERVER_ERROR_MESSAGE = "Temporary Server Error";

Expand Down Expand Up @@ -112,6 +115,9 @@ public static final String getAuthenticationFailureMessage(int errorCode) {
case API_AUTH_INCORRECT_ACCESS_TOKEN_TYPE:
errorMessage = API_AUTH_INCORRECT_ACCESS_TOKEN_TYPE_MESSAGE;
break;
case API_OAUTH_INVALID_AUDIENCES:
errorMessage = API_OAUTH_INVALID_AUDIENCES_MESSAGE;
break;
case API_BLOCKED:
errorMessage = API_BLOCKED_MESSAGE;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.cache.Cache;
Expand All @@ -81,6 +82,7 @@ public class JWTValidator {
private APIKeyValidator apiKeyValidator;
private boolean jwtGenerationEnabled;
private AbstractAPIMgtGatewayJWTGenerator apiMgtGatewayJWTGenerator;
private Set<String> audiences;
ExtendedJWTConfigurationDto jwtConfigurationDto;
JWTValidationService jwtValidationService;
private static volatile long ttl = -1L;
Expand Down Expand Up @@ -131,6 +133,12 @@ protected JWTValidator(String apiLevelPolicy, boolean isGatewayTokenCacheEnabled
this.jwtValidationService = jwtValidationService;
}

public JWTValidator(APIKeyValidator apiKeyValidator, String tenantDomain, Set<String> audiences)
throws APIManagementException {
this(apiKeyValidator, tenantDomain);
this.setAudiences(audiences);
}

/**
* Authenticates the given request with a JWT token to see if an API consumer is allowed to access
* a particular API or not.
Expand Down Expand Up @@ -273,6 +281,7 @@ public AuthenticationContext authenticate(SignedJWTInfo signedJWTInfo, MessageCo
}
// Validate scopes
validateScopes(apiContext, apiVersion, matchingResource, httpMethod, jwtValidationInfo, signedJWTInfo);
validateAudiences(signedJWTInfo);
synCtx.setProperty(APIMgtGatewayConstants.SCOPES, jwtValidationInfo.getScopes().toString());
synCtx.setProperty(APIMgtGatewayConstants.JWT_CLAIMS, jwtValidationInfo.getClaims());
if (apiKeyValidationInfoDTO.isAuthorized()) {
Expand Down Expand Up @@ -346,6 +355,35 @@ private long getTtl() {
}
}

public Set<String> getAudiences() {

return audiences;
}

public void setAudiences(Set<String> audiences) {

this.audiences = audiences;
}

private boolean validateAudiences(SignedJWTInfo signedJWTInfo) throws APISecurityException {
if (this.getAudiences() == null || this.getAudiences().isEmpty() ||
this.getAudiences().contains(APIConstants.ALL_AUDIENCES)) {
return true;
}
List<String> jwtAudienceClaim = signedJWTInfo.getJwtClaimsSet().getAudience();
if (jwtAudienceClaim == null) {
throw new APISecurityException(APISecurityConstants.API_OAUTH_INVALID_AUDIENCES,
APISecurityConstants.API_OAUTH_INVALID_AUDIENCES_MESSAGE);
}
for (String aud : this.getAudiences()) {
if (jwtAudienceClaim.contains(aud)) {
return true;
}
}
throw new APISecurityException(APISecurityConstants.API_OAUTH_INVALID_AUDIENCES,
APISecurityConstants.API_OAUTH_INVALID_AUDIENCES_MESSAGE);
}

private String generateAndRetrieveJWTToken(String tokenSignature, JWTInfoDto jwtInfoDto)
throws APISecurityException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.apache.axiom.om.OMElement;
import org.apache.axis2.AxisFault;
import org.apache.axis2.Constants;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
Expand All @@ -27,7 +29,9 @@
import org.apache.synapse.core.SynapseEnvironment;
import org.apache.synapse.core.axis2.Axis2MessageContext;
import org.apache.synapse.rest.RESTConstants;
import org.wso2.carbon.apimgt.api.APIConsumer;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.model.ApiTypeWrapper;
import org.wso2.carbon.apimgt.common.gateway.constants.GraphQLConstants;
import org.wso2.carbon.apimgt.common.gateway.dto.JWTConfigurationDto;
import org.wso2.carbon.apimgt.gateway.APIMgtGatewayConstants;
Expand Down Expand Up @@ -56,6 +60,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.cache.Cache;

/**
Expand Down Expand Up @@ -83,11 +88,18 @@ public class OAuthAuthenticator implements Authenticator {
private boolean removeDefaultAPIHeaderFromOutMessage = true;
private String requestOrigin;
private boolean isMandatory;
private Set<String> audiences;
private ThreadLocal<String> remainingAuthHeader = new ThreadLocal<String>();

public OAuthAuthenticator() {
}

public OAuthAuthenticator(String authorizationHeader, boolean isMandatory, boolean removeOAuthHeader,
Set<String> audiences) {
this(authorizationHeader, isMandatory, removeOAuthHeader);
this.setAudiences(audiences);
}

public OAuthAuthenticator(String authorizationHeader, boolean isMandatory, boolean removeOAuthHeader) {
this.securityHeader = authorizationHeader;
this.removeOAuthHeadersFromOutMessage = removeOAuthHeader;
Expand Down Expand Up @@ -118,7 +130,7 @@ public AuthenticationResponse authenticate(MessageContext synCtx) throws APIMana
}

if (jwtValidator == null) {
this.jwtValidator = new JWTValidator(this.keyValidator, tenantDomain);
this.jwtValidator = new JWTValidator(this.keyValidator, tenantDomain, this.getAudiences());
}

config = getApiManagerConfiguration();
Expand Down Expand Up @@ -471,6 +483,14 @@ private void setSecurityContextHeader(String securityContextHeader) {
this.securityContextHeader = securityContextHeader;
}

public void setAudiences(Set<String> audiences) {
this.audiences = audiences;
}

public Set<String> getAudiences() {
return audiences;
}

private boolean isRemoveOAuthHeadersFromOutMessage() {
String value = config.getFirstProperty(APIConstants.REMOVE_OAUTH_HEADERS_FROM_MESSAGE);
if (value != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,8 @@ private Permissions() {
public static final String JWT_AUDIENCES = "JWTAudiences";
public static final String JWT_AUDIENCE = "JWTAudience";
public static final String AUDIENCE = "Audience";
public static final String AUDIENCES = "Audiences";
public static final String ALL_AUDIENCES = "all";
public static final String BASEPATH = "Basepath";
public static final String URN_CHOREO = "urn:choreo:";
public static final String BASE_PATH = "http.base.path";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public final class APIConstants {
public static final String API_OVERVIEW_VISIBLE_TENANTS = "overview_visibleTenants";
public static final String API_OVERVIEW_ENVIRONMENTS = "overview_environments";
public static final String API_OVERVIEW_AUDIENCE = "overview_audience";
public static final String API_OVERVIEW_AUDIENCES = "overview_audiences";
public static final String API_PROVIDER = "Provider";
public static final String API_NAME = "Name";
public static final String API_VERSION_LABEL = "Version";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.wso2.carbon.apimgt.persistence;

import com.google.gson.Gson;
import org.apache.axiom.om.OMAttribute;
import org.apache.axiom.om.OMElement;
import org.apache.axiom.om.util.AXIOMUtil;
Expand Down Expand Up @@ -1003,6 +1004,10 @@ private PublisherAPISearchResult searchPaginatedPublisherAPIs(Registry userRegis
apiInfo.setThumbnail(artifact.getAttribute(APIConstants.API_OVERVIEW_THUMBNAIL_URL));
apiInfo.setVersion(artifact.getAttribute(APIConstants.API_OVERVIEW_VERSION));
apiInfo.setAudience(artifact.getAttribute(APIConstants.API_OVERVIEW_AUDIENCE));
String audiences = artifact.getAttribute(APIConstants.API_OVERVIEW_AUDIENCES);
if (StringUtils.isNotEmpty(audiences)) {
apiInfo.setAudiences(new Gson().fromJson(audiences, Set.class));
}
apiInfo.setCreatedTime(String.valueOf(apiResource.getCreatedTime().getTime()));
apiInfo.setUpdatedTime(apiResource.getLastModified());
apiInfo.setUpdatedBy(apiResource.getLastUpdaterUserName());
Expand Down Expand Up @@ -3379,6 +3384,10 @@ public PublisherAPIProductSearchResult searchAPIProductsForPublisher(Organizatio
info.setVersion(artifact.getAttribute(APIConstants.API_OVERVIEW_VERSION));
info.setApiSecurity(artifact.getAttribute(APIConstants.API_OVERVIEW_API_SECURITY));
info.setThumbnail(artifact.getAttribute(APIConstants.API_OVERVIEW_THUMBNAIL_URL));
String audiences = artifact.getAttribute(APIConstants.API_OVERVIEW_AUDIENCES);
if (StringUtils.isNotEmpty(audiences)) {
info.setAudiences(new Gson().fromJson(audiences, Set.class));
}
info.setBusinessOwner(artifact.getAttribute(APIConstants.API_OVERVIEW_BUSS_OWNER));
info.setBusinessOwnerEmail(artifact.getAttribute(APIConstants.API_OVERVIEW_BUSS_OWNER_EMAIL));
info.setTechnicalOwner(artifact.getAttribute(APIConstants.API_OVERVIEW_TEC_OWNER));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class PublisherAPI extends PublisherAPIInfo {

private String versionTimestamp;
private String audience;
private Set<String> audiences;
private String apiExternalProductionEndpoint;
private String apiExternalSandboxEndpoint;
private String redirectURL;
Expand All @@ -109,6 +110,14 @@ public void setAudience(String audience) {
this.audience = audience;
}

public Set<String> getAudiences() {
return audiences;
}

public void setAudiences(Set<String> audiences) {
this.audiences = audiences;
}

public List<SOAPToRestSequence> getSoapToRestSequences() {
return soapToRestSequences;
}
Expand Down
Loading

0 comments on commit 8bbc5a3

Please sign in to comment.