Add support for listening to WorkQueue.Tasks from plugins

Add a TaskListener interface to the WorkQueue which enables implementors
to be called directly before WorkQueue.Tasks run and directly after they
complete. This listening can be used to potentially keep track of
resources and even delay tasks from running until resources are
available. This extension point makes it possible for plugins to define
and implement WorkQueue QOS policies.

Release-Notes: Added a WorkQueue.TaskListener extension for plugins
Change-Id: I28907e27101cd7a0bcdbe5f6aced4afaeeff97e0
diff --git a/Documentation/dev-plugins.txt b/Documentation/dev-plugins.txt
index ca72f8b..33c5bbd 100644
--- a/Documentation/dev-plugins.txt
+++ b/Documentation/dev-plugins.txt
@@ -533,6 +533,24 @@
 Certain operations in Gerrit can be validated by plugins by
 implementing the corresponding link:config-validation.html[listeners].
 
+[[taskListeners]]
+== WorkQueue.TaskListeners
+
+It is possible for plugins to listen to
+`com.google.gerrit.server.git.WorkQueue$Task`s directly before they run, and
+directly after they complete. This may be used to delay task executions based
+on custom criteria by blocking, likely on a lock or semaphore, inside
+onStart(), and a lock/semaphore release in onStop(). Plugins may listen to
+tasks by implementing a `com.google.gerrit.server.git.WorkQueue$TaskListener`
+and registering the new listener like this:
+
+[source,java]
+----
+bind(TaskListener.class)
+    .annotatedWith(Exports.named("MyListener"))
+    .to(MyListener.class);
+----
+
 [[change-message-modifier]]
 == Change Message Modifier
 
diff --git a/java/com/google/gerrit/pgm/Reindex.java b/java/com/google/gerrit/pgm/Reindex.java
index c4e185d..762d988 100644
--- a/java/com/google/gerrit/pgm/Reindex.java
+++ b/java/com/google/gerrit/pgm/Reindex.java
@@ -36,6 +36,7 @@
 import com.google.gerrit.server.cache.CacheInfo;
 import com.google.gerrit.server.change.ChangeResource;
 import com.google.gerrit.server.config.GerritServerConfig;
+import com.google.gerrit.server.git.WorkQueue.WorkQueueModule;
 import com.google.gerrit.server.index.IndexModule;
 import com.google.gerrit.server.index.change.ChangeSchemaDefinitions;
 import com.google.gerrit.server.index.options.AutoFlush;
@@ -175,6 +176,8 @@
     }
     boolean replica = ReplicaUtil.isReplica(globalConfig);
     List<Module> modules = new ArrayList<>();
