Update checkpoint on regular bases

Currently checkpointing happens only at shutdown and at the end
of the stream. Allow more frequent checkpointing to avoid replaying
all the stream in case the kinesis collector is not stopped gracefully.

Bug: Issue 15321
Change-Id: I83096084a32773e585f5b7c264bc9ad286325dec
diff --git a/src/main/java/com/googlesource/gerrit/plugins/kinesis/Configuration.java b/src/main/java/com/googlesource/gerrit/plugins/kinesis/Configuration.java
index ef7e698..9c04c09 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/kinesis/Configuration.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/kinesis/Configuration.java
@@ -38,6 +38,7 @@
   private static final Long DEFAULT_PUBLISH_SINGLE_REQUEST_TIMEOUT_MS = 6000L;
   private static final Long DEFAULT_PUBLISH_TIMEOUT_MS = 6000L;
   private static final Long DEFAULT_SHUTDOWN_TIMEOUT_MS = 20000L;
+  private static final Long DEFAULT_CHECKPOINT_INTERVAL_MS = 5 * 60000L; // 5 min
   private static final Level DEFAULT_AWS_LIB_LOG_LEVEL = Level.WARN;
   private static final Boolean DEFAULT_SEND_ASYNC = true;
 
@@ -52,6 +53,7 @@
   private final Long publishTimeoutMs;
   private final Long publishSingleRequestTimeoutMs;
   private final Long shutdownTimeoutMs;
+  private final Long checkpointIntervalMs;
   private final Level awsLibLogLevel;
   private final Boolean sendAsync;
 
@@ -98,6 +100,11 @@
             .map(Long::parseLong)
             .orElse(DEFAULT_SHUTDOWN_TIMEOUT_MS);
 
+    this.checkpointIntervalMs =
+        Optional.ofNullable(getStringParam(pluginConfig, "checkpointIntervalMs", null))
+            .map(Long::parseLong)
+            .orElse(DEFAULT_CHECKPOINT_INTERVAL_MS);
+
     this.awsLibLogLevel =
         Optional.ofNullable(getStringParam(pluginConfig, "awsLibLogLevel", null))
             .map(l -> Level.toLevel(l, DEFAULT_AWS_LIB_LOG_LEVEL))
@@ -172,6 +179,10 @@
     return shutdownTimeoutMs;
   }
 
+  public Long getCheckpointIntervalMs() {
+    return checkpointIntervalMs;
+  }
+
   public Level getAwsLibLogLevel() {
     return awsLibLogLevel;
   }
diff --git a/src/main/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessor.java b/src/main/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessor.java
index 2786b82..29c7f60 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessor.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessor.java
@@ -24,11 +24,13 @@
 import java.util.function.Consumer;
 import software.amazon.kinesis.exceptions.InvalidStateException;
 import software.amazon.kinesis.exceptions.ShutdownException;
+import software.amazon.kinesis.exceptions.ThrottlingException;
 import software.amazon.kinesis.lifecycle.events.InitializationInput;
 import software.amazon.kinesis.lifecycle.events.LeaseLostInput;
 import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
 import software.amazon.kinesis.lifecycle.events.ShardEndedInput;
 import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput;
+import software.amazon.kinesis.processor.RecordProcessorCheckpointer;
 import software.amazon.kinesis.processor.ShardRecordProcessor;
 
 class KinesisRecordProcessor implements ShardRecordProcessor {
@@ -40,21 +42,29 @@
   private final Consumer<Event> recordProcessor;
   private final OneOffRequestContext oneOffCtx;
   private final EventDeserializer eventDeserializer;
+  private final Configuration configuration;
+
+  private long nextCheckpointTimeInMillis;
+  private String kinesisShardId;
 
   @Inject
   KinesisRecordProcessor(
       @Assisted Consumer<Event> recordProcessor,
       OneOffRequestContext oneOffCtx,
-      EventDeserializer eventDeserializer) {
+      EventDeserializer eventDeserializer,
+      Configuration configuration) {
     this.recordProcessor = recordProcessor;
     this.oneOffCtx = oneOffCtx;
     this.eventDeserializer = eventDeserializer;
+    this.configuration = configuration;
   }
 
   @Override
   public void initialize(InitializationInput initializationInput) {
+    kinesisShardId = initializationInput.shardId();
     logger.atInfo().log(
         "Initializing @ Sequence: %s", initializationInput.extendedSequenceNumber());
+    setNextCheckpointTime();
   }
 
   @Override
@@ -79,11 +89,21 @@
                   logger.atSevere().withCause(e).log("Could not process event '%s'", jsonMessage);
                 }
               });
+
+      if (System.currentTimeMillis() >= nextCheckpointTimeInMillis) {
+        checkpoint(processRecordsInput.checkpointer());
+        setNextCheckpointTime();
+      }
     } catch (Throwable t) {
       logger.atSevere().withCause(t).log("Caught throwable while processing records. Aborting.");
     }
   }
 
