Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com).
* Copyright (c) 2025-2026, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand Down Expand Up @@ -36,6 +36,7 @@
import java.text.ParseException;
import java.util.Base64;
import java.util.Date;
import java.util.Map;

/**
* JWT token validator for Push notification scenarios.
Expand Down Expand Up @@ -81,6 +82,21 @@ public static JWTClaimsSet getValidatedClaimSet(String jwt, String publicKey) th
}
}

/**
* Validate the JWT token and return the claim values.
*
* @param jwt JWT token to be validated
* @param publicKey Public key used for signing the JWT
* @return Map of claim values
* @throws PushTokenValidationException Error when validating the JWT token
*/
public static Map<String, Object> getValidatedClaims(String jwt, String publicKey)
throws PushTokenValidationException {

JWTClaimsSet claimsSet = getValidatedClaimSet(jwt, publicKey);
return claimsSet.getClaims();
}

/**
* Validate the legitimacy of JWT token.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com).
* Copyright (c) 2025-2026, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand Down Expand Up @@ -31,6 +31,8 @@

import java.text.ParseException;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -141,6 +143,89 @@ public void testGetValidatedClaimSetWithExpiredToken() throws Exception {
}
}

@Test
public void testGetValidatedClaimsWithValidToken() throws Exception {

try (MockedStatic<SignedJWT> mockedStatic = Mockito.mockStatic(SignedJWT.class)) {
mockedStatic.when(() -> SignedJWT.parse(validJwt)).thenReturn(mockSignedJWT);

Map<String, Object> expectedClaims = new HashMap<>();
expectedClaims.put("chg", "e0c3d04c-750b-4301-8f76-e07ebf02e53a");
expectedClaims.put("td", "carbon.super");

when(mockSignedJWT.getJWTClaimsSet()).thenReturn(mockClaimsSet);
when(mockClaimsSet.getExpirationTime()).thenReturn(new Date(System.currentTimeMillis() + 3000000));
when(mockClaimsSet.getNotBeforeTime()).thenReturn(new Date(System.currentTimeMillis() - 3000000));
when(mockSignedJWT.verify(any())).thenReturn(true);
when(mockClaimsSet.getClaims()).thenReturn(expectedClaims);

Map<String, Object> claims = PushChallengeValidator.getValidatedClaims(validJwt, publicKey);
assertNotNull(claims);
}
}

@Test(expectedExceptions = PushTokenValidationException.class)
public void testGetValidatedClaimsWithBlankToken() throws Exception {

PushChallengeValidator.getValidatedClaims("", publicKey);
}

@Test(expectedExceptions = PushTokenValidationException.class)
public void testGetValidatedClaimsWithInvalidJwtToken() throws Exception {

PushChallengeValidator.getValidatedClaims(invalidJwt, publicKey);
}

@Test(expectedExceptions = PushTokenValidationException.class)
public void testGetValidatedClaimsWithExpiredToken() throws Exception {

try (MockedStatic<SignedJWT> mockedStatic = Mockito.mockStatic(SignedJWT.class)) {
mockedStatic.when(() -> SignedJWT.parse(validJwt)).thenReturn(mockSignedJWT);
when(mockSignedJWT.getJWTClaimsSet()).thenReturn(mockClaimsSet);
when(mockClaimsSet.getExpirationTime()).thenReturn(new Date(System.currentTimeMillis() - 3000000));
when(mockSignedJWT.verify(any())).thenReturn(true);
PushChallengeValidator.getValidatedClaims(validJwt, publicKey);
}
}

@Test(expectedExceptions = PushTokenValidationException.class)
public void testGetValidatedClaimsWithInvalidSignature() throws Exception {

try (MockedStatic<SignedJWT> mockedStatic = Mockito.mockStatic(SignedJWT.class)) {
mockedStatic.when(() -> SignedJWT.parse(validJwt)).thenReturn(mockSignedJWT);
when(mockSignedJWT.getJWTClaimsSet()).thenReturn(null);
PushChallengeValidator.getValidatedClaims(validJwt, invalidPublicKey);
}
}

@Test
public void testGetValidatedClaimsReturnsCorrectClaimValues() throws Exception {

try (MockedStatic<SignedJWT> mockedStatic = Mockito.mockStatic(SignedJWT.class)) {
mockedStatic.when(() -> SignedJWT.parse(validJwt)).thenReturn(mockSignedJWT);

Map<String, Object> expectedClaims = new HashMap<>();
expectedClaims.put("td", "carbon.super");
expectedClaims.put("pid", "f5ae6a0d-390a-4eea-a380-8bcc86e4a148");
expectedClaims.put("chg", "e0c3d04c-750b-4301-8f76-e07ebf02e53a");
expectedClaims.put("res", "APPROVED");

when(mockSignedJWT.getJWTClaimsSet()).thenReturn(mockClaimsSet);
when(mockClaimsSet.getExpirationTime()).thenReturn(new Date(System.currentTimeMillis() + 3000000));
when(mockClaimsSet.getNotBeforeTime()).thenReturn(new Date(System.currentTimeMillis() - 3000000));
when(mockSignedJWT.verify(any())).thenReturn(true);
when(mockClaimsSet.getClaims()).thenReturn(expectedClaims);

Map<String, Object> claims = PushChallengeValidator.getValidatedClaims(validJwt, publicKey);
assertNotNull(claims);
Assert.assertEquals(claims.get("td"), "carbon.super");
Assert.assertEquals(claims.get("pid"), "f5ae6a0d-390a-4eea-a380-8bcc86e4a148");
Assert.assertEquals(claims.get("chg"), "e0c3d04c-750b-4301-8f76-e07ebf02e53a");
Assert.assertEquals(claims.get("res"), "APPROVED");
Assert.assertEquals(claims.size(), 4);
}
}

@Test
public void testValidateChallengeWithEmptyClaimsSet() {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com).
* Copyright (c) 2025-2026, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand Down Expand Up @@ -92,6 +92,15 @@ Device registerDevice(RegistrationRequest registrationRequest, String tenantDoma
*/
void editDevice(String deviceId, String path, String value) throws PushDeviceHandlerException;