+    modules.add(new WorkQueueModule());
+
     Module indexModule;
     IndexType indexType = IndexModule.getIndexType(dbInjector);
     if (indexType.isLucene()) {
diff --git a/java/com/google/gerrit/server/git/WorkQueue.java b/java/com/google/gerrit/server/git/WorkQueue.java
index 3032bfe..715cb30 100644
--- a/java/com/google/gerrit/server/git/WorkQueue.java
+++ b/java/com/google/gerrit/server/git/WorkQueue.java
@@ -20,6 +20,7 @@
 import com.google.common.flogger.FluentLogger;
 import com.google.gerrit.entities.Project;
 import com.google.gerrit.extensions.events.LifecycleListener;
+import com.google.gerrit.extensions.registration.DynamicMap;
 import com.google.gerrit.lifecycle.LifecycleModule;
 import com.google.gerrit.metrics.Description;
 import com.google.gerrit.metrics.MetricMaker;
@@ -27,6 +28,7 @@
 import com.google.gerrit.server.config.ScheduleConfig.Schedule;
 import com.google.gerrit.server.logging.LoggingContext;
 import com.google.gerrit.server.logging.LoggingContextAwareRunnable;
+import com.google.gerrit.server.plugincontext.PluginMapContext;
 import com.google.gerrit.server.util.IdGenerator;
 import com.google.inject.Inject;
 import com.google.inject.Singleton;
@@ -58,6 +60,30 @@
 public class WorkQueue {
   private static final FluentLogger logger = FluentLogger.forEnclosingClass();
 
+  /**
+   * To register a TaskListener, which will be called directly before Tasks run, and directly after
+   * they complete, bind the TaskListener like this:
+   *
+   * <p><code>
+   *   bind(TaskListener.class)
+   *       .annotatedWith(Exports.named("MyListener"))
+   *       .to(MyListener.class);
+   * </code>
+   */
+  public interface TaskListener {
+    public static class NoOp implements TaskListener {
+      @Override
+      public void onStart(Task<?> task) {}
+
+      @Override
+      public void onStop(Task<?> task) {}
+    }
+
+    void onStart(Task<?> task);
+
+    void onStop(Task<?> task);
+  }
+
   public static class Lifecycle implements LifecycleListener {
     private final WorkQueue workQueue;
 
@@ -78,6 +104,7 @@
   public static class WorkQueueModule extends LifecycleModule {
     @Override
     protected void configure() {
+      DynamicMap.mapOf(binder(), WorkQueue.TaskListener.class);
       bind(WorkQueue.class);
       listener().to(Lifecycle.class);
     }
@@ -87,18 +114,32 @@
   private final IdGenerator idGenerator;
   private final MetricMaker metrics;
   private final CopyOnWriteArrayList<Executor> queues;
+  private final PluginMapContext<TaskListener> listeners;
 
   @Inject
-  WorkQueue(IdGenerator idGenerator, @GerritServerConfig Config cfg, MetricMaker metrics) {
-    this(idGenerator, Math.max(cfg.getInt("execution", "defaultThreadPoolSize", 2), 2), metrics);
+  WorkQueue(
+      IdGenerator idGenerator,
+      @GerritServerConfig Config cfg,
+      MetricMaker metrics,
+      PluginMapContext<TaskListener> listeners) {
+    this(
+        idGenerator,
+        Math.max(cfg.getInt("execution", "defaultThreadPoolSize", 2), 2),
+        metrics,
+        listeners);
   }
 
   /** Constructor to allow binding the WorkQueue more explicitly in a vhost setup. */
-  public WorkQueue(IdGenerator idGenerator, int defaultThreadPoolSize, MetricMaker metrics) {
+  public WorkQueue(
+      IdGenerator idGenerator,
+      int defaultThreadPoolSize,
+      MetricMaker metrics,
+      PluginMapContext<TaskListener> listeners) {
     this.idGenerator = idGenerator;
     this.metrics = metrics;
     this.queues = new CopyOnWriteArrayList<>();
     this.defaultQueue = createQueue(defaultThreadPoolSize, "WorkQueue", true);
+    this.listeners = listeners;
   }
 
   /** Get the default work queue, for miscellaneous tasks. */
@@ -438,6 +479,14 @@
     Collection<Task<?>> getTasks() {
       return all.values();
     }
+
+    public void onStart(Task<?> task) {
+      listeners.runEach(extension -> extension.getProvider().get().onStart(task));
+    }
+
+    public void onStop(Task<?> task) {
+      listeners.runEach(extension -> extension.getProvider().get().onStop(task));
+    }
   }
 
   private static void logUncaughtException(Thread t, Throwable e) {
@@ -608,10 +657,12 @@
       if (running.compareAndSet(false, true)) {
         String oldThreadName = Thread.currentThread().getName();
         try {
+          executor.onStart(this);
           Thread.currentThread().setName(oldThreadName + "[" + task.toString() + "]");
           task.run();
         } finally {
           Thread.currentThread().setName(oldThreadName);
+          executor.onStop(this);
           if (isPeriodic()) {
             running.set(false);
           } else {
diff --git a/java/com/google/gerrit/testing/InMemoryModule.java b/java/com/google/gerrit/testing/InMemoryModule.java
index b00cadb..781965e 100644
--- a/java/com/google/gerrit/testing/InMemoryModule.java
+++ b/java/com/google/gerrit/testing/InMemoryModule.java
@@ -79,6 +79,7 @@
 import com.google.gerrit.server.git.PerThreadRequestScope;
 import com.google.gerrit.server.git.SearchingChangeCacheImpl.SearchingChangeCacheImplModule;
 import com.google.gerrit.server.git.WorkQueue;
+import com.google.gerrit.server.git.WorkQueue.WorkQueueModule;
 import com.google.gerrit.server.group.testing.TestGroupBackend;
 import com.google.gerrit.server.index.account.AccountSchemaDefinitions;
 import com.google.gerrit.server.index.account.AllAccountsIndexer;
@@ -195,6 +196,7 @@
     install(new AuditModule());
     install(new SubscriptionGraphModule());
     install(new SuperprojectUpdateSubmissionListenerModule());
+    install(new WorkQueueModule());
 
     bindScope(RequestScoped.class, PerThreadRequestScope.REQUEST);
 
diff --git a/javatests/com/google/gerrit/acceptance/server/util/TaskListenerIT.java b/javatests/com/google/gerrit/acceptance/server/util/TaskListenerIT.java
new file mode 100644
index 0000000..b3094ac
--- /dev/null
+++ b/javatests/com/google/gerrit/acceptance/server/util/TaskListenerIT.java
@@ -0,0 +1,255 @@
+// Copyright (C) 2022 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.gerrit.acceptance.server.util;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.gerrit.acceptance.AbstractDaemonTest;
+import com.google.gerrit.extensions.annotations.Exports;
+import com.google.gerrit.server.git.WorkQueue;
+import com.google.gerrit.server.git.WorkQueue.Task;
+import com.google.gerrit.server.git.WorkQueue.TaskListener;
+import com.google.inject.AbstractModule;
+import com.google.inject.Inject;
+import com.google.inject.Module;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TaskListenerIT extends AbstractDaemonTest {
+  /**
+   * Use a LatchedMethod in a method to allow another thread to await the method's call. Once
+   * called, the Latch.call() method will block until another thread calls its LatchedMethods's
+   * complete() method.
+   */
+  private static class LatchedMethod {
+    private static final int AWAIT_TIMEOUT = 20;
+    private static final TimeUnit AWAIT_TIMEUNIT = TimeUnit.MILLISECONDS;
+
+    /** API class meant be used by the class whose method is being latched */
+    private class Latch {
+      /** Ensure that the latched method calls this on entry */
+      public void call() {
+        called.countDown();
+        await(complete);
+      }
+    }
+
+    public Latch latch = new Latch();
+
+    private CountDownLatch called = new CountDownLatch(1);
+    private CountDownLatch complete = new CountDownLatch(1);
+
+    /** Assert that the Latch's call() method has not yet been called */
+    public void assertUncalled() {
+      assertThat(called.getCount()).isEqualTo(1);
+    }
+
+    /**
+     * Assert that a timeout does not occur while awaiting Latch's call() method to be called. Fails
+     * if the waiting time elapses before Latch's call() method is called, otherwise passes.
+     */
+    public void assertAwait() {
+      assertThat(await(called)).isEqualTo(true);
+    }
+
+    /** Unblock the Latch's call() method so that it can complete */
+    public void complete() {
+      complete.countDown();
+    }
+
+    private static boolean await(CountDownLatch latch) {
+      try {
+        return latch.await(AWAIT_TIMEOUT, AWAIT_TIMEUNIT);
+      } catch (InterruptedException e) {
+        return false;
+      }
+    }
+  }
+
+  private static class LatchedRunnable implements Runnable {
+    public LatchedMethod run = new LatchedMethod();
+
+    @Override
+    public void run() {
+      run.latch.call();
+    }
+  }
+
+  private static class ForwardingListener implements TaskListener {
+    public volatile TaskListener delegate;
+    public volatile Task task;
+
+    public void resetDelegate(TaskListener listener) {
+      delegate = listener;
+      task = null;
+    }
+
+    @Override
+    public void onStart(Task<?> task) {
+      if (delegate != null) {
+        if (this.task == null || this.task == task) {
+          this.task = task;
+          delegate.onStart(task);
+        }
+      }
+    }
+
+    @Override
+    public void onStop(Task<?> task) {
+      if (delegate != null) {
+        if (this.task == task) {
+          delegate.onStop(task);
+        }
+      }
+    }
+  }
+
+  private static class LatchedListener implements TaskListener {
+    public LatchedMethod onStart = new LatchedMethod();
+    public LatchedMethod onStop = new LatchedMethod();
+
+    @Override
+    public void onStart(Task<?> task) {
+      onStart.latch.call();
+    }
+
+    @Override
+    public void onStop(Task<?> task) {
+      onStop.latch.call();
+    }
+  }
+
+  private static ForwardingListener forwarder;
+
+  @Inject private WorkQueue workQueue;
+  private ScheduledExecutorService executor;
+
+  private LatchedListener listener = new LatchedListener();
+  private LatchedRunnable runnable = new LatchedRunnable();
+
+  @Override
+  public Module createModule() {
+    return new AbstractModule() {
+      @Override
+      public void configure() {
+        // Forwarder.delegate is empty on start to protect test listener from non test tasks
+        // (such as the "Log File Compressor") interference
+        forwarder = new ForwardingListener(); // Only gets bound once for all tests
+        bind(TaskListener.class).annotatedWith(Exports.named("listener")).toInstance(forwarder);
+      }
+    };
+  }
+
+  @Before
+  public void setupExecutorAndForwarder() throws InterruptedException {
+    executor = workQueue.createQueue(1, "TaskListeners");
+
+    // "Log File Compressor"s are likely running and will interfere with tests
+    while (0 != workQueue.getTasks().size()) {
+      for (Task<?> t : workQueue.getTasks()) {
+        t.cancel(true);
+      }
+      TimeUnit.MILLISECONDS.sleep(1);
+    }
+
+    forwarder.resetDelegate(listener);
+
+    assertQueueSize(0);
+    assertThat(forwarder.task).isEqualTo(null);
+    listener.onStart.assertUncalled();
+    runnable.run.assertUncalled();
+    listener.onStop.assertUncalled();
+  }
+
+  @Test
+  public void onStartThenRunThenOnStopAreCalled() throws Exception {
+    int size = assertQueueBlockedOnExecution(runnable);
+
+    // onStartThenRunThenOnStopAreCalled -> onStart...Called
+    listener.onStart.assertAwait();
+    assertQueueSize(size);
+    runnable.run.assertUncalled();
+    listener.onStop.assertUncalled();
+
+    listener.onStart.complete();
+    // onStartThenRunThenOnStopAreCalled -> ...ThenRun...Called
+    runnable.run.assertAwait();
+    listener.onStop.assertUncalled();
+
+    runnable.run.complete();
+    // onStartThenRunThenOnStopAreCalled -> ...ThenOnStop...Called
+    listener.onStop.assertAwait();
+    assertQueueSize(size);
+
+    listener.onStop.complete();
+    assertAwaitQueueSize(--size);
+  }
+
+  @Test
+  public void firstBlocksSecond() throws Exception {
+    int size = assertQueueBlockedOnExecution(runnable);
+
+    // firstBlocksSecond -> first...
+    listener.onStart.assertAwait();
+    assertQueueSize(size);
+
+    LatchedRunnable runnable2 = new LatchedRunnable();
+    size = assertQueueBlockedOnExecution(runnable2);
+
+    // firstBlocksSecond -> ...BlocksSecond
+    runnable2.run.assertUncalled();
+    assertQueueSize(size); // waiting on first
+
+    listener.onStart.complete();
+    runnable.run.assertAwait();
+    assertQueueSize(size); // waiting on first
+    runnable2.run.assertUncalled();
+
+    runnable.run.complete();
+    listener.onStop.assertAwait();
+    assertQueueSize(size); // waiting on first
+    runnable2.run.assertUncalled();
+
+    listener.onStop.complete();
+    runnable2.run.assertAwait();
+    assertQueueSize(--size);
+
+    runnable2.run.complete();
+    assertAwaitQueueSize(--size);
+  }
+
+  private int assertQueueBlockedOnExecution(Runnable runnable) {
+    int expectedSize = workQueue.getTasks().size() + 1;
+    executor.execute(runnable);
+    assertQueueSize(expectedSize);
+    return expectedSize;
+  }
+
+  private void assertQueueSize(int size) {
+    assertThat(workQueue.getTasks().size()).isEqualTo(size);
+  }
+
+  /** Fails if the waiting time elapses before the count is reached, otherwise passes */
+  private void assertAwaitQueueSize(int size) throws InterruptedException {
+    long i = 0;
+    do {
+      TimeUnit.NANOSECONDS.sleep(10);
+      assertThat(i++).isLessThan(100);
+    } while (size != workQueue.getTasks().size());
+  }
+}