+  private void setNextCheckpointTime() {
+    nextCheckpointTimeInMillis =
+        System.currentTimeMillis() + configuration.getCheckpointIntervalMs();
+  }
+
   @Override
   public void leaseLost(LeaseLostInput leaseLostInput) {
     logger.atInfo().log("Lost lease, so terminating.");
@@ -91,22 +111,28 @@
 
   @Override
   public void shardEnded(ShardEndedInput shardEndedInput) {
-    try {
-      logger.atInfo().log("Reached shard end checkpointing.");
-      shardEndedInput.checkpointer().checkpoint();
-    } catch (ShutdownException | InvalidStateException e) {
-      logger.atSevere().withCause(e).log("Exception while checkpointing at shard end. Giving up.");
-    }
+    logger.atInfo().log("Reached shard end checkpointing.");
+    checkpoint(shardEndedInput.checkpointer());
   }
 
   @Override
   public void shutdownRequested(ShutdownRequestedInput shutdownRequestedInput) {
+    logger.atInfo().log("Scheduler is shutting down, checkpointing.");
+    checkpoint(shutdownRequestedInput.checkpointer());
+  }
+
+  private void checkpoint(RecordProcessorCheckpointer checkpointer) {
+    logger.atInfo().log("Checkpointing shard: " + kinesisShardId);
     try {
-      logger.atInfo().log("Scheduler is shutting down, checkpointing.");
-      shutdownRequestedInput.checkpointer().checkpoint();
-    } catch (ShutdownException | InvalidStateException e) {
+      checkpointer.checkpoint();
+    } catch (ShutdownException se) {
+      logger.atInfo().log("Caught shutdown exception, skipping checkpoint.", se);
+    } catch (ThrottlingException e) {
+      logger.atSevere().withCause(e).log("Caught throttling exception, skipping checkpoint.", e);
+    } catch (InvalidStateException e) {
       logger.atSevere().withCause(e).log(
-          "Exception while checkpointing at requested shutdown. Giving up.");
+          "Cannot save checkpoint to the DynamoDB table used by the Amazon Kinesis Client Library.",
+          e);
     }
   }
 }
diff --git a/src/main/resources/Documentation/Config.md b/src/main/resources/Documentation/Config.md
index 03e3650..004ce0b 100644
--- a/src/main/resources/Documentation/Config.md
+++ b/src/main/resources/Documentation/Config.md
@@ -81,6 +81,10 @@
   kinesis consumers.
   Default: 20000
 
+`plugin.events-aws-kinesis.checkpointIntervalMs`
+: Optional. The interval between checkpoints (milliseconds).
+Default: 300000 (5 minutes)
+
 `plugin.events-aws-kinesis.awsLibLogLevel`
 : Optional. Which level AWS libraries should log at.
   This plugin delegates most complex tasks associated to the production and
diff --git a/src/test/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessorTest.java b/src/test/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessorTest.java
index d488ab8..24027fc 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessorTest.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/kinesis/KinesisRecordProcessorTest.java
@@ -38,11 +38,14 @@
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.junit.MockitoJUnitRunner;
 import software.amazon.awssdk.core.SdkBytes;
 import software.amazon.awssdk.services.kinesis.model.Record;
+import software.amazon.kinesis.lifecycle.events.InitializationInput;
 import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput;
 import software.amazon.kinesis.retrieval.KinesisClientRecord;
+import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber;
 
 @RunWith(MockitoJUnitRunner.class)
 public class KinesisRecordProcessorTest {
@@ -54,11 +57,41 @@
   @Captor ArgumentCaptor<Event> eventMessageCaptor;
   @Mock OneOffRequestContext oneOffCtx;
   @Mock ManualRequestContext requestContext;
+  @Mock Configuration configuration;
 
   @Before
   public void setup() {
     when(oneOffCtx.open()).thenReturn(requestContext);
-    objectUnderTest = new KinesisRecordProcessor(succeedingConsumer, oneOffCtx, eventDeserializer);
+    objectUnderTest =
+        new KinesisRecordProcessor(succeedingConsumer, oneOffCtx, eventDeserializer, configuration);
+  }
+
+  @Test
+  public void shouldNotCheckpointBeforeIntervalIsExpired() {
+    when(configuration.getCheckpointIntervalMs()).thenReturn(10000L);
+    Event event = new ProjectCreatedEvent();
+
+    initializeRecordProcessor();
+
+    ProcessRecordsInput kinesisInput = sampleMessage(gson.toJson(event));
+    ProcessRecordsInput processRecordsInputSpy = Mockito.spy(kinesisInput);
+    objectUnderTest.processRecords(processRecordsInputSpy);
+
+    verify(processRecordsInputSpy, never()).checkpointer();
+  }
+
+  @Test
+  public void shouldCheckpointAfterIntervalIsExpired() throws InterruptedException {
+    when(configuration.getCheckpointIntervalMs()).thenReturn(0L);
+    Event event = new ProjectCreatedEvent();
+
+    initializeRecordProcessor();
+
+    ProcessRecordsInput kinesisInput = sampleMessage(gson.toJson(event));
+    ProcessRecordsInput processRecordsInputSpy = Mockito.spy(kinesisInput);
+    objectUnderTest.processRecords(processRecordsInputSpy);
+
+    verify(processRecordsInputSpy).checkpointer();
   }
 
   @Test
@@ -153,4 +186,13 @@
             .build();
     return kinesisInput;
   }
+
+  private void initializeRecordProcessor() {
+    InitializationInput initializationInput =
+        InitializationInput.builder()
+            .shardId("shard-0000")
+            .extendedSequenceNumber(new ExtendedSequenceNumber("0000"))
+            .build();
+    objectUnderTest.initialize(initializationInput);
+  }
 }