Merge "Always terminate the ServletRequest associated task" into stable-3.2
diff --git a/java/com/google/gerrit/pgm/http/jetty/ProjectQoSFilter.java b/java/com/google/gerrit/pgm/http/jetty/ProjectQoSFilter.java
index 4f9d7e7..1cca789 100644
--- a/java/com/google/gerrit/pgm/http/jetty/ProjectQoSFilter.java
+++ b/java/com/google/gerrit/pgm/http/jetty/ProjectQoSFilter.java
@@ -18,6 +18,7 @@
 import static java.util.concurrent.TimeUnit.MINUTES;
 import static javax.servlet.http.HttpServletResponse.SC_SERVICE_UNAVAILABLE;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.gerrit.server.CurrentUser;
 import com.google.gerrit.server.account.AccountLimits;
 import com.google.gerrit.server.config.GerritServerConfig;
@@ -170,7 +171,7 @@
         request.setAttribute(TASK, task);
 
         Future<?> f = getExecutor().submit(task);
-        asyncContext.addListener(new Listener(f));
+        asyncContext.addListener(new Listener(f, task));
         break;
       case CANCELED:
         rsp.sendError(SC_SERVICE_UNAVAILABLE);
@@ -181,7 +182,6 @@
           task.begin(Thread.currentThread());
           chain.doFilter(req, rsp);
         } finally {
-          task.end();
           Thread.interrupted();
         }
         break;
@@ -211,29 +211,38 @@
   @Override
   public void destroy() {}
 