/**
* Edit the device from mobile.
*
* @param deviceId Device ID.
* @param token Token.
* @throws PushDeviceHandlerException Push Device Handler Exception.
*/
default void editDeviceMobile(String deviceId, String token) throws PushDeviceHandlerException{}

/**
* Get registration discovery data.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public class PushDeviceHandlerConstants {
public static final String DEVICE_REGISTRATION_CONTEXT_VALIDITY_PERIOD =
"PushAuthenticator.DeviceRegistrationContext.ValidityPeriod";
public static final int DEFAULT_DEVICE_REGISTRATION_CONTEXT_VALIDITY_PERIOD = 180;
public static final String DEVICE_NAME = "name";
public static final String DEVICE_TOKEN = "deviceToken";

/**
* Private constructor to prevent initialization of the class.
Expand All @@ -56,7 +58,7 @@ public static class SQLQueries {
public static final String GET_PUBLIC_KEY_BY_ID = "SELECT PUBLIC_KEY FROM IDN_PUSH_DEVICE_STORE " +
"WHERE ID = :ID;";
public static final String UNREGISTER_DEVICE = "DELETE FROM IDN_PUSH_DEVICE_STORE WHERE ID = :ID;";
public static final String EDIT_DEVICE = "UPDATE IDN_PUSH_DEVICE_STORE SET DEVICE_NAME = :DEVICE_NAME; " +
public static final String EDIT_DEVICE = "UPDATE IDN_PUSH_DEVICE_STORE SET DEVICE_NAME = :DEVICE_NAME;, " +
"DEVICE_TOKEN = :DEVICE_TOKEN; WHERE ID = :ID;";
}

Expand Down Expand Up @@ -136,6 +138,10 @@ public enum ErrorMessages {
ERROR_CODE_FAILED_TO_RESOLVE_PUSH_PROVIDER(
"PDH-15013",
"Failed to resolve the correct push provider for the request."
),
ERROR_CODE_DEVICE_EDIT_FAILED(
"PDH-150014",
"Error occurred while updating the device for the device ID: %s."
);

private final String code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@
import java.util.UUID;

import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.DEFAULT_PUSH_PROVIDER;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.DEVICE_NAME;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.DEVICE_TOKEN;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.ErrorMessages.ERROR_CODE_DEVICE_ALREADY_REGISTERED;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.ErrorMessages.ERROR_CODE_DEVICE_EDIT_FAILED;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.ErrorMessages.ERROR_CODE_DEVICE_NOT_FOUND;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.ErrorMessages.ERROR_CODE_DEVICE_NOT_FOUND_FOR_USER_ID;
import static org.wso2.carbon.identity.notification.push.device.handler.constant.PushDeviceHandlerConstants.ErrorMessages.ERROR_CODE_DEVICE_REGISTRATION_FAILED;
Expand Down Expand Up @@ -135,6 +138,11 @@ public Device registerDevice(RegistrationRequest registrationRequest, String ten
device = handleDeviceRegistration(registrationRequest, context);
if (context.isRegistered()) {
deviceRegistrationContextManager.clearContext(registrationRequest.getDeviceId(), tenantDomain);
AUDIT_LOGGER.printAuditLog(
DeviceHandlerAuditLogger.Operation.REGISTER_DEVICE,
deviceId,
device.getUserId()
);
} else {
throw new PushDeviceHandlerClientException(ERROR_CODE_DEVICE_REGISTRATION_FAILED.getCode(),
String.format(ERROR_CODE_DEVICE_REGISTRATION_FAILED.getMessage(), deviceId));
Expand Down Expand Up @@ -245,6 +253,62 @@ public void editDevice(String deviceId, String path, String value) throws PushDe

Device device = getDevice(deviceId);
handleEditDevice(device, path, value);
AUDIT_LOGGER.printAuditLog(
DeviceHandlerAuditLogger.Operation.UPDATE_DEVICE,
deviceId,
device.getUserId()
);
}

@Override
public void editDeviceMobile(String deviceId, String token) throws PushDeviceHandlerException {

Optional<Device> deviceOptional = deviceDAO.getDevice(deviceId);
if (!deviceOptional.isPresent()) {
throw new PushDeviceHandlerClientException(ERROR_CODE_DEVICE_NOT_FOUND.getCode(),
String.format(ERROR_CODE_DEVICE_NOT_FOUND.getMessage(), deviceId));
}
Device device = deviceOptional.get();

Map<String, Object> claims;
try {
claims = PushChallengeValidator.getValidatedClaims(token, device.getPublicKey());
} catch (PushTokenValidationException e) {
throw new PushDeviceHandlerClientException(ERROR_CODE_TOKEN_CLAIM_VERIFICATION_FAILED.getCode(),
String.format(ERROR_CODE_TOKEN_CLAIM_VERIFICATION_FAILED.getMessage(), deviceId), e);
}

validateDeviceEditClaims(claims, deviceId);

String deviceToken = null;
String deviceName = null;

if (claims.containsKey(DEVICE_TOKEN)) {
deviceToken = (String) claims.get(DEVICE_TOKEN);
}

if (claims.containsKey(DEVICE_NAME)) {
deviceName = (String) claims.get(DEVICE_NAME);
}

if (deviceToken == null && deviceName == null) {
throw new PushDeviceHandlerClientException(ERROR_CODE_DEVICE_EDIT_FAILED.getCode(),
String.format(ERROR_CODE_DEVICE_EDIT_FAILED.getMessage(), deviceId));
}
if (deviceToken != null) {
device.setDeviceToken(deviceToken);
}
if (deviceName != null) {
device.setDeviceName(deviceName);
}

handleUpdateDeviceForProvider(device);
deviceDAO.editDevice(device.getDeviceId(), device);
AUDIT_LOGGER.printAuditLog(
DeviceHandlerAuditLogger.Operation.UPDATE_DEVICE,
deviceId,
device.getUserId()
);
}

@Override
Expand Down Expand Up @@ -673,4 +737,25 @@ public static PushSenderData buildPushSenderData(PushSenderDTO pushSenderDTO) {
pushSenderData.setProviderId(pushSenderDTO.getProviderId());
return pushSenderData;
}

/**
* Validate the claims in the edit device token.
*
* @param claims Claims to validate.
* @param deviceId Device ID for error messages.
* @throws PushDeviceHandlerClientException If the claims are invalid.
*/
private static void validateDeviceEditClaims(Map<String, Object> claims, String deviceId)
throws PushDeviceHandlerClientException {

if (claims.containsKey(DEVICE_TOKEN) && !(claims.get(DEVICE_TOKEN) instanceof String)) {
throw new PushDeviceHandlerClientException(ERROR_CODE_DEVICE_EDIT_FAILED.getCode(),
String.format(ERROR_CODE_DEVICE_EDIT_FAILED.getMessage(), deviceId));
}

if (claims.containsKey(DEVICE_NAME) && !(claims.get(DEVICE_NAME) instanceof String)) {
throw new PushDeviceHandlerClientException(ERROR_CODE_DEVICE_EDIT_FAILED.getCode(),
String.format(ERROR_CODE_DEVICE_EDIT_FAILED.getMessage(), deviceId));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class DeviceHandlerAuditLogger {
*/
public void printAuditLog(Operation operation, String deviceId, String userId) {

JSONObject data = createAuditLogEntry(userId);
JSONObject data = createAuditLogEntry(operation, userId);
buildAuditLog(operation, deviceId, data);
}

Expand All @@ -72,11 +72,21 @@ private void buildAuditLog(Operation operation, String targetId, JSONObject data
*
* @return Audit log data.
*/
private JSONObject createAuditLogEntry(String userId) {
private JSONObject createAuditLogEntry(Operation operation, String userId) {

JSONObject data = new JSONObject();
data.put(LogConstants.END_USER_ID, userId != null ? userId : JSONObject.NULL);
data.put(LogConstants.UNREGISTERED_AT, System.currentTimeMillis());
switch (operation) {
case REGISTER_DEVICE:
data.put(LogConstants.REGISTERED_AT, System.currentTimeMillis());
break;
case UPDATE_DEVICE:
data.put(LogConstants.UPDATED_AT, System.currentTimeMillis());
break;
case UNREGISTER_DEVICE:
data.put(LogConstants.UNREGISTERED_AT, System.currentTimeMillis());
break;
}

return data;
}
Expand Down Expand Up @@ -128,6 +138,8 @@ private String getInitiatorId() {
*/
public enum Operation {

REGISTER_DEVICE("Register-Push-Auth-Device"),
UPDATE_DEVICE("Update-Push-Auth-Device"),
UNREGISTER_DEVICE("Unregister-Push-Auth-Device");

private final String logAction;
Expand All @@ -150,6 +162,8 @@ private static class LogConstants {

public static final String TARGET_TYPE_FIELD = "Push-Auth-Device";
public static final String END_USER_ID = "UserId";
public static final String REGISTERED_AT = "RegisteredAt";
public static final String UPDATED_AT = "UpdatedAt";
public static final String UNREGISTERED_AT = "UnregisteredAt";
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com).
* Copyright (c) 2025-2026, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand Down Expand Up @@ -42,7 +42,7 @@ public class DeviceDAOImplTest {
"DEVICE_MODEL, DEVICE_TOKEN, DEVICE_HANDLE, PROVIDER, PUBLIC_KEY, TENANT_ID) VALUES " +
"( ? , ? , ? , ? , ? , ? , ? , ? , ? )";
public static final String UNREGISTER_DEVICE_TEST = "DELETE FROM IDN_PUSH_DEVICE_STORE WHERE ID = ? ";
public static final String EDIT_DEVICE_TEST = "UPDATE IDN_PUSH_DEVICE_STORE SET DEVICE_NAME = ? " +
public static final String EDIT_DEVICE_TEST = "UPDATE IDN_PUSH_DEVICE_STORE SET DEVICE_NAME = ? , " +
"DEVICE_TOKEN = ? WHERE ID = ? ";
private DeviceDAOImpl deviceDAO;

Expand Down
Loading
Loading