Fix kafka-events replay messages feature

When requesting to reset the offset and consume messages from the
beginning, the subscriber has to wait first for the assignment of
partitions.

Failing to do so will cause the subscriber to consume zero records,
since no partitions have yet been assigned.

Make an explicit poll() call before the seekToBeginning() to ensure that
the consumer heartbeat is sent to kafka and thus a partition is
assigned.

Bug: Issue 14136
Change-Id: Ibc6a66507ebfc9bb6c67df9e576114bed8973e74
diff --git a/src/main/java/com/googlesource/gerrit/plugins/kafka/subscribe/KafkaEventSubscriber.java b/src/main/java/com/googlesource/gerrit/plugins/kafka/subscribe/KafkaEventSubscriber.java
index 40415a4..7ef9d7b 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/kafka/subscribe/KafkaEventSubscriber.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/kafka/subscribe/KafkaEventSubscriber.java
@@ -133,6 +133,12 @@
       try {
         while (!closed.get()) {
           if (resetOffset.getAndSet(false)) {
+            // Make sure there is an assignment for this consumer
+            while (consumer.assignment().isEmpty() && !closed.get()) {
+              logger.atInfo().log(
+                  "Resetting offset: no partitions assigned to the consumer, request assignment.");
+              consumer.poll(Duration.ofMillis(configuration.getPollingInterval()));
+            }
             consumer.seekToBeginning(consumer.assignment());
           }
           ConsumerRecords<byte[], byte[]> consumerRecords =
diff --git a/src/test/java/com/googlesource/gerrit/plugins/kafka/EventConsumerIT.java b/src/test/java/com/googlesource/gerrit/plugins/kafka/EventConsumerIT.java
index 06261ed..c8495f1 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/kafka/EventConsumerIT.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/kafka/EventConsumerIT.java
@@ -15,9 +15,13 @@
 package com.googlesource.gerrit.plugins.kafka;
 
 import static com.google.common.truth.Truth.assertThat;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
 import static org.junit.Assert.fail;
 
+import com.gerritforge.gerrit.eventbroker.BrokerApi;
 import com.gerritforge.gerrit.eventbroker.EventGsonProvider;
+import com.gerritforge.gerrit.eventbroker.EventMessage;
+import com.google.common.base.Stopwatch;
 import com.google.common.collect.Iterables;
 import com.google.gerrit.acceptance.GerritConfig;
 import com.google.gerrit.acceptance.LightweightPluginDaemonTest;
@@ -29,11 +33,15 @@
 import com.google.gerrit.extensions.common.ChangeMessageInfo;
 import com.google.gerrit.server.events.CommentAddedEvent;
 import com.google.gerrit.server.events.Event;
+import com.google.gerrit.server.events.ProjectCreatedEvent;
 import com.google.gson.Gson;
 import com.googlesource.gerrit.plugins.kafka.config.KafkaProperties;
+import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.UUID;
+import java.util.function.Supplier;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
@@ -121,7 +129,62 @@
     assertThat(commentAddedEvent.comment).isEqualTo(expectedMessage);
   }
 
+  @Test
+  @UseLocalDisk
+  @GerritConfig(name = "plugin.kafka-events.groupId", value = "test-consumer-group")
+  @GerritConfig(
+      name = "plugin.kafka-events.keyDeserializer",
+      value = "org.apache.kafka.common.serialization.StringDeserializer")
+  @GerritConfig(
+      name = "plugin.kafka-events.valueDeserializer",
+      value = "org.apache.kafka.common.serialization.StringDeserializer")
+  @GerritConfig(name = "plugin.kafka-events.pollingIntervalMs", value = "500")
+  public void shouldReplayAllEvents() throws InterruptedException {
+    String topic = "a_topic";
+    EventMessage eventMessage =
+        new EventMessage(
+            new EventMessage.Header(UUID.randomUUID(), UUID.randomUUID()),
+            new ProjectCreatedEvent());
+
+    Duration WAIT_FOR_POLL_TIMEOUT = Duration.ofMillis(1000);
+
+    List<EventMessage> receivedEvents = new ArrayList<>();
+
+    BrokerApi kafkaBrokerApi = kafkaBrokerApi();
+    kafkaBrokerApi.send(topic, eventMessage);
+
+    kafkaBrokerApi.receiveAsync(topic, receivedEvents::add);
+
+    waitUntil(() -> receivedEvents.size() == 1, WAIT_FOR_POLL_TIMEOUT);
+
+    assertThat(receivedEvents.get(0).getHeader().eventId)
+        .isEqualTo(eventMessage.getHeader().eventId);
+
+    kafkaBrokerApi.replayAllEvents(topic);
+    waitUntil(() -> receivedEvents.size() == 2, WAIT_FOR_POLL_TIMEOUT);
+
+    assertThat(receivedEvents.get(1).getHeader().eventId)
+        .isEqualTo(eventMessage.getHeader().eventId);
+  }
+
+  private BrokerApi kafkaBrokerApi() {
+    return plugin.getSysInjector().getInstance(BrokerApi.class);
+  }
+
   private KafkaProperties kafkaProperties() {
     return plugin.getSysInjector().getInstance(KafkaProperties.class);
   }
+
+  // XXX: Remove this method when merging into stable-3.3, since waitUntil is
+  // available in Gerrit core.
+  public static void waitUntil(Supplier<Boolean> waitCondition, Duration timeout)
+      throws InterruptedException {
+    Stopwatch stopwatch = Stopwatch.createStarted();
+    while (!waitCondition.get()) {
+      if (stopwatch.elapsed().compareTo(timeout) > 0) {
+        throw new InterruptedException();
+      }
+      MILLISECONDS.sleep(50);
+    }
+  }
 }