diff --git a/leaf-server/src/main/java/org/dreeam/leaf/async/AsyncDispatcher.java b/leaf-server/src/main/java/org/dreeam/leaf/async/AsyncDispatcher.java new file mode 100644 index 000000000..1781fd232 --- /dev/null +++ b/leaf-server/src/main/java/org/dreeam/leaf/async/AsyncDispatcher.java @@ -0,0 +1,37 @@ +package org.dreeam.leaf.async; + +@org.jspecify.annotations.NullMarked +public final class AsyncDispatcher { + + public static final ThreadPool INSTANCE; + + static { + final String threadsProperty = System.getProperty("leaf.scheduler.threads"); + int numThreads = Math.clamp(Runtime.getRuntime().availableProcessors() / 2, 1, 4); + if (threadsProperty != null) { + try { + int i = Integer.parseInt(threadsProperty); + if (i >= 1) { + numThreads = i; + } + } catch (NumberFormatException ignored) { + } + } + final String queueProperty = System.getProperty("leaf.scheduler.queue-size"); + int queue = 8192; + if (queueProperty != null) { + try { + int j = Integer.parseInt(queueProperty); + if (j >= 1) queue = j; + } catch (NumberFormatException ignored) { + } + } + INSTANCE = new ThreadPool(numThreads, + queue, + "Leaf Async Scheduler", + Thread.NORM_PRIORITY - 1); + } + + private AsyncDispatcher() { + } +} diff --git a/leaf-server/src/main/java/org/dreeam/leaf/async/ThreadPool.java b/leaf-server/src/main/java/org/dreeam/leaf/async/ThreadPool.java new file mode 100644 index 000000000..6bff2e941 --- /dev/null +++ b/leaf-server/src/main/java/org/dreeam/leaf/async/ThreadPool.java @@ -0,0 +1,176 @@ +package org.dreeam.leaf.async; + +import net.minecraft.util.Util; +import org.apache.logging.log4j.LogManager; +import org.dreeam.leaf.util.queue.MpmcQueue; +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; +import org.apache.logging.log4j.Logger; + +import java.util.concurrent.Callable; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.FutureTask; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; + +@NullMarked +public final class ThreadPool implements Executor { + private static final Logger LOGGER = LogManager.getLogger("Leaf"); + private static final long PARK_NANOS = 200_000L; // 0.2ms + + private volatile boolean shutdown = false; + private final Thread[] threads; + private final MpmcQueue channel; + private final MpmcQueue parkChannel; + + public ThreadPool(int numThreads, final int queue, final String prefix, final int priority) { + if (numThreads <= 0) { + throw new IllegalArgumentException(); + } + numThreads = numThreads + 1; + this.threads = new Thread[numThreads]; + this.channel = new MpmcQueue<>(queue); + this.parkChannel = new MpmcQueue<>(numThreads); + this.threads[0] = Thread.ofPlatform() + .uncaughtExceptionHandler(Util::onThreadException) + .daemon(false) + .priority(priority + 1) + .name(prefix + " Dispatcher") + .start(new Dispatcher(this)); + for (int i = 1; i < numThreads; i++) { + threads[i] = Thread.ofPlatform() + .uncaughtExceptionHandler(Util::onThreadException) + .daemon(false) + .priority(priority) + .name(prefix + " Worker - " + i) + .start(new Worker(this)); + } + } + + @Override + public void execute(Runnable task) { + if (shutdown || !channel.send(task)) { + task.run(); + } + } + + public boolean isShutdown() { + return shutdown; + } + + public FutureTask submit(Runnable task, @Nullable V result) { + final FutureTask t = new FutureTask<>(Executors.callable(task, result)); + execute(t); + return t; + } + + public FutureTask submit(Callable task) { + final FutureTask t = new FutureTask<>(task); + execute(t); + return t; + } + + public void shutdown() { + shutdown = true; + for (final Thread thread : threads) { + LockSupport.unpark(thread); + } + } + + public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException { + final long nanos = unit.toNanos(timeout); + final long startTime = System.nanoTime(); + + boolean flag = true; + for (final Thread worker : threads) { + if (nanos <= 0L) { + worker.join(); + continue; + } + final long remaining = startTime + nanos - System.nanoTime(); + if (remaining <= 0L) { + flag = false; + break; + } else { + worker.join(remaining / 1_000_000L, (int) (remaining % 1_000_000L)); + if (worker.isAlive()) { + flag = false; + break; + } + } + } + Runnable task; + while ((task = channel.recv()) != null) { + task.run(); + } + return flag; + } + + public int workerCount() { + return threads.length - 1; + } + + private record Worker(ThreadPool executor) implements Runnable { + @Override + public void run() { + final MpmcQueue channel = executor.channel; + final MpmcQueue park = executor.parkChannel; + while (true) { + final Runnable task = channel.recv(); + if (task != null) { + try { + task.run(); + } catch (final Throwable e) { + LOGGER.error("Task {} generated an exception: {}", task, Thread.currentThread().getName(), e); + } + } else if (executor.shutdown) { + break; + } else if (park.send(Thread.currentThread())) { + LockSupport.park(); + if (Thread.interrupted()) { + Thread.currentThread().interrupt(); + break; + } + } else { + Thread.yield(); + } + } + } + } + + private record Dispatcher(ThreadPool executor) implements Runnable { + @Override + public void run() { + final int threads = executor.threads.length - 1; + final MpmcQueue channel = executor.channel; + final MpmcQueue park = executor.parkChannel; + int backoff = 0; + while (true) { + final int len = channel.length(); + if (len != 0 && threads - park.length() < len) { + backoff = 0; + final Thread thread = park.recv(); + if (thread != null) { + LockSupport.unpark(thread); + } + } else if (executor.shutdown) { + break; + } else if (backoff < 8) { + backoff++; + Thread.yield(); + } else { + LockSupport.parkNanos(PARK_NANOS); + if (Thread.interrupted()) { + Thread.currentThread().interrupt(); + break; + } + } + } + Thread left; + while ((left = park.recv()) != null) { + LockSupport.unpark(left); + } + } + } +} diff --git a/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java b/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java index a94fe874a..81fd9e941 100644 --- a/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java +++ b/leaf-server/src/main/java/org/dreeam/leaf/util/queue/MpmcQueue.java @@ -11,8 +11,60 @@ import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; +@SuppressWarnings("unused") +abstract sealed class ReadCounter permits CachePad1 { + protected volatile long r; +} + +@SuppressWarnings("unused") +abstract sealed class CachePad1 extends ReadCounter permits WriteCounter { + byte i0, i1, i2, i3, i4, i5, i6, i7, + j0, j1, j2, j3, j4, j5, j6, j7, + k0, k1, k2, k3, k4, k5, k6, k7, + l0, l1, l2, l3, l4, l5, l6, l7, + m0, m1, m2, m3, m4, m5, m6, m7, + n0, n1, n2, n3, n4, n5, n6, n7, + o0, o1, o2, o3, o4, o5, o6, o7; +} + +@SuppressWarnings("unused") +abstract sealed class WriteCounter extends CachePad1 permits CachePad2 { + protected volatile long w; +} + +@SuppressWarnings("unused") +abstract sealed class CachePad2 extends WriteCounter permits MpmcQueue { + byte i0, i1, i2, i3, i4, i5, i6, i7, + j0, j1, j2, j3, j4, j5, j6, j7, + k0, k1, k2, k3, k4, k5, k6, k7, + l0, l1, l2, l3, l4, l5, l6, l7, + m0, m1, m2, m3, m4, m5, m6, m7, + n0, n1, n2, n3, n4, n5, n6, n7, + o0, o1, o2, o3, o4, o5, o6, o7; +} + +/// ```text +/// counter layout +/// +63------------------------------------------------16+15-----8+7------0+ +/// | index | done | pend | +/// +----------------------------------------------------+--------+--------+ +/// ``` +/// +/// - index (48bits): current read/write position in the ring buffer (head/tail) +/// - pend (8bits): number of pending concurrent read/writes +/// - done (8bits): number of completed read/writes +/// +/// For reading reads_pend is incremented first, then the content of the ring buffer is read from memory. +/// After reading is done reads_done is incremented. reads_index is only incremented if reads_done is equal to reads_pend. +/// +/// For writing first writes_pend is incremented, then the content of the ring buffer is updated. +/// After writing writes_done is incremented. If writes_done is equal to writes_pend then both are set to 0 and writes_index is incremented. +/// +/// In rare cases this can result in a race where multiple threads increment reads_pend in turn and reads_done never quite reaches reads_pend. +/// If reads_pend == 16 or writes_pend == 16 a spin loop waits it to be <16 to continue. @NullMarked -public final class MpmcQueue { +public final class MpmcQueue extends CachePad2 { + private static final long DONE_MASK = 0x0000_0000_0000_FF00L; private static final long PENDING_MASK = 0x0000_0000_0000_00FFL; private static final long DONE_PENDING_MASK = DONE_MASK | PENDING_MASK; @@ -20,158 +72,162 @@ public final class MpmcQueue { private static final int DONE_SHIFT = 8; private static final long MAX_IN_PROGRESS = 16; private static final int MAX_CAPACITY = 1 << 30; - private static final int PARALLELISM = Runtime.getRuntime().availableProcessors(); private static final VarHandle READ; private static final VarHandle WRITE; + private static final VarHandle A; private final long mask; - private final long capacity; - private final @Nullable T[] buffer; - - private final ReadCounter reads = new ReadCounter(); - private final WriteCounter writes = new WriteCounter(); + private final T[] a; static { try { MethodHandles.Lookup l = MethodHandles.lookup(); - READ = l.findVarHandle(ReadCounter.class, "reads", long.class); - WRITE = l.findVarHandle(WriteCounter.class, "writes", long.class); + READ = l.findVarHandle(MpmcQueue.class, "r", long.class); + WRITE = l.findVarHandle(MpmcQueue.class, "w", long.class); + A = MethodHandles.arrayElementVarHandle(Object[].class); } catch (ReflectiveOperationException e) { throw new ExceptionInInitializerError(e); } } - public MpmcQueue(Class clazz, int capacity) { + public MpmcQueue(final int capacity) { + super(); if (capacity <= 0 || capacity > MAX_CAPACITY) { throw new IllegalArgumentException(); } - this.capacity = Math.max(2, (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity - 1)))); - this.mask = this.capacity - 1L; - //noinspection unchecked - this.buffer = (clazz == Object.class) - ? (T[]) new Object[(int) this.capacity] - : (T[]) java.lang.reflect.Array.newInstance(clazz, (int) this.capacity); + final long size = Math.max(2, (1L << (Integer.SIZE - Integer.numberOfLeadingZeros(capacity)))); + this.mask = size - 1L; + this.a = (T[]) new Object[(int) size]; } - private void spinWait(final int attempts) { - //noinspection StatementWithEmptyBody - if (attempts == 0) { - } else if (PARALLELISM != 1 && (attempts & 31) != 31) { - Thread.onSpinWait(); - } else { - Thread.yield(); - } + public MpmcQueue(final Class ignore, final int capacity) { + this(capacity); } public boolean send(final T item) { - long write = (long) WRITE.getAcquire(this.writes); - boolean success; - long newWrite = 0L; - long index = 0L; - int attempts = 0; + long write = (long) WRITE.getAcquire(this); + long index; while (true) { - spinWait(attempts++); final long inProgressCnt = (write & PENDING_MASK); - if ((((write >>> INDEX_SHIFT) + 1L) & mask) == ((long) READ.getVolatile(this.reads) >>> INDEX_SHIFT)) { - success = false; - break; + if (writeFull(write >>> INDEX_SHIFT)) { + return false; } - if (inProgressCnt == MAX_IN_PROGRESS) { - write = (long) WRITE.getAcquire(this.writes); + spinWait(); + write = (long) WRITE.getAcquire(this); continue; } - index = ((write >>> INDEX_SHIFT) + inProgressCnt) & mask; - if (((index + 1L) & mask) == ((long) READ.getVolatile(this.reads) >>> INDEX_SHIFT)) { - success = false; - break; + index = nextIndex(write, inProgressCnt); + if (writeFull(index)) { + return false; } - newWrite = write + 1L; - if (WRITE.weakCompareAndSetAcquire(this.writes, write, newWrite)) { - success = true; + final long newWrite = write + 1L; + final long prev = (long) WRITE.compareAndExchangeAcquire(this, write, newWrite); + if (prev == write) { + write = newWrite; break; } - write = (long) WRITE.getVolatile(this.writes); + write = prev; } - if (!success) { - return false; - } - buffer[(int) index] = item; - write = newWrite; + A.setRelease(this.a, (int) index, item); + long expected = write; while (true) { - final long n = ((write & DONE_MASK) >>> DONE_SHIFT) + 1L == (write & PENDING_MASK) - ? ((write >>> INDEX_SHIFT) + (write & PENDING_MASK) & mask) << INDEX_SHIFT - : write >>> INDEX_SHIFT == index - ? write + (1L << INDEX_SHIFT) - 1L & (mask << INDEX_SHIFT | DONE_PENDING_MASK) - : write + (1L << DONE_SHIFT); - if (WRITE.weakCompareAndSetRelease(this.writes, write, n)) { - break; + final long n = nextState(expected, index); + final long cmp = (long) WRITE.compareAndExchangeRelease(this, expected, n); + if (cmp == expected) { + return true; + } else { + expected = cmp; } - write = (long) WRITE.getVolatile(this.writes); - spinWait(attempts++); } - return true; } public @Nullable T recv() { - long read = (long) READ.getAcquire(this.reads); - boolean success; - long index = 0; - long newRead = 0L; - int attempts = 0; + long read = (long) READ.getAcquire(this); + long index; while (true) { - spinWait(attempts++); final long inProgressCnt = (read & PENDING_MASK); - if ((read >>> INDEX_SHIFT) == ((long) WRITE.getVolatile(this.writes) >>> INDEX_SHIFT)) { - success = false; - break; + if (readEmpty(read >>> INDEX_SHIFT)) { + return null; } if (inProgressCnt == MAX_IN_PROGRESS) { - read = (long) READ.getAcquire(this.reads); + spinWait(); + read = (long) READ.getAcquire(this); continue; } - index = ((read >>> INDEX_SHIFT) + inProgressCnt) & mask; - if ((index & mask) == ((long) WRITE.getVolatile(this.writes) >>> INDEX_SHIFT)) { - success = false; - break; + index = nextIndex(read, inProgressCnt); + if (readEmpty(index)) { + return null; } - newRead = read + 1L; - if (READ.weakCompareAndSetAcquire(this.reads, read, newRead)) { - success = true; + final long newRead = read + 1L; + final long prev = (long) READ.compareAndExchangeAcquire(this, read, newRead); + if (prev == read) { + read = newRead; break; } - read = (long) READ.getVolatile(this.reads); + read = prev; } - if (!success) { - return null; - } - final T result = buffer[(int) index]; - buffer[(int) index] = null; - read = newRead; + // noinspection unchecked + final T result = (T) A.getAndSetAcquire(this.a, (int) index, null); + long expected = read; while (true) { - final long n = ((read & DONE_MASK) >>> DONE_SHIFT) + 1L == (read & PENDING_MASK) - ? ((read >>> INDEX_SHIFT) + (read & PENDING_MASK) & mask) << INDEX_SHIFT - : read >>> INDEX_SHIFT == index - ? read + (1L << INDEX_SHIFT) - 1L & (mask << INDEX_SHIFT | DONE_PENDING_MASK) - : read + (1L << DONE_SHIFT); - if (READ.weakCompareAndSetRelease(this.reads, read, n)) { - break; + final long n = nextState(expected, index); + final long cmp = (long) READ.compareAndExchangeRelease(this, expected, n); + if (cmp == expected) { + return result; + } else { + expected = cmp; } - read = (long) READ.getVolatile(this.reads); - spinWait(attempts++); } - return result; + } + + private long nextIndex(final long read, final long pending) { + return ((read >>> INDEX_SHIFT) + pending) & this.mask; + } + + private static void spinWait() { + Thread.onSpinWait(); + } + + /// incrementing the done count and potentially advancing the index + /// + /// if done + 1 == pending (all operations complete) + /// increment index by pending, zero pending and done + /// + /// if index == idx (completing in order) + /// increment index, decrement pending, wrapping and preserve done + /// + /// else (skip index increment) + /// increment done + private long nextState(final long c, final long idx) { + return (((c & DONE_MASK) >>> DONE_SHIFT) + 1L) == (c & PENDING_MASK) + ? (((c >>> INDEX_SHIFT) + (c & PENDING_MASK)) & this.mask) << INDEX_SHIFT + : (c >>> INDEX_SHIFT) == idx + ? (c + DONE_PENDING_MASK) & ((this.mask << INDEX_SHIFT) | DONE_PENDING_MASK) + : c + (1L << DONE_SHIFT); + } + + /// write would cause the queue to become full + private boolean writeFull(final long wIdx) { + return ((wIdx + 1L) & this.mask) == ((long) READ.getVolatile(this) >>> INDEX_SHIFT); + } + + /// read would read an empty position + private boolean readEmpty(final long rIdx) { + return (rIdx & this.mask) == ((long) WRITE.getVolatile(this) >>> INDEX_SHIFT); } public int length() { - final long reads = (long) READ.getVolatile(this.reads); - final long writes = (long) WRITE.getVolatile(this.writes); + final long reads = (long) READ.getVolatile(this); + final long writes = (long) WRITE.getVolatile(this); final long readIndex = (reads >>> INDEX_SHIFT); final long writeIndex = (writes >>> INDEX_SHIFT); - return (int) (readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex); - // (readIndex <= writeIndex ? writeIndex - readIndex : writeIndex + capacity - readIndex) - (reads & PENDING_MASK) + final long len = (readIndex <= writeIndex + ? writeIndex - readIndex + : writeIndex + this.mask + 1L - readIndex); + return (int) (len - (reads & PENDING_MASK)); } public boolean isEmpty() { @@ -179,35 +235,13 @@ public boolean isEmpty() { } public int remaining() { - final long reads = (long) READ.getVolatile(this.reads); - final long writes = (long) WRITE.getVolatile(this.writes); + final long reads = (long) READ.getVolatile(this); + final long writes = (long) WRITE.getVolatile(this); final long readIndex = (reads >>> INDEX_SHIFT); final long writeIndex = (writes >>> INDEX_SHIFT); - final long len = readIndex <= writeIndex ? - writeIndex - readIndex : - writeIndex + capacity - readIndex; - return (int) (mask - len - (writes & PENDING_MASK)); - } - - @SuppressWarnings("unused") - public abstract static sealed class CachePadded permits ReadCounter, WriteCounter { - public final byte i0 = 0, i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5 = 0, i6 = 0, i7 = 0, i8 = 0, i9 = 0, i10 = 0, i11 = 0, i12 = 0, i13 = 0, i14 = 0, i15 = 0; - public final byte j0 = 0, j1 = 0, j2 = 0, j3 = 0, j4 = 0, j5 = 0, j6 = 0, j7 = 0, j8 = 0, j9 = 0, j10 = 0, j11 = 0, j12 = 0, j13 = 0, j14 = 0, j15 = 0; - public final byte k0 = 0, k1 = 0, k2 = 0, k3 = 0, k4 = 0, k5 = 0, k6 = 0, k7 = 0, k8 = 0, k9 = 0, k10 = 0, k11 = 0, k12 = 0, k13 = 0, k14 = 0, k15 = 0; - public final byte l0 = 0, l1 = 0, l2 = 0, l3 = 0, l4 = 0, l5 = 0, l6 = 0, l7 = 0, l8 = 0, l9 = 0, l10 = 0, l11 = 0, l12 = 0, l13 = 0, l14 = 0, l15 = 0; - public final byte m0 = 0, m1 = 0, m2 = 0, m3 = 0, m4 = 0, m5 = 0, m6 = 0, m7 = 0, m8 = 0, m9 = 0, m10 = 0, m11 = 0, m12 = 0, m13 = 0, m14 = 0, m15 = 0; - public final byte n0 = 0, n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0, n6 = 0, n7 = 0, n8 = 0, n9 = 0, n10 = 0, n11 = 0, n12 = 0, n13 = 0, n14 = 0, n15 = 0; - public final byte o0 = 0, o1 = 0, o2 = 0, o3 = 0, o4 = 0, o5 = 0, o6 = 0, o7 = 0, o8 = 0, o9 = 0, o10 = 0, o11 = 0, o12 = 0, o13 = 0, o14 = 0, o15 = 0; - public final byte p0 = 0, p1 = 0, p2 = 0, p3 = 0, p4 = 0, p5 = 0, p6 = 0, p7 = 0, p8 = 0, p9 = 0, p10 = 0, p11 = 0, p12 = 0, p13 = 0, p14 = 0, p15 = 0; - } - - private static final class ReadCounter extends CachePadded { - @SuppressWarnings("unused") - private volatile long reads; - } - - private static final class WriteCounter extends CachePadded { - @SuppressWarnings("unused") - private volatile long writes; + final long len = readIndex <= writeIndex + ? writeIndex - readIndex + : writeIndex + this.mask + 1L - readIndex; + return (int) (this.mask - len - (writes & PENDING_MASK)); } }