Merge "Lock on index element id rather than on index task id" into stable-3.0
diff --git a/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandler.java b/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandler.java
index b311db9..368581c 100644
--- a/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandler.java
+++ b/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandler.java
@@ -171,6 +171,8 @@
     }
 
     abstract CompletableFuture<Boolean> execute();
+
+    abstract String indexId();
   }
 
   class IndexChangeTask extends IndexTask {
@@ -206,6 +208,11 @@
     public String toString() {
       return String.format("[%s] Index change %s in target instance", pluginName, changeId);
     }
+
+    @Override
+    String indexId() {
+      return "change/" + changeId;
+    }
   }
 
   class DeleteChangeTask extends IndexTask {
@@ -239,6 +246,11 @@
     public String toString() {
       return String.format("[%s] Delete change %s in target instance", pluginName, changeId);
     }
+
+    @Override
+    String indexId() {
+      return "change/" + changeId;
+    }
   }
 
   class IndexAccountTask extends IndexTask {
@@ -271,6 +283,11 @@
     public String toString() {
       return String.format("[%s] Index account %s in target instance", pluginName, accountId);
     }
+
+    @Override
+    String indexId() {
+      return "account/" + accountId;
+    }
   }
 
   class IndexGroupTask extends IndexTask {
@@ -303,6 +320,11 @@
     public String toString() {
       return String.format("[%s] Index group %s in target instance", pluginName, groupUUID);
     }
+
+    @Override
+    String indexId() {
+      return "group/" + groupUUID;
+    }
   }
 
   class IndexProjectTask extends IndexTask {
@@ -335,5 +357,10 @@
     public String toString() {
       return String.format("[%s] Index project %s in target instance", pluginName, projectName);
     }
+
+    @Override
+    String indexId() {
+      return "project/" + projectName;
+    }
   }
 }
diff --git a/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventLocks.java b/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventLocks.java
index 4434c34..332a7d6 100644
--- a/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventLocks.java
+++ b/src/main/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventLocks.java
@@ -38,28 +38,40 @@
     this.waitTimeout = cfg.index().waitTimeout();
   }
 
