Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ public class JwtClientAuthentication {

// no signature check with invalid algorithms
private static final Set<Algorithm> NOT_SUPPORTED_ALGORITHMS = Set.of(Algorithm.NONE, JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512);
private static final Set<String> JWT_REQUIRED_CLAIMS = Set.of(ClaimConstants.ISS, ClaimConstants.SUB, ClaimConstants.AUD,
private static final Set<String> JWT_RFC7523_CLAIMS = Set.of(ClaimConstants.ISS, ClaimConstants.SUB, ClaimConstants.AUD,
ClaimConstants.EXPIRY_IN_SECONDS);
private static final Set<String> JWT_OIDC_CLAIMS = Set.of(ClaimConstants.ISS, ClaimConstants.SUB, ClaimConstants.AUD,
ClaimConstants.EXPIRY_IN_SECONDS, ClaimConstants.JTI);

private final KeyInfoService keyInfoService;
Expand Down Expand Up @@ -160,12 +162,12 @@ public boolean validateClientJwt(Map<String, String[]> requestParameters, Client
// Validate token according to private_key_jwt with OIDC
return clientId.equals(validateClientJWToken(clientJWT, oidcMetadataFetcher == null ? new JWKSet() :
JWKSet.parse(oidcMetadataFetcher.fetchWebKeySet(clientJwtConfiguration).getKeySetMap()),
clientId, clientId, keyInfoService.getTokenEndpointUrl()).getSubject());
JWT_OIDC_CLAIMS, clientId, clientId, keyInfoService.getTokenEndpointUrl()).getSubject());
} else {
// Check if we found trust for private_key_jwt with RFC 7523. We allow client_id (from request) != sub (client_assertion)
ClientJwtCredential jwtFederation = getClientJwtFederation(clientJwtConfiguration, clientClaims);
if (jwtFederation != null) {
return validateFederatedClientWT(clientJWT, clientClaims, jwtFederation);
return validateFederatedClientJWT(clientJWT, clientClaims, jwtFederation);
}
throw new BadCredentialsException("Wrong client_assertion");
}
Expand Down Expand Up @@ -217,11 +219,12 @@ public static String getClientIdOidcAssertion(String clientAssertion) {
}
}

private boolean validateFederatedClientWT(JWT jwtAssertion, JWTClaimsSet clientClaims, ClientJwtCredential jwtFederation) throws OidcMetadataFetchingException, ParseException {
// Validate federated client with RFC 7523
private boolean validateFederatedClientJWT(JWT jwtAssertion, JWTClaimsSet clientClaims, ClientJwtCredential jwtFederation) throws OidcMetadataFetchingException, ParseException {
try {
JWKSet jwkSet = retrieveJwkSet(clientClaims);
String expectedAud = Optional.ofNullable(jwtFederation.getAudience()).orElse(keyInfoService.getTokenEndpointUrl());
return validateClientJWToken(jwtAssertion, jwkSet, jwtFederation.getSubject(), jwtFederation.getIssuer(), expectedAud) != null;
return validateClientJWToken(jwtAssertion, jwkSet, JWT_RFC7523_CLAIMS, jwtFederation.getSubject(), jwtFederation.getIssuer(), expectedAud) != null;
} catch (MalformedURLException | IllegalArgumentException | URISyntaxException e) {
return false;
}
Expand Down Expand Up @@ -270,7 +273,7 @@ private JWKSet retrieveJwkSet(JWTClaimsSet clientClaims) throws MalformedURLExce
}
}

private JWTClaimsSet validateClientJWToken(JWT jwtAssertion, JWKSet jwkSet, String expectedSub, String expectIss, String expectedAud) {
private JWTClaimsSet validateClientJWToken(JWT jwtAssertion, JWKSet jwkSet, Set<String> requiredClaims, String expectedSub, String expectIss, String expectedAud) {
if (Optional.ofNullable(jwkSet).orElse(new JWKSet()).isEmpty()) {
throw new BadCredentialsException("Bad empty jwk_set");
}
Expand All @@ -284,7 +287,7 @@ private JWTClaimsSet validateClientJWToken(JWT jwtAssertion, JWKSet jwkSet, Stri
jwtProcessor.setJWSKeySelector(keySelector);

JWTClaimsSet.Builder claimSetBuilder = new JWTClaimsSet.Builder().issuer(expectIss).subject(expectedSub);
jwtProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<>(expectedAud, claimSetBuilder.build(), JWT_REQUIRED_CLAIMS));
jwtProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<>(expectedAud, claimSetBuilder.build(), requiredClaims));

try {
return jwtProcessor.process(jwtAssertion, null);
Expand Down
Loading