-  private static final class Listener implements AsyncListener {
+  @VisibleForTesting
+  protected static final class Listener implements AsyncListener {
     final Future<?> future;
+    final TaskThunk task;
 
-    Listener(Future<?> future) {
+    Listener(Future<?> future, TaskThunk task) {
       this.future = future;
+      this.task = task;
     }
 
     @Override
-    public void onComplete(AsyncEvent event) throws IOException {}
+    public void onComplete(AsyncEvent event) throws IOException {
+      task.end();
+    }
 
     @Override
     public void onTimeout(AsyncEvent event) throws IOException {
+      task.end();
       future.cancel(true);
     }
 
     @Override
-    public void onError(AsyncEvent event) throws IOException {}
+    public void onError(AsyncEvent event) throws IOException {
+      task.end();
+    }
 
     @Override
     public void onStartAsync(AsyncEvent event) throws IOException {}
   }
 
-  private final class TaskThunk implements CancelableRunnable {
+  @VisibleForTesting
+  protected class TaskThunk implements CancelableRunnable {
     private final AsyncContext asyncContext;
     private final String name;
     private final Object lock = new Object();
@@ -292,6 +301,10 @@
       }
     }
 
+    public boolean isDone() {
+      return done;
+    }
+
     @Override
     public String toString() {
       return name;
diff --git a/javatests/com/google/gerrit/pgm/BUILD b/javatests/com/google/gerrit/pgm/BUILD
index 5a3a824..0fe4fad 100644
--- a/javatests/com/google/gerrit/pgm/BUILD
+++ b/javatests/com/google/gerrit/pgm/BUILD
@@ -5,6 +5,7 @@
     name = "pgm_tests",
     srcs = glob(["**/*.java"]),
     deps = [
+        "//java/com/google/gerrit/pgm/http/jetty",
         "//java/com/google/gerrit/pgm/init/api",
         "//java/com/google/gerrit/server",
         "//java/com/google/gerrit/server/securestore/testing",
@@ -15,6 +16,8 @@
         "//lib/guice",
         "//lib/mockito",
         "//lib/truth",
+        "@jetty-server//jar",
+        "@servlet-api//jar",
     ],
 )
 
diff --git a/javatests/com/google/gerrit/pgm/http/jetty/ProjectQoSFilterTest.java b/javatests/com/google/gerrit/pgm/http/jetty/ProjectQoSFilterTest.java
new file mode 100644
index 0000000..b969d68
--- /dev/null
+++ b/javatests/com/google/gerrit/pgm/http/jetty/ProjectQoSFilterTest.java
@@ -0,0 +1,175 @@
+// Copyright (C) 2021 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.gerrit.pgm.http.jetty;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import com.google.gerrit.server.CurrentUser;
+import com.google.gerrit.server.account.AccountLimits;
+import com.google.gerrit.server.account.GroupMembership;
+import com.google.gerrit.server.git.QueueProvider;
+import com.google.inject.Provider;
+import java.util.Optional;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import javax.servlet.AsyncContext;
+import javax.servlet.AsyncEvent;
+import javax.servlet.ServletContext;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
+import org.eclipse.jetty.server.Request;
+import org.eclipse.jgit.lib.Config;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+@RunWith(MockitoJUnitRunner.class)
+public class ProjectQoSFilterTest {
+
+  @Mock AsyncEvent asyncEvent;
+  @Mock AsyncContext asyncContext;
+
+  @Mock AccountLimits.Factory limitsFactory;
+  @Mock Provider<CurrentUser> userProvider;
+  @Mock QueueProvider queue;
+  @Mock ServletContext context;
+
+  @Test
+  public void shouldCallTaskEndOnListenerCompleteFromDifferentThread() {
+    ProjectQoSFilter.TaskThunk taskThunk = getTaskThunk();
+    ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(1);
+
+    Future<?> f = scheduledThreadPoolExecutor.submit(taskThunk);
+    taskThunk.begin(Thread.currentThread());
+
+    new Thread() {
+      @Override
+      public void run() {
+        ProjectQoSFilter.Listener listener = new ProjectQoSFilter.Listener(f, taskThunk);
+        try {
+          listener.onComplete(asyncEvent);
+        } catch (Exception e) {
+        }
+      }
+    }.run();
+
+    assertThat(taskThunk.isDone()).isTrue();
+  }
+
+  @Test
+  public void shouldCallTaskEndOnListenerTimeoutFromDifferentThread() {
+    ProjectQoSFilter.TaskThunk taskThunk = getTaskThunk();
+    ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(1);
+
+    Future<?> f = scheduledThreadPoolExecutor.submit(taskThunk);
+    taskThunk.begin(Thread.currentThread());
+
+    new Thread() {
+      @Override
+      public void run() {
+        ProjectQoSFilter.Listener listener = new ProjectQoSFilter.Listener(f, taskThunk);
+        try {
+          listener.onTimeout(asyncEvent);
+        } catch (Exception e) {
+        }
+      }
+    }.run();
+
+    assertThat(taskThunk.isDone()).isTrue();
+  }
+
+  @Test
+  public void shouldCallTaskEndOnListenerErrorFromDifferentThread() {
+    ProjectQoSFilter.TaskThunk taskThunk = getTaskThunk();
+    ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(1);
+
+    Future<?> f = scheduledThreadPoolExecutor.submit(taskThunk);
+    taskThunk.begin(Thread.currentThread());
+
+    new Thread() {
+      @Override
+      public void run() {
+        ProjectQoSFilter.Listener listener = new ProjectQoSFilter.Listener(f, taskThunk);
+        try {
+          listener.onError(asyncEvent);
+        } catch (Exception e) {
+        }
+      }
+    }.run();
+
+    assertThat(taskThunk.isDone()).isTrue();
+  }
+
+  private ProjectQoSFilter.TaskThunk getTaskThunk() {
+    HttpServletRequest servletRequest = new FakeHttpServletRequest();
+    Config config = new Config();
+    String HTTP_MAX_WAIT = "1 minute";
+    config.setString("httpd", null, "maxwait", HTTP_MAX_WAIT);
+
+    when(userProvider.get()).thenReturn(new FakeUser("testUser"));
+    when(asyncContext.getRequest()).thenReturn(servletRequest);
+
+    ProjectQoSFilter projectQoSFilter =
+        new ProjectQoSFilter(limitsFactory, userProvider, queue, context, config);
+    return projectQoSFilter.new TaskThunk(asyncContext, servletRequest);
+  }
+
+  private static class FakeUser extends CurrentUser {
+    private final String username;
+
+    FakeUser(String name) {
+      username = name;
+    }
+
+    @Override
+    public GroupMembership getEffectiveGroups() {
+      return null;
+    }
+
+    @Override
+    public Object getCacheKey() {
+      return new Object();
+    }
+
+    @Override
+    public Optional<String> getUserName() {
+      return Optional.ofNullable(username);
+    }
+  }
+
+  private static final class FakeHttpServletRequest extends HttpServletRequestWrapper {
+
+    FakeHttpServletRequest() {
+      super(new Request(null, null));
+    }
+
+    @Override
+    public String getRemoteHost() {
+      return "1.2.3.4";
+    }
+
+    @Override
+    public String getRemoteUser() {
+      return "bob";
+    }
+
+    @Override
+    public String getServletPath() {
+      return "http://testulr/a/plugins_replication/info/refs?service=git-upload-pack";
+    }
+  }
+}