-  public void withLock(
+  public CompletableFuture<?> withLock(
       IndexTask id, IndexCallFunction function, VoidFunction lockAcquireTimeoutCallback) {
-    Lock idLock = getLock(id);
+    String indexId = id.indexId();
+    Lock idLock = getLock(indexId);
     try {
       if (idLock.tryLock(waitTimeout, TimeUnit.MILLISECONDS)) {
-        function
+        return function
             .invoke()
             .whenComplete(
                 (result, error) -> {
                   idLock.unlock();
                 });
-      } else {
-        lockAcquireTimeoutCallback.invoke();
       }
+
+      String timeoutMessage =
+          String.format(
+              "Acquisition of the locking of %s timed out after %d msec: consider increasing the number of shards",
+              indexId, waitTimeout);
+      log.atWarning().log(timeoutMessage);
+      lockAcquireTimeoutCallback.invoke();
+      CompletableFuture<?> failureFuture = new CompletableFuture<>();
+      failureFuture.completeExceptionally(new InterruptedException(timeoutMessage));
+      return failureFuture;
     } catch (InterruptedException e) {
-      log.atSevere().withCause(e).log("%s was interrupted; giving up", id);
+      CompletableFuture<?> failureFuture = new CompletableFuture<>();
+      failureFuture.completeExceptionally(e);
+      log.atSevere().withCause(e).log("Locking of %s was interrupted; giving up", indexId);
+      return failureFuture;
     }
   }
 
   @VisibleForTesting
-  protected Lock getLock(IndexTask id) {
-    return locks.get(id);
+  protected Lock getLock(String indexId) {
+    return locks.get(indexId);
   }
 
   @FunctionalInterface
diff --git a/src/test/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandlerTest.java b/src/test/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandlerTest.java
index 7097a3a..ffad41f 100644
--- a/src/test/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandlerTest.java
+++ b/src/test/java/com/ericsson/gerrit/plugins/highavailability/index/IndexEventHandlerTest.java
@@ -16,6 +16,7 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -32,7 +33,8 @@
 import com.ericsson.gerrit.plugins.highavailability.index.IndexEventHandler.IndexAccountTask;
 import com.ericsson.gerrit.plugins.highavailability.index.IndexEventHandler.IndexChangeTask;
 import com.ericsson.gerrit.plugins.highavailability.index.IndexEventHandler.IndexGroupTask;
-import com.ericsson.gerrit.plugins.highavailability.index.IndexEventHandler.IndexProjectTask;
+import com.ericsson.gerrit.plugins.highavailability.index.IndexEventHandler.IndexTask;
+import com.ericsson.gerrit.plugins.highavailability.index.IndexEventLocks.VoidFunction;
 import com.google.gerrit.reviewdb.client.Account;
 import com.google.gerrit.reviewdb.client.AccountGroup;
 import com.google.gerrit.reviewdb.client.Change;
@@ -44,15 +46,19 @@
 import java.util.Optional;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.locks.Lock;
 import java.util.function.Consumer;
+import java.util.function.Supplier;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -68,6 +74,8 @@
   private static final int ACCOUNT_ID = 2;
   private static final String UUID = "3";
   private static final String OTHER_UUID = "4";
+  private static final Integer INDEX_WAIT_TIMEOUT_MS = 5;
+  private static final int MAX_TEST_PARALLELISM = 4;
 
   private IndexEventHandler indexEventHandler;
   @Mock private Forwarder forwarder;
@@ -77,6 +85,8 @@
   private Account.Id accountId;
   private AccountGroup.UUID accountGroupUUID;
   private ScheduledExecutorService executor = new CurrentThreadScheduledExecutorService();
+  private ScheduledExecutorService testExecutor =
+      Executors.newScheduledThreadPool(MAX_TEST_PARALLELISM);
   @Mock private RequestContext mockCtx;
   @Mock private Configuration configuration;
   private IndexEventLocks idLocks;
@@ -100,7 +110,7 @@
     Configuration.Index cfgIndex = mock(Configuration.Index.class);
     when(configuration.index()).thenReturn(cfgIndex);
     when(cfgIndex.numStripedLocks()).thenReturn(Configuration.DEFAULT_NUM_STRIPED_LOCKS);
-    when(cfgIndex.waitTimeout()).thenReturn(Configuration.DEFAULT_TIMEOUT_MS);
+    when(cfgIndex.waitTimeout()).thenReturn(INDEX_WAIT_TIMEOUT_MS);
 
     Configuration.Http http = mock(Configuration.Http.class);
     when(configuration.http()).thenReturn(http);
@@ -164,7 +174,7 @@
   public void shouldNotIndexChangeWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexChangeTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
     setUpIndexEventHandler(currCtx, locks);
@@ -178,7 +188,7 @@
   public void shouldNotIndexAccountWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexAccountTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     setUpIndexEventHandler(currCtx, locks);
@@ -192,7 +202,7 @@
   public void shouldNotDeleteChangeWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(DeleteChangeTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     setUpIndexEventHandler(currCtx, locks);
@@ -206,7 +216,7 @@
   public void shouldNotIndexGroupWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexGroupTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     setUpIndexEventHandler(currCtx, locks);
@@ -220,7 +230,7 @@
   public void shouldNotIndexProjectWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexProjectTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     setUpIndexEventHandler(currCtx, locks);
@@ -234,7 +244,7 @@
   public void shouldRetryIndexChangeWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexChangeTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)))
         .thenReturn(false, true);
@@ -250,7 +260,7 @@
   public void shouldRetryUpToMaxTriesWhenCannotAcquireLock() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexChangeTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
 
@@ -270,7 +280,7 @@
   public void shouldNotRetryWhenMaxTriesLowerThanOne() throws Exception {
     IndexEventLocks locks = mock(IndexEventLocks.class);
     Lock lock = mock(Lock.class);
-    when(locks.getLock(any(IndexChangeTask.class))).thenReturn(lock);
+    when(locks.getLock(anyString())).thenReturn(lock);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     when(lock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS))).thenReturn(false);
 
@@ -295,8 +305,8 @@
     Lock accountChangeLock = mock(Lock.class);
     when(accountChangeLock.tryLock(Mockito.anyLong(), Mockito.eq(TimeUnit.MILLISECONDS)))
         .thenReturn(true);
-    when(locks.getLock(any(IndexChangeTask.class))).thenReturn(indexChangeLock);
-    when(locks.getLock(any(IndexAccountTask.class))).thenReturn(accountChangeLock);
+    when(locks.getLock(eq("change/" + CHANGE_ID))).thenReturn(indexChangeLock);
+    when(locks.getLock(eq("account/" + ACCOUNT_ID))).thenReturn(accountChangeLock);
     Mockito.doCallRealMethod().when(locks).withLock(any(), any(), any());
     setUpIndexEventHandler(currCtx, locks);
 
@@ -531,6 +541,134 @@
     assertThat(task.hashCode()).isNotEqualTo(differentGroupIdTask.hashCode());
   }
 
+  class TestTask<T> implements Runnable {
+    private IndexTask task;
+    private CyclicBarrier testBarrier;
+    private Supplier<T> successFunc;
+    private VoidFunction failureFunc;
+    private CompletableFuture<T> future;
+
+    public TestTask(
+        IndexTask task,
+        CyclicBarrier testBarrier,
+        Supplier<T> successFunc,
+        VoidFunction failureFunc) {
+      this.task = task;
+      this.testBarrier = testBarrier;
+      this.successFunc = successFunc;
+      this.failureFunc = failureFunc;
+      this.future = new CompletableFuture<>();
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void run() {
+      try {
+        testBarrier.await();
+        idLocks
+            .withLock(
+                task,
+                () ->
+                    runLater(
+                        INDEX_WAIT_TIMEOUT_MS * 2,
+                        () -> CompletableFuture.completedFuture(successFunc.get())),
+                failureFunc)
+            .whenComplete(
+                (v, t) -> {
+                  if (t == null) {
+                    future.complete((T) v);
+                  } else {
+                    future.completeExceptionally(t);
+                  }
+                });
+      } catch (Throwable t) {
+        future = new CompletableFuture<>();
+        future.completeExceptionally(t);
+      }
+    }
+
+    public void join() {
+      try {
+        future.join();
+      } catch (Exception e) {
+      }
+    }
+
+    private CompletableFuture<T> runLater(
+        long scheduledTimeMsec, Supplier<CompletableFuture<T>> supplier) {
+      CompletableFuture<T> resFuture = new CompletableFuture<>();
+      testExecutor.schedule(
+          () -> {
+            try {
+              return supplier
+                  .get()
+                  .whenComplete(
+                      (v, t) -> {
+                        if (t == null) {
+                          resFuture.complete(v);
+                        }
+                        resFuture.completeExceptionally(t);
+                      });
+            } catch (Throwable t) {
+              return resFuture.completeExceptionally(t);
+            }
+          },
+          scheduledTimeMsec,
+          TimeUnit.MILLISECONDS);
+      return resFuture;
+    }
+  }
+
+  @Test
+  public void indexLocksShouldBlockConcurrentIndexChange() throws Exception {
+    IndexChangeTask indexTask1 =
+        indexEventHandler.new IndexChangeTask(PROJECT_NAME, CHANGE_ID, new IndexEvent());
+    IndexChangeTask indexTask2 =
+        indexEventHandler.new IndexChangeTask(PROJECT_NAME, CHANGE_ID, new IndexEvent());
+    testIsolationOfCuncurrentIndexTasks(indexTask1, indexTask2);
+  }
+
+  @Test
+  public void indexLocksShouldBlockConcurrentIndexAndDeleteChange() throws Exception {
+    IndexChangeTask indexTask =
+        indexEventHandler.new IndexChangeTask(PROJECT_NAME, CHANGE_ID, new IndexEvent());
+    DeleteChangeTask deleteTask =
+        indexEventHandler.new DeleteChangeTask(CHANGE_ID, new IndexEvent());
+    testIsolationOfCuncurrentIndexTasks(indexTask, deleteTask);
+  }
+
+  private void testIsolationOfCuncurrentIndexTasks(IndexTask indexTask1, IndexTask indexTask2)
+      throws Exception {
+    AtomicInteger changeIndexedCount = new AtomicInteger();
+    AtomicInteger lockFailedCounts = new AtomicInteger();
+    CyclicBarrier changeThreadsSync = new CyclicBarrier(2);
+
+    TestTask<Integer> task1 =
+        new TestTask<>(
+            indexTask1,
+            changeThreadsSync,
+            () -> changeIndexedCount.incrementAndGet(),
+            () -> lockFailedCounts.incrementAndGet());
+    TestTask<Integer> task2 =
+        new TestTask<>(
+            indexTask2,
+            changeThreadsSync,
+            () -> changeIndexedCount.incrementAndGet(),
+            () -> lockFailedCounts.incrementAndGet());
+
+    new Thread(task1).start();
+    new Thread(task2).start();
+    task1.join();
+    task2.join();
+
+    /* Both assertions needs to be true, the order doesn't really matter:
+     * - Only one of the two tasks should succeed
+     * - Only one of the two tasks should fail to acquire the lock
+     */
+    assertThat(changeIndexedCount.get()).isEqualTo(1);
+    assertThat(lockFailedCounts.get()).isEqualTo(1);
+  }
+
   private class CurrentThreadScheduledExecutorService implements ScheduledExecutorService {
 
     @Override