diff --git a/core/src/main/java/com/taobao/arthas/core/advisor/SpyImpl.java b/core/src/main/java/com/taobao/arthas/core/advisor/SpyImpl.java index cd69371f3f..ad295f9ba9 100644 --- a/core/src/main/java/com/taobao/arthas/core/advisor/SpyImpl.java +++ b/core/src/main/java/com/taobao/arthas/core/advisor/SpyImpl.java @@ -23,6 +23,10 @@ */ public class SpyImpl extends AbstractSpy { private static final Logger logger = LoggerFactory.getLogger(SpyImpl.class); + /** + * 防止 advice 回调链路里再次命中增强点时发生同线程递归分发。 + */ + private static final ThreadLocal DISPATCH_DEPTH = new ThreadLocal(); @Override public void atEnter(Class clazz, String methodInfo, Object target, Object[] args) { @@ -34,16 +38,20 @@ public void atEnter(Class clazz, String methodInfo, Object target, Object[] a // TODO listener 只用查一次,放到 thread local里保存起来就可以了! List listeners = AdviceListenerManager.queryAdviceListeners(classLoader, clazz.getName(), methodName, methodDesc); - if (listeners != null) { - for (AdviceListener adviceListener : listeners) { - try { - if (skipAdviceListener(adviceListener)) { - continue; + if (listeners != null && tryEnterDispatch()) { + try { + for (AdviceListener adviceListener : listeners) { + try { + if (skipAdviceListener(adviceListener)) { + continue; + } + adviceListener.before(clazz, methodName, methodDesc, target, args); + } catch (Throwable e) { + logger.error("class: {}, methodInfo: {}", clazz.getName(), methodInfo, e); } - adviceListener.before(clazz, methodName, methodDesc, target, args); - } catch (Throwable e) { - logger.error("class: {}, methodInfo: {}", clazz.getName(), methodInfo, e); } + } finally { + exitDispatch(); } } @@ -59,16 +67,20 @@ public void atExit(Class clazz, String methodInfo, Object target, Object[] ar List listeners = AdviceListenerManager.queryAdviceListeners(classLoader, clazz.getName(), methodName, methodDesc); - if (listeners != null) { - for (AdviceListener adviceListener : listeners) { - try { - if (skipAdviceListener(adviceListener)) { - continue; + if (listeners != null && tryEnterDispatch()) { + try { + for (AdviceListener adviceListener : listeners) { + try { + if (skipAdviceListener(adviceListener)) { + continue; + } + adviceListener.afterReturning(clazz, methodName, methodDesc, target, args, returnObject); + } catch (Throwable e) { + logger.error("class: {}, methodInfo: {}", clazz.getName(), methodInfo, e); } - adviceListener.afterReturning(clazz, methodName, methodDesc, target, args, returnObject); - } catch (Throwable e) { - logger.error("class: {}, methodInfo: {}", clazz.getName(), methodInfo, e); } + } finally { + exitDispatch(); } } } @@ -83,16 +95,20 @@ public void atExceptionExit(Class clazz, String methodInfo, Object target, Ob List listeners = AdviceListenerManager.queryAdviceListeners(classLoader, clazz.getName(), methodName, methodDesc); - if (listeners != null) { - for (AdviceListener adviceListener : listeners) { - try { - if (skipAdviceListener(adviceListener)) { - continue; + if (listeners != null && tryEnterDispatch()) { + try { + for (AdviceListener adviceListener : listeners) { + try { + if (skipAdviceListener(adviceListener)) { + continue; + } + adviceListener.afterThrowing(clazz, methodName, methodDesc, target, args, throwable); + } catch (Throwable e) { + logger.error("class: {}, methodInfo: {}", clazz.getName(), methodInfo, e); } - adviceListener.afterThrowing(clazz, methodName, methodDesc, target, args, throwable); - } catch (Throwable e) { - logger.error("class: {}, methodInfo: {}", clazz.getName(), methodInfo, e); } + } finally { + exitDispatch(); } } } @@ -107,18 +123,22 @@ public void atBeforeInvoke(Class clazz, String invokeInfo, Object target) { List listeners = AdviceListenerManager.queryTraceAdviceListeners(classLoader, clazz.getName(), owner, methodName, methodDesc); - - if (listeners != null) { - for (AdviceListener adviceListener : listeners) { - try { - if (skipAdviceListener(adviceListener)) { - continue; + int tracingLineNumber = Integer.parseInt(info[3]); + if (listeners != null && tryEnterDispatch()) { + try { + for (AdviceListener adviceListener : listeners) { + try { + if (skipAdviceListener(adviceListener)) { + continue; + } + InvokeTraceable listener = (InvokeTraceable) adviceListener; + listener.invokeBeforeTracing(classLoader, owner, methodName, methodDesc, tracingLineNumber); + } catch (Throwable e) { + logger.error("class: {}, invokeInfo: {}", clazz.getName(), invokeInfo, e); } - final InvokeTraceable listener = (InvokeTraceable) adviceListener; - listener.invokeBeforeTracing(classLoader, owner, methodName, methodDesc, Integer.parseInt(info[3])); - } catch (Throwable e) { - logger.error("class: {}, invokeInfo: {}", clazz.getName(), invokeInfo, e); } + } finally { + exitDispatch(); } } } @@ -132,18 +152,22 @@ public void atAfterInvoke(Class clazz, String invokeInfo, Object target) { String methodDesc = info[2]; List listeners = AdviceListenerManager.queryTraceAdviceListeners(classLoader, clazz.getName(), owner, methodName, methodDesc); - - if (listeners != null) { - for (AdviceListener adviceListener : listeners) { - try { - if (skipAdviceListener(adviceListener)) { - continue; + int tracingLineNumber = Integer.parseInt(info[3]); + if (listeners != null && tryEnterDispatch()) { + try { + for (AdviceListener adviceListener : listeners) { + try { + if (skipAdviceListener(adviceListener)) { + continue; + } + InvokeTraceable listener = (InvokeTraceable) adviceListener; + listener.invokeAfterTracing(classLoader, owner, methodName, methodDesc, tracingLineNumber); + } catch (Throwable e) { + logger.error("class: {}, invokeInfo: {}", clazz.getName(), invokeInfo, e); } - final InvokeTraceable listener = (InvokeTraceable) adviceListener; - listener.invokeAfterTracing(classLoader, owner, methodName, methodDesc, Integer.parseInt(info[3])); - } catch (Throwable e) { - logger.error("class: {}, invokeInfo: {}", clazz.getName(), invokeInfo, e); } + } finally { + exitDispatch(); } } @@ -159,22 +183,39 @@ public void atInvokeException(Class clazz, String invokeInfo, Object target, List listeners = AdviceListenerManager.queryTraceAdviceListeners(classLoader, clazz.getName(), owner, methodName, methodDesc); - - if (listeners != null) { - for (AdviceListener adviceListener : listeners) { - try { - if (skipAdviceListener(adviceListener)) { - continue; + int tracingLineNumber = Integer.parseInt(info[3]); + if (listeners != null && tryEnterDispatch()) { + try { + for (AdviceListener adviceListener : listeners) { + try { + if (skipAdviceListener(adviceListener)) { + continue; + } + InvokeTraceable listener = (InvokeTraceable) adviceListener; + listener.invokeThrowTracing(classLoader, owner, methodName, methodDesc, tracingLineNumber); + } catch (Throwable e) { + logger.error("class: {}, invokeInfo: {}", clazz.getName(), invokeInfo, e); } - final InvokeTraceable listener = (InvokeTraceable) adviceListener; - listener.invokeThrowTracing(classLoader, owner, methodName, methodDesc, Integer.parseInt(info[3])); - } catch (Throwable e) { - logger.error("class: {}, invokeInfo: {}", clazz.getName(), invokeInfo, e); } + } finally { + exitDispatch(); } } } + private static boolean tryEnterDispatch() { + Integer depth = DISPATCH_DEPTH.get(); + if (depth != null && depth.intValue() > 0) { + return false; + } + DISPATCH_DEPTH.set(Integer.valueOf(1)); + return true; + } + + private static void exitDispatch() { + DISPATCH_DEPTH.remove(); + } + private static boolean skipAdviceListener(AdviceListener adviceListener) { if (adviceListener instanceof ProcessAware) { ProcessAware processAware = (ProcessAware) adviceListener; diff --git a/core/src/test/java/com/taobao/arthas/core/advisor/SpyImplTest.java b/core/src/test/java/com/taobao/arthas/core/advisor/SpyImplTest.java index 5d6110cfb2..bdf672bd9b 100644 --- a/core/src/test/java/com/taobao/arthas/core/advisor/SpyImplTest.java +++ b/core/src/test/java/com/taobao/arthas/core/advisor/SpyImplTest.java @@ -1,16 +1,42 @@ package com.taobao.arthas.core.advisor; +import java.lang.instrument.Instrumentation; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + import org.assertj.core.api.Assertions; +import org.junit.BeforeClass; import org.junit.Test; +import com.taobao.arthas.core.bytecode.TestHelper; +import com.taobao.arthas.core.server.ArthasBootstrap; import com.taobao.arthas.core.util.StringUtils; +import net.bytebuddy.agent.ByteBuddyAgent; + /** * * @author hengyunabc 2021-07-14 * */ public class SpyImplTest { + private static final AtomicLong LISTENER_ID = new AtomicLong(1000L); + private static final AtomicLong METHOD_SEQUENCE = new AtomicLong(1L); + + @BeforeClass + public static void beforeClass() throws Throwable { + try { + ArthasBootstrap.getInstance(); + return; + } catch (IllegalStateException ignored) { + // 测试首次运行时初始化 Arthas 环境,供 AdviceListenerManager 使用。 + } + + Instrumentation instrumentation = ByteBuddyAgent.install(); + TestHelper.appendSpyJar(instrumentation); + ArthasBootstrap.getInstance(instrumentation, "ip=127.0.0.1"); + } @Test public void testSplitMethodInfo() throws Throwable { @@ -27,4 +53,143 @@ public void testSplitInvokeInfo() throws Throwable { .containsExactly("demo/MathGame", "primeFactors", "(I)Ljava/util/List;", "24"); } + + @Test + public void shouldSkipReentrantAfterReturningDispatch() throws Throwable { + final SpyImpl spy = new SpyImpl(); + final String methodName = nextMethodName("reentrantAfterReturning"); + final String methodDesc = "()Ljava/lang/String;"; + final String methodInfo = buildMethodInfo(methodName, methodDesc); + final AtomicBoolean reentered = new AtomicBoolean(false); + final AtomicInteger callbackCount = new AtomicInteger(0); + registerMethodListener(methodName, methodDesc, new TestAdviceListener() { + + @Override + public void afterReturning(Class clazz, String adviceMethodName, String adviceMethodDesc, Object target, + Object[] args, Object returnObject) throws Throwable { + callbackCount.incrementAndGet(); + if (reentered.compareAndSet(false, true)) { + spy.atExit(SpyImplTest.class, methodInfo, target, args, returnObject); + } + } + }); + + spy.atExit(SpyImplTest.class, methodInfo, null, new Object[0], "done"); + + Assertions.assertThat(reentered.get()).isTrue(); + Assertions.assertThat(callbackCount).hasValue(1); + } + + @Test + public void shouldSkipReentrantAfterThrowingDispatch() throws Throwable { + final SpyImpl spy = new SpyImpl(); + final String methodName = nextMethodName("reentrantAfterThrowing"); + final String methodDesc = "()V"; + final Throwable expected = new IllegalStateException("boom"); + final String methodInfo = buildMethodInfo(methodName, methodDesc); + final AtomicBoolean reentered = new AtomicBoolean(false); + final AtomicInteger callbackCount = new AtomicInteger(0); + registerMethodListener(methodName, methodDesc, new TestAdviceListener() { + + @Override + public void afterThrowing(Class clazz, String adviceMethodName, String adviceMethodDesc, Object target, + Object[] args, Throwable throwable) throws Throwable { + callbackCount.incrementAndGet(); + if (reentered.compareAndSet(false, true)) { + spy.atExceptionExit(SpyImplTest.class, methodInfo, target, args, throwable); + } + + Assertions.assertThat(throwable).isSameAs(expected); + } + }); + + spy.atExceptionExit(SpyImplTest.class, methodInfo, null, new Object[0], expected); + + Assertions.assertThat(reentered.get()).isTrue(); + Assertions.assertThat(callbackCount).hasValue(1); + } + + @Test + public void shouldSkipNestedDispatchForOtherListenersToo() throws Throwable { + final SpyImpl spy = new SpyImpl(); + final String methodName = nextMethodName("nestedDispatchAllListeners"); + final String methodDesc = "()I"; + final AtomicInteger firstListenerCount = new AtomicInteger(0); + final AtomicInteger secondListenerCount = new AtomicInteger(0); + final AtomicBoolean reentered = new AtomicBoolean(false); + final String methodInfo = buildMethodInfo(methodName, methodDesc); + + AdviceListenerManager.registerAdviceListener(SpyImplTest.class.getClassLoader(), SpyImplTest.class.getName(), + methodName, methodDesc, new TestAdviceListener() { + @Override + public void afterReturning(Class clazz, String adviceMethodName, String adviceMethodDesc, + Object target, Object[] args, Object returnObject) throws Throwable { + firstListenerCount.incrementAndGet(); + if (reentered.compareAndSet(false, true)) { + spy.atExit(SpyImplTest.class, methodInfo, target, args, returnObject); + } + } + }); + + AdviceListenerManager.registerAdviceListener(SpyImplTest.class.getClassLoader(), SpyImplTest.class.getName(), + methodName, methodDesc, new TestAdviceListener() { + @Override + public void afterReturning(Class clazz, String adviceMethodName, String adviceMethodDesc, + Object target, Object[] args, Object returnObject) throws Throwable { + secondListenerCount.incrementAndGet(); + } + }); + + spy.atExit(SpyImplTest.class, methodInfo, null, new Object[0], Integer.valueOf(1)); + + Assertions.assertThat(reentered.get()).isTrue(); + Assertions.assertThat(firstListenerCount).hasValue(1); + Assertions.assertThat(secondListenerCount).hasValue(1); + } + + private static String registerMethodListener(String methodName, String methodDesc, AdviceListener adviceListener) { + AdviceListenerManager.registerAdviceListener(SpyImplTest.class.getClassLoader(), SpyImplTest.class.getName(), + methodName, methodDesc, adviceListener); + return buildMethodInfo(methodName, methodDesc); + } + + private static String buildMethodInfo(String methodName, String methodDesc) { + return methodName + "|" + methodDesc; + } + + private static String nextMethodName(String prefix) { + return prefix + METHOD_SEQUENCE.getAndIncrement(); + } + + private abstract static class TestAdviceListener implements AdviceListener { + private final long id = LISTENER_ID.getAndIncrement(); + + @Override + public long id() { + return id; + } + + @Override + public void create() { + } + + @Override + public void destroy() { + } + + @Override + public void before(Class clazz, String methodName, String methodDesc, Object target, Object[] args) + throws Throwable { + } + + @Override + public void afterReturning(Class clazz, String methodName, String methodDesc, Object target, Object[] args, + Object returnObject) throws Throwable { + } + + @Override + public void afterThrowing(Class clazz, String methodName, String methodDesc, Object target, Object[] args, + Throwable throwable) throws Throwable { + } + } }