diff --git a/test/jdk/jdk/internal/misc/TerminatingThreadLocal/TestTerminatingThreadLocal.java b/test/jdk/jdk/internal/misc/TerminatingThreadLocal/TestTerminatingThreadLocal.java index 8ef3e5cb180..dd22bbd7e01 100644 --- a/test/jdk/jdk/internal/misc/TerminatingThreadLocal/TestTerminatingThreadLocal.java +++ b/test/jdk/jdk/internal/misc/TerminatingThreadLocal/TestTerminatingThreadLocal.java @@ -23,47 +23,69 @@ import jdk.internal.misc.TerminatingThreadLocal; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; import java.util.Arrays; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Stream; + +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; +import static org.testng.Assert.*; /* * @test * @bug 8202788 8291897 * @summary TerminatingThreadLocal unit test - * @modules java.base/jdk.internal.misc + * @modules java.base/java.lang:+open java.base/jdk.internal.misc * @requires vm.continuations * @enablePreview - * @run main/othervm -Djdk.virtualThreadScheduler.parallelism=1 -Djdk.virtualThreadScheduler.maxPoolSize=2 TestTerminatingThreadLocal + * @run testng/othervm TestTerminatingThreadLocal */ public class TestTerminatingThreadLocal { - public static void main(String[] args) { - ttlTestSet(42, 112); - ttlTestSet(null, 112); - ttlTestSet(42, null); - - ttlTestVirtual(666, ThreadLocal::get, 666); - } - - static void ttlTestSet(T v0, T v1) { - ttlTestPlatform(v0, ttl -> { } ); - ttlTestPlatform(v0, ttl -> { ttl.get(); }, v0); - ttlTestPlatform(v0, ttl -> { ttl.get(); ttl.remove(); } ); - ttlTestPlatform(v0, ttl -> { ttl.get(); ttl.set(v1); }, v1); - ttlTestPlatform(v0, ttl -> { ttl.set(v1); }, v1); - ttlTestPlatform(v0, ttl -> { ttl.set(v1); ttl.remove(); } ); - ttlTestPlatform(v0, ttl -> { ttl.set(v1); ttl.remove(); ttl.get(); }, v0); - ttlTestPlatform(v0, ttl -> { ttl.get(); ttl.remove(); ttl.set(v1); }, v1); - } - - @SafeVarargs - static void ttlTestPlatform(T initialValue, + static Object[] testCase(T initialValue, + Consumer> ttlOps, + T... expectedTerminatedValues) { + return new Object[] {initialValue, ttlOps, Arrays.asList(expectedTerminatedValues)}; + } + + static Stream testCases(T v0, T v1) { + return Stream.of( + testCase(v0, ttl -> { } ), + testCase(v0, ttl -> { ttl.get(); }, v0), + testCase(v0, ttl -> { ttl.get(); ttl.remove(); } ), + testCase(v0, ttl -> { ttl.get(); ttl.set(v1); }, v1), + testCase(v0, ttl -> { ttl.set(v1); }, v1), + testCase(v0, ttl -> { ttl.set(v1); ttl.remove(); } ), + testCase(v0, ttl -> { ttl.set(v1); ttl.remove(); ttl.get(); }, v0), + testCase(v0, ttl -> { ttl.get(); ttl.remove(); ttl.set(v1); }, v1) + ); + } + + @DataProvider + public Object[][] testCases() { + return Stream.of( + testCases(42, 112), + testCases(null, new Object()), + testCases("abc", null) + ).flatMap(Function.identity()).toArray(Object[][]::new); + } + + /** + * Test TerminatingThreadLocal with a platform thread. + */ + @Test(dataProvider = "testCases") + public void ttlTestPlatform(T initialValue, Consumer> ttlOps, - T... expectedTerminatedValues) { + List expectedTerminatedValues) throws Exception { List terminatedValues = new CopyOnWriteArrayList<>(); TerminatingThreadLocal ttl = new TerminatingThreadLocal<>() { @@ -80,23 +102,20 @@ public class TestTerminatingThreadLocal { Thread thread = new Thread(() -> ttlOps.accept(ttl), "ttl-test-platform"); thread.start(); - try { - thread.join(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + thread.join(); - if (!terminatedValues.equals(Arrays.asList(expectedTerminatedValues))) { - throw new AssertionError("Expected terminated values: " + - Arrays.toString(expectedTerminatedValues) + - " but got: " + terminatedValues); - } + assertEquals(terminatedValues, expectedTerminatedValues); } - @SafeVarargs - static void ttlTestVirtual(T initialValue, + /** + * Test TerminatingThreadLocal with a virtual thread. The thread local should be + * carrier thread local but accessible to the virtual thread. The threadTerminated + * method should be invoked when the carrier thread terminates. + */ + @Test(dataProvider = "testCases") + public void ttlTestVirtual(T initialValue, Consumer> ttlOps, - T... expectedTerminatedValues) { + List expectedTerminatedValues) throws Exception { List terminatedValues = new CopyOnWriteArrayList<>(); TerminatingThreadLocal ttl = new TerminatingThreadLocal<>() { @@ -111,77 +130,49 @@ public class TestTerminatingThreadLocal { } }; - var lock = new Lock(); + Thread carrier; - var blockerThread = Thread.startVirtualThread(() -> { - // force compensation in carrier thread pool which will spin another - // carrier thread so that we can later observe it being terminated... - synchronized (lock) { - while (!lock.unblock) { - try { - lock.wait(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + // use a single worker thread pool for the cheduler + try (var pool = Executors.newSingleThreadExecutor()) { + + // capture carrier Thread + carrier = pool.submit(Thread::currentThread).get(); + + ThreadFactory factory = virtualThreadBuilder(pool) + .name("ttl-test-virtual-", 0) + .allowSetThreadLocals(false) + .factory(); + try (var executor = Executors.newThreadPerTaskExecutor(factory)) { + executor.submit(() -> ttlOps.accept(ttl)).get(); } - // keep thread running in a non-blocking-fashion which keeps - // it bound to carrier thread - while (!lock.unspin) { - Thread.onSpinWait(); - } - }); - Thread thread = Thread - .ofVirtual() - .allowSetThreadLocals(false) - .inheritInheritableThreadLocals(false) - .name("ttl-test-virtual") - .unstarted(() -> ttlOps.accept(ttl)); - thread.start(); - try { - thread.join(); - } catch (InterruptedException e) { - throw new RuntimeException(e); + assertTrue(terminatedValues.isEmpty(), + "Unexpected terminated values after virtual thread terminated"); } - if (!terminatedValues.isEmpty()) { - throw new AssertionError("Unexpected terminated values after virtual thread.join(): " + - terminatedValues); - } + // wait for carrier to terminate + carrier.join(); - // we now unblock the blocker thread but keep it running - synchronized (lock) { - lock.unblock = true; - lock.notify(); - } - - // carrier thread pool has a 30 second keep-alive time to terminate excessive carrier - // threads. Since blockerThread is still pinning one of them we hope for the other - // thread to be terminated... - try { - TimeUnit.SECONDS.sleep(31); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - - if (!terminatedValues.equals(Arrays.asList(expectedTerminatedValues))) { - throw new AssertionError("Expected terminated values: " + - Arrays.toString(expectedTerminatedValues) + - " but got: " + terminatedValues); - } - - // we now terminate the blocker thread - lock.unspin = true; - try { - blockerThread.join(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + assertEquals(terminatedValues, expectedTerminatedValues); } - static class Lock { - boolean unblock; - volatile boolean unspin; + /** + * Returns a builder to create virtual threads that use the given scheduler. + */ + static Thread.Builder.OfVirtual virtualThreadBuilder(Executor scheduler) { + try { + Class clazz = Class.forName("java.lang.ThreadBuilders$VirtualThreadBuilder"); + Constructor ctor = clazz.getDeclaredConstructor(Executor.class); + ctor.setAccessible(true); + return (Thread.Builder.OfVirtual) ctor.newInstance(scheduler); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof RuntimeException re) { + throw re; + } + throw new RuntimeException(e); + } catch (Exception e) { + throw new RuntimeException(e); + } } }