diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java b/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java index 708c0e9afe..49f84c5159 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java @@ -112,6 +112,14 @@ public static void clearEventHandler() throws RmmException { } } + // helper method to get the SparkResourceAdaptor, keeping consistency + // with the static Rmm class lock + private static SparkResourceAdaptor getSra() { + synchronized (Rmm.class) { + return sra; + } + } + /** * Get the id of the current thread as used by RmmSpark. */ @@ -126,11 +134,10 @@ public static long getCurrentThreadId() { * @param taskId the task ID this thread is working on. */ public static void startDedicatedTaskThread(long threadId, long taskId, Thread thread) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - ThreadStateRegistry.addThread(threadId, thread); - sra.startDedicatedTaskThread(threadId, taskId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + ThreadStateRegistry.addThread(threadId, thread); + local.startDedicatedTaskThread(threadId, taskId); } } @@ -150,11 +157,10 @@ public static void currentThreadIsDedicatedToTask(long taskId) { * @param taskIds the IDs of tasks that this is starting work on. */ public static void shuffleThreadWorkingTasks(long threadId, Thread thread, long[] taskIds) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - ThreadStateRegistry.addThread(threadId, thread); - sra.poolThreadWorkingOnTasks(true, threadId, taskIds); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + ThreadStateRegistry.addThread(threadId, thread); + local.poolThreadWorkingOnTasks(true, threadId, taskIds); } } @@ -167,10 +173,9 @@ public static void shuffleThreadWorkingOnTasks(long[] taskIds) { } public static boolean isThreadWorkingOnTaskAsPoolThread() { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.isThreadWorkingOnTaskAsPoolThread(getCurrentThreadId()); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.isThreadWorkingOnTaskAsPoolThread(getCurrentThreadId()); } return false; } @@ -184,11 +189,10 @@ public static void poolThreadWorkingOnTask(long taskId) { long threadId = getCurrentThreadId(); Thread thread = Thread.currentThread(); long[] taskIds = new long[]{taskId}; - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - ThreadStateRegistry.addThread(threadId, thread); - sra.poolThreadWorkingOnTasks(false, threadId, taskIds); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + ThreadStateRegistry.addThread(threadId, thread); + local.poolThreadWorkingOnTasks(false, threadId, taskIds); } } @@ -199,10 +203,9 @@ public static void poolThreadWorkingOnTask(long taskId) { * @param taskIds the IDs of the tasks that are done. */ public static void poolThreadFinishedForTasks(long threadId, long[] taskIds) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.poolThreadFinishedForTasks(threadId, taskIds); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.poolThreadFinishedForTasks(threadId, taskIds); } } @@ -246,10 +249,9 @@ public static void poolThreadFinishedForTask(long taskId) { * @param threadId the id of the thread, not the java ID. */ public static void startRetryBlock(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.startRetryBlock(threadId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.startRetryBlock(threadId); } } @@ -265,10 +267,9 @@ public static void currentThreadStartRetryBlock() { * @param threadId the id of the thread, not the java ID. */ public static void endRetryBlock(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.endRetryBlock(threadId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.endRetryBlock(threadId); } } @@ -280,10 +281,9 @@ public static void currentThreadEndRetryBlock() { } private static void checkAndBreakDeadlocks() { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.checkAndBreakDeadlocks(); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.checkAndBreakDeadlocks(); } } @@ -293,10 +293,9 @@ private static void checkAndBreakDeadlocks() { * (not java thread id). */ public static void removeDedicatedThreadAssociation(long threadId, long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.removeThreadAssociation(threadId, taskId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.removeThreadAssociation(threadId, taskId); } } @@ -314,10 +313,9 @@ public static void removeCurrentDedicatedThreadAssociation(long taskId) { * @param threadId the id of the thread to clean up */ public static void removeAllThreadAssociation(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.removeThreadAssociation(threadId, -1); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.removeThreadAssociation(threadId, -1); } } @@ -336,10 +334,9 @@ public static void removeAllCurrentThreadAssociation() { * @param taskId the ID of the task that has completed. */ public static void taskDone(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.taskDone(taskId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.taskDone(taskId); } } @@ -348,10 +345,9 @@ public static void taskDone(long taskId) { * @param threadId the ID of the thread that is about to submit the work. */ public static void submittingToPool(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.submittingToPool(threadId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.submittingToPool(threadId); } } @@ -369,10 +365,9 @@ public static void submittingToPool() { * @param threadId the ID of the thread that is about to wait. */ public static void waitingOnPool(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.waitingOnPool(threadId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.waitingOnPool(threadId); } } @@ -390,10 +385,9 @@ public static void waitingOnPool() { * @param threadId the ID of the thread that is done. */ public static void doneWaitingOnPool(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.doneWaitingOnPool(threadId); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.doneWaitingOnPool(threadId); } } @@ -451,12 +445,11 @@ public static void forceRetryOOM(long threadId) { * @param skipCount how many matching allocations to skip */ public static void forceRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.forceRetryOOM(threadId, numOOMs, oomMode, skipCount); - } else { - throw new IllegalStateException("RMM has not been configured for OOM injection"); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.forceRetryOOM(threadId, numOOMs, oomMode, skipCount); + } else { + throw new IllegalStateException("RMM has not been configured for OOM injection"); } } @@ -482,12 +475,11 @@ public static void forceSplitAndRetryOOM(long threadId) { * @param skipCount how many matching allocations to skip */ public static void forceSplitAndRetryOOM(long threadId, int numOOMs, int oomMode, int skipCount) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.forceSplitAndRetryOOM(threadId, numOOMs, oomMode, skipCount); - } else { - throw new IllegalStateException("RMM has not been configured for OOM injection"); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.forceSplitAndRetryOOM(threadId, numOOMs, oomMode, skipCount); + } else { + throw new IllegalStateException("RMM has not been configured for OOM injection"); } } @@ -511,23 +503,21 @@ public static void forceCudfException(long threadId) { * @param numTimes the number of times the CudfException should be thrown */ public static void forceCudfException(long threadId, int numTimes) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - sra.forceCudfException(threadId, numTimes); - } else { - throw new IllegalStateException("RMM has not been configured for OOM injection"); - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + local.forceCudfException(threadId, numTimes); + } else { + throw new IllegalStateException("RMM has not been configured for OOM injection"); } } public static RmmSparkThreadState getStateOf(long threadId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getStateOf(threadId); - } else { - // sra is not set so the thread is by definition unknown to it. - return RmmSparkThreadState.UNKNOWN; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getStateOf(threadId); + } else { + // sra is not set so the thread is by definition unknown to it. + return RmmSparkThreadState.UNKNOWN; } } @@ -537,13 +527,12 @@ public static RmmSparkThreadState getStateOf(long threadId) { * @return the number of times it was thrown or 0 if in the UNKNOWN state. */ public static int getAndResetNumRetryThrow(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getAndResetNumRetryThrow(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getAndResetNumRetryThrow(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } @@ -553,13 +542,12 @@ public static int getAndResetNumRetryThrow(long taskId) { * @return the number of times it was thrown or 0 if in the UNKNOWN state. */ public static int getAndResetNumSplitRetryThrow(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getAndResetNumSplitRetryThrow(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getAndResetNumSplitRetryThrow(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } @@ -569,13 +557,12 @@ public static int getAndResetNumSplitRetryThrow(long taskId) { * @return the time the task was blocked or 0 if in the UNKNOWN state. */ public static long getAndResetBlockTimeNs(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getAndResetBlockTime(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getAndResetBlockTime(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } @@ -585,24 +572,22 @@ public static long getAndResetBlockTimeNs(long taskId) { * @return the time the task did computation that was lost. */ public static long getAndResetComputeTimeLostToRetryNs(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getAndResetComputeTimeLostToRetry(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getAndResetComputeTimeLostToRetry(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } public static long getTotalBlockedOrLostTime(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getTotalBlockedOrLostTime(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getTotalBlockedOrLostTime(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } @@ -612,24 +597,22 @@ public static long getTotalBlockedOrLostTime(long taskId) { * @return the max device memory footprint. */ public static long getAndResetGpuMaxMemoryAllocated(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getAndResetGpuMaxMemoryAllocated(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getAndResetGpuMaxMemoryAllocated(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } public static long getMaxGpuTaskMemory(long taskId) { - synchronized (Rmm.class) { - if (sra != null && sra.isOpen()) { - return sra.getMaxGpuTaskMemory(taskId); - } else { - // sra is not set so the value is by definition 0 - return 0; - } + SparkResourceAdaptor local = getSra(); + if (local != null && local.isOpen()) { + return local.getMaxGpuTaskMemory(taskId); + } else { + // sra is not set so the value is by definition 0 + return 0; } } @@ -643,10 +626,7 @@ public static long getMaxGpuTaskMemory(long taskId) { * back into the post allocations calls. */ public static boolean preCpuAlloc(long amount, boolean blocking) { - SparkResourceAdaptor local; - synchronized (Rmm.class) { - local = sra; - } + SparkResourceAdaptor local = getSra(); if (local != null && local.isOpen()) { return local.preCpuAlloc(amount, blocking); } else { @@ -663,10 +643,7 @@ public static boolean preCpuAlloc(long amount, boolean blocking) { */ public static void postCpuAllocSuccess(long ptr, long amount, boolean blocking, boolean wasRecursive) { - SparkResourceAdaptor local; - synchronized (Rmm.class) { - local = sra; - } + SparkResourceAdaptor local = getSra(); if (local != null && local.isOpen()) { local.postCpuAllocSuccess(ptr, amount, blocking, wasRecursive); } @@ -681,10 +658,7 @@ public static void postCpuAllocSuccess(long ptr, long amount, boolean blocking, * thinks that a retry would not help. */ public static boolean postCpuAllocFailed(boolean wasOom, boolean blocking, boolean wasRecursive) { - SparkResourceAdaptor local; - synchronized (Rmm.class) { - local = sra; - } + SparkResourceAdaptor local = getSra(); if (local != null && local.isOpen()) { return local.postCpuAllocFailed(wasOom, blocking, wasRecursive); } else { @@ -698,30 +672,21 @@ public static boolean postCpuAllocFailed(boolean wasOom, boolean blocking, boole * @param amount the amount that was made available. */ public static void cpuDeallocate(long ptr, long amount) { - SparkResourceAdaptor local; - synchronized (Rmm.class) { - local = sra; - } + SparkResourceAdaptor local = getSra(); if (local != null && local.isOpen()) { local.cpuDeallocate(ptr, amount); } } public static void spillRangeStart() { - SparkResourceAdaptor local; - synchronized (Rmm.class) { - local = sra; - } + SparkResourceAdaptor local = getSra(); if (local != null && local.isOpen()) { local.spillRangeStart(); } } public static void spillRangeDone() { - SparkResourceAdaptor local; - synchronized (Rmm.class) { - local = sra; - } + SparkResourceAdaptor local = getSra(); if (local != null && local.isOpen()) { local.spillRangeDone(); }