Fix heartbeat task leak after lock reclaim

Reclaiming an existing lock (re)started the heartbeat task but it missed
to set the heartbeatTask reference. After a reclaimed lock was closed
the heartbeat task wasn't stopped because the heartbeatTask reference
was null. The heartbeat thread continued to run on the scheduled rate
indefinitely.

Make sure to set the heartbeatTask reference after a lock is reclaimed.
Extend the tests to verify the heartbeat task is running/stopped.

Change-Id: Iecb7b77f54ad416a3a4de27bf26942270cfc3c1e
diff --git a/src/main/java/com/googlesource/gerrit/plugins/spannerrefdb/Lock.java b/src/main/java/com/googlesource/gerrit/plugins/spannerrefdb/Lock.java
index d65e216..0dde4ed 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/spannerrefdb/Lock.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/spannerrefdb/Lock.java
@@ -56,7 +56,7 @@
   }
 
   private static final Long SECONDS_FOR_STALE_HEARTBEAT = 30L;
-  private static final Duration HEARTBEAT_INTERVAL = Duration.ofSeconds(2);
+  public static final Duration HEARTBEAT_INTERVAL = Duration.ofSeconds(2);
   private static final String RECLAIM_LOCK_PREFIX = "RECLAIM";
   private final DatabaseClient dbClient;
   private final String gerritInstanceId;
@@ -168,8 +168,9 @@
 
     if (success) {
       token = transactionRunner.getCommitTimestamp();
-      heartbeatExecutor.scheduleAtFixedRate(
-          this::heartbeat, 0, HEARTBEAT_INTERVAL.toSeconds(), TimeUnit.SECONDS);
+      heartbeatTask =
+          heartbeatExecutor.scheduleAtFixedRate(
+              this::heartbeat, 0, HEARTBEAT_INTERVAL.toSeconds(), TimeUnit.SECONDS);
       logger.atFine().log("Reclaimed lock for %s %s.", projectName, refName);
     }
     return success;
diff --git a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/EmulatedSpannerRefDb.java b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/EmulatedSpannerRefDb.java
index 5ecfcc7..c5aa09b 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/EmulatedSpannerRefDb.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/EmulatedSpannerRefDb.java
@@ -27,8 +27,7 @@
 import com.google.cloud.spanner.SpannerOptions;
 import java.util.Collections;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import org.testcontainers.containers.SpannerEmulatorContainer;
 import org.testcontainers.utility.DockerImageName;
@@ -47,7 +46,7 @@
   private final Database spannerDatabase;
   private final DatabaseId dbId;
   private final DatabaseClient databaseClient;
-  private final ScheduledExecutorService heartbeatExecutor;
+  private final ScheduledThreadPoolExecutor heartbeatExecutor;
   private final SpannerRefDatabase spannerRefDb;
 
   public EmulatedSpannerRefDb() throws Exception {
@@ -69,7 +68,7 @@
         spannerInstance.createDatabase(SPANNER_DATABASE_ID, Collections.emptyList()).get();
     createSchema();
     databaseClient = createDatabaseClient();
-    heartbeatExecutor = Executors.newScheduledThreadPool(2);
+    heartbeatExecutor = new ScheduledThreadPoolExecutor(2);
     Lock.Factory lockFactory =
         new Lock.Factory() {
           @Override
@@ -96,6 +95,11 @@
   }
 
   @Override
+  ScheduledThreadPoolExecutor heartbeatExecutor() {
+    return heartbeatExecutor;
+  }
+
+  @Override
   public void cleanup() {
     heartbeatExecutor.shutdownNow();
     try {
diff --git a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/LockTest.java b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/LockTest.java
index e8beaa8..a88052b 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/LockTest.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/LockTest.java
@@ -23,6 +23,8 @@
 import com.google.cloud.spanner.Key;
 import com.google.cloud.spanner.Mutation;
 import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.TransactionRunner;
+import com.google.cloud.spanner.Value;
 import com.google.gerrit.entities.Project;
 import java.util.Arrays;
 import java.util.Collections;
@@ -30,6 +32,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -38,6 +41,7 @@
   private SpannerTestSystem testSystem;
   private SpannerRefDatabase refDb;
   private DatabaseClient dbClient;
+  private ScheduledThreadPoolExecutor heartbeatExecutor;
 
   private String TOKEN = "token";
   private String HEARTBEAT = "heartbeat";
@@ -47,6 +51,7 @@
     testSystem = SpannerTestSystem.create();
     refDb = testSystem.database();
     dbClient = testSystem.dbClient();
+    heartbeatExecutor = testSystem.heartbeatExecutor();
   }
 
   @After
@@ -58,18 +63,11 @@
   public void lockUnlockedRef_Success() throws Exception {
     try (AutoCloseable refLock = refDb.lockRef(PROJECT_NAME_KEY, REF_NAME)) {
       assertThat(getLockTimestamp(PROJECT_NAME_KEY, REF_NAME, TOKEN)).isNotNull();
-
-      Timestamp heartbeat = getLockTimestamp(PROJECT_NAME_KEY, REF_NAME, HEARTBEAT);
-      try {
-        Thread.sleep(4_000);
-      } catch (InterruptedException ie) {
-        Thread.currentThread().interrupt();
-      }
-      assertThat(getLockTimestamp(PROJECT_NAME_KEY, REF_NAME, HEARTBEAT).compareTo(heartbeat))
-          .isGreaterThan(0);
+      assertHeartbeatTaskIsRunning();
+      assertHeatbeatUpdatesTimestamp(PROJECT_NAME_KEY, REF_NAME);
     }
-
     assertThat(getLockRow(PROJECT_NAME_KEY, REF_NAME)).isNull();
+    assertHeartbeatTaskIsStopped();
   }
 
   @Test
@@ -82,6 +80,7 @@
           .hasMessageThat()
           .contains(
               String.format("Unable to lock ref %s on project %s", REF_NAME, PROJECT_NAME_KEY));
+      assertHeartbeatTaskIsRunning();
     }
   }
 
@@ -95,6 +94,7 @@
     try (AutoCloseable refLock = refDb.lockRef(PROJECT_NAME_KEY, REF_NAME)) {
       Timestamp newLockToken = getLockTimestamp(PROJECT_NAME_KEY, REF_NAME, TOKEN);
       assertThat(lockToken).isNotEqualTo(newLockToken);
+      assertHeartbeatTaskIsRunning();
     }
   }
 
@@ -106,7 +106,9 @@
       Timestamp newLockToken = getLockTimestamp(PROJECT_NAME_KEY, REF_NAME, TOKEN);
       assertThat(newLockToken).isNotNull();
       assertThat(staleTimestamp).isLessThan(newLockToken);
+      assertHeartbeatTaskIsRunning();
     }
+    assertHeartbeatTaskIsStopped();
   }
 
   @Test
@@ -149,6 +151,22 @@
     pool.shutdown();
   }
 
+  @Test
+  public void heartbeatStopsWhenLockNotFound() throws Exception {
+    try (AutoCloseable refLock = refDb.lockRef(PROJECT_NAME_KEY, REF_NAME)) {
+      deleteLockRow(PROJECT_NAME_KEY, REF_NAME);
+      assertHeartbeatTaskIsStopped();
+    }
+  }
+
+  @Test
+  public void heartbeatStopsWhenTokenDoesNotMatch() throws Exception {
+    try (AutoCloseable refLock = refDb.lockRef(PROJECT_NAME_KEY, REF_NAME)) {
+      updateLockToken(PROJECT_NAME_KEY, REF_NAME);
+      assertHeartbeatTaskIsStopped();
+    }
+  }
+
   private boolean tryLockAwaitBarrier(CyclicBarrier barrier) throws Exception {
     try (AutoCloseable refLock = refDb.lockRef(PROJECT_NAME_KEY, REF_NAME)) {
       barrier.await();
@@ -186,4 +204,55 @@
             .build();
     dbClient.write(Collections.singletonList(insertLock));
   }
+
+  private void deleteLockRow(Project.NameKey project, String refName) {
+    Mutation deleteLock = Mutation.delete("locks", Key.of(project.get(), refName));
+    dbClient.write(Collections.singletonList(deleteLock));
+  }
+
+  private void updateLockToken(Project.NameKey project, String refName) {
+    Timestamp currentToken = getLockTimestamp(project, refName, TOKEN);
+    Timestamp newToken;
+    do {
+      Mutation updateToken =
+          Mutation.newUpdateBuilder("locks")
+              .set("project")
+              .to(project.get())
+              .set("ref")
+              .to(refName)
+              .set("token")
+              .to(Value.COMMIT_TIMESTAMP)
+              .build();
+      TransactionRunner transactionRunner = dbClient.readWriteTransaction();
+      transactionRunner.run(
+          transaction -> {
+            transaction.buffer(updateToken);
+            return true;
+          });
+      newToken = transactionRunner.getCommitTimestamp();
+    } while (newToken.equals(currentToken));
+  }
+
+  private void assertHeatbeatUpdatesTimestamp(Project.NameKey project, String refName)
+      throws Exception {
+    Timestamp a = getLockTimestamp(project, refName, HEARTBEAT);
+    Thread.sleep(2 * Lock.HEARTBEAT_INTERVAL.toMillis());
+    Timestamp b = getLockTimestamp(project, refName, HEARTBEAT);
+    assertThat(b).isGreaterThan(a);
+  }
+
+  private void assertHeartbeatTaskIsRunning() throws Exception {
+    long a = heartbeatExecutor.getCompletedTaskCount();
+    Thread.sleep(2 * Lock.HEARTBEAT_INTERVAL.toMillis());
+    long b = heartbeatExecutor.getCompletedTaskCount();
+    assertThat(b).isGreaterThan(a);
+  }
+
+  private void assertHeartbeatTaskIsStopped() throws Exception {
+    Thread.sleep(2 * Lock.HEARTBEAT_INTERVAL.toMillis());
+    long a = heartbeatExecutor.getCompletedTaskCount();
+    Thread.sleep(2 * Lock.HEARTBEAT_INTERVAL.toMillis());
+    long b = heartbeatExecutor.getCompletedTaskCount();
+    assertThat(b).isEqualTo(a);
+  }
 }
diff --git a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/RealSpannerRefDb.java b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/RealSpannerRefDb.java
index 499ac57..1712196 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/RealSpannerRefDb.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/RealSpannerRefDb.java
@@ -25,8 +25,7 @@
 import java.io.FileInputStream;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 
 public class RealSpannerRefDb extends SpannerTestSystem {
@@ -49,6 +48,7 @@
   private boolean dbInitialized;
   private SpannerRefDatabase refdb;
   private DatabaseClient dbClient;
+  private ScheduledThreadPoolExecutor heartbeatExecutor;
 
   private RealSpannerRefDb(String keyPath, String instance) {
     this.keyPath = keyPath;
@@ -77,6 +77,11 @@
   }
 
   @Override
+  ScheduledThreadPoolExecutor heartbeatExecutor() {
+    return heartbeatExecutor;
+  }
+
+  @Override
   void cleanup() {
     // do nothing
   }
@@ -97,7 +102,7 @@
     DatabaseSchemaCreator databaseSchemaCreator = new DatabaseSchemaCreator(dbAdminClient, dbId);
     databaseSchemaCreator.start();
     dbClient = options.getService().getDatabaseClient(dbId);
-    ScheduledExecutorService heartbeatExecutor = Executors.newScheduledThreadPool(2);
+    heartbeatExecutor = new ScheduledThreadPoolExecutor(2);
     Lock.Factory lockFactory =
         new Lock.Factory() {
           @Override
diff --git a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/SpannerTestSystem.java b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/SpannerTestSystem.java
index 3eb6d0c..0cadd3d 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/SpannerTestSystem.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/spannerrefdb/SpannerTestSystem.java
@@ -15,6 +15,7 @@
 package com.googlesource.gerrit.plugins.spannerrefdb;
 
 import com.google.cloud.spanner.DatabaseClient;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
 
 public abstract class SpannerTestSystem {
 
@@ -37,5 +38,7 @@
 
   abstract DatabaseClient dbClient();
 
+  abstract ScheduledThreadPoolExecutor heartbeatExecutor();
+
   abstract void cleanup();
 }