Refresh GitHub groups upon Gerrit successful login

GitHub groups are cached and persisted on disk, for preventing
GitHub outages to impact Gerrit Code Review. The persistent cache may
however cause security and inconsistency issues with the original
copy of the groups and memebership on GitHub.

Make sure that groups are refreshed every time a user is logging
on Gerrit Code Review UI successfully.

NOTE: Git/HTTP is excluded because it could lead to bursts of
GitHub APIs that could exhaust the user's allowance.

Because of the lack of login notification from Gerrit, the solution
needs to be based on the presence of the Set-Cookie header in
the HTTP response, hence the group refresh needs to use a
ServletFilter that is invoked for every request heading to Gerrit.

Bug: Issue 308979847
Change-Id: Ie5d012fb7c34e3edbb5576b53de32fd5b140ec17
diff --git a/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/GuiceHttpModule.java b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/GuiceHttpModule.java
index 661b0bf..ef24191 100644
--- a/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/GuiceHttpModule.java
+++ b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/GuiceHttpModule.java
@@ -18,10 +18,13 @@
 import com.google.gerrit.extensions.registration.DynamicSet;
 import com.google.gerrit.extensions.webui.JavaScriptPlugin;
 import com.google.gerrit.extensions.webui.WebUiPlugin;
+import com.google.gerrit.httpd.AllRequestFilter;
+import com.google.inject.Scopes;
 import com.google.inject.TypeLiteral;
 import com.google.inject.assistedinject.FactoryModuleBuilder;
 import com.google.inject.name.Names;
 import com.google.inject.servlet.ServletModule;
+import com.googlesource.gerrit.plugins.github.filters.GitHubGroupCacheRefreshFilter;
 import com.googlesource.gerrit.plugins.github.filters.GitHubOAuthFilter;
 import com.googlesource.gerrit.plugins.github.git.CreateProjectStep;
 import com.googlesource.gerrit.plugins.github.git.GitCloneStep;
@@ -103,5 +106,9 @@
 
     serve("/static/*").with(VelocityViewServlet.class);
     filterRegex("(?!/webhook).*").through(GitHubOAuthFilter.class);
+
+    DynamicSet.bind(binder(), AllRequestFilter.class)
+        .to(GitHubGroupCacheRefreshFilter.class)
+        .in(Scopes.SINGLETON);
   }
 }
diff --git a/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/filters/GitHubGroupCacheRefreshFilter.java b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/filters/GitHubGroupCacheRefreshFilter.java
new file mode 100644
index 0000000..eb34962
--- /dev/null
+++ b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/filters/GitHubGroupCacheRefreshFilter.java
@@ -0,0 +1,78 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.github.filters;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.flogger.FluentLogger;
+import com.google.gerrit.httpd.AllRequestFilter;
+import com.googlesource.gerrit.plugins.github.group.GitHubGroupsCache;
+import java.io.IOException;
+import java.util.Optional;
+import javax.inject.Inject;
+import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+public class GitHubGroupCacheRefreshFilter extends AllRequestFilter {
+  private static final FluentLogger logger = FluentLogger.forEnclosingClass();
+  private static final String LOGIN_URL = "/login";
+  private static final String LOGIN_QUERY_FINAL = "final=true";
+  private static final String ACCOUNT_COOKIE = "GerritAccount";
+  private static final String INVALIDATE_CACHED_GROUPS = "RefreshGroups";
+
+  private final GitHubGroupsCache ghGroupsCache;
+
+  @Inject
+  @VisibleForTesting
+  public GitHubGroupCacheRefreshFilter(GitHubGroupsCache ghGroupsCache) {
+    this.ghGroupsCache = ghGroupsCache;
+  }
+
+  @Override
+  public void init(FilterConfig filterConfig) throws ServletException {}
+
+  @Override
+  public void doFilter(
+      ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+      throws IOException, ServletException {
+    filterChain.doFilter(servletRequest, servletResponse);
+
+    HttpServletRequest req = (HttpServletRequest) servletRequest;
+    if (req.getRequestURI().endsWith(LOGIN_URL) && req.getQueryString().equals(LOGIN_QUERY_FINAL)) {
+      HttpServletResponse resp = (HttpServletResponse) servletResponse;
+      String cookieResponse = resp.getHeader("Set-Cookie");
+      if (cookieResponse != null && cookieResponse.contains(ACCOUNT_COOKIE)) {
+        req.getSession().setAttribute(INVALIDATE_CACHED_GROUPS, Boolean.TRUE);
+      }
+    } else if (hasSessionFlagForInvalidatingCachedUserGroups(req)) {
+      ghGroupsCache.invalidateCurrentUserGroups();
+      req.getSession().removeAttribute(INVALIDATE_CACHED_GROUPS);
+    }
+  }
+
+  private static boolean hasSessionFlagForInvalidatingCachedUserGroups(HttpServletRequest req) {
+    return Optional.ofNullable(req.getSession(false))
+        .flatMap(session -> Optional.ofNullable(session.getAttribute(INVALIDATE_CACHED_GROUPS)))
+        .filter(refresh -> (Boolean) refresh)
+        .isPresent();
+  }
+
+  @Override
+  public void destroy() {}
+}
diff --git a/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/CurrentUsernameProvider.java b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/CurrentUsernameProvider.java
new file mode 100644
index 0000000..028aa6b
--- /dev/null
+++ b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/CurrentUsernameProvider.java
@@ -0,0 +1,41 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.github.group;
+
+import com.google.gerrit.server.CurrentUser;
+import com.google.gerrit.server.IdentifiedUser;
+import com.google.inject.Inject;
+import com.google.inject.Provider;
+import java.util.Optional;
+
+public class CurrentUsernameProvider implements Provider<String> {
+  public static final String CURRENT_USERNAME = "CurrentUsername";
+
+  private final Provider<CurrentUser> userProvider;
+
+  @Inject
+  CurrentUsernameProvider(Provider<CurrentUser> userProvider) {
+    this.userProvider = userProvider;
+  }
+
+  @Override
+  public String get() {
+    return Optional.ofNullable(userProvider.get())
+        .filter(CurrentUser::isIdentifiedUser)
+        .map(CurrentUser::asIdentifiedUser)
+        .flatMap(IdentifiedUser::getUserName)
+        .orElse(null);
+  }
+}
diff --git a/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/GitHubGroupsCache.java b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/GitHubGroupsCache.java
index 3504f2c..a5301e9 100644
--- a/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/GitHubGroupsCache.java
+++ b/github-plugin/src/main/java/com/googlesource/gerrit/plugins/github/group/GitHubGroupsCache.java
@@ -14,19 +14,21 @@
 
 package com.googlesource.gerrit.plugins.github.group;
 
+import static com.googlesource.gerrit.plugins.github.group.CurrentUsernameProvider.CURRENT_USERNAME;
 import static java.time.temporal.ChronoUnit.MINUTES;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 import com.google.common.collect.ImmutableSet;
 import com.google.gerrit.entities.AccountGroup.UUID;
-import com.google.gerrit.server.IdentifiedUser;
 import com.google.gerrit.server.cache.CacheModule;
 import com.google.inject.Inject;
 import com.google.inject.Module;
 import com.google.inject.Provider;
 import com.google.inject.Singleton;
 import com.google.inject.name.Named;
+import com.google.inject.name.Names;
 import com.googlesource.gerrit.plugins.github.groups.OrganizationStructure;
 import com.googlesource.gerrit.plugins.github.oauth.GitHubLogin;
 import com.googlesource.gerrit.plugins.github.oauth.UserScopedProvider;
@@ -106,6 +108,9 @@
     return new CacheModule() {
       @Override
       protected void configure() {
+        bind(String.class)
+            .annotatedWith(Names.named(CurrentUsernameProvider.CURRENT_USERNAME))
+            .toProvider(CurrentUsernameProvider.class);
         persist(ORGS_CACHE_NAME, String.class, OrganizationStructure.class)
             .expireAfterWrite(Duration.of(GROUPS_CACHE_TTL_MINS, MINUTES))
             .loader(OrganisationLoader.class);
@@ -115,14 +120,15 @@
   }
 
   private final LoadingCache<String, OrganizationStructure> orgTeamsByUsername;
-  private final Provider<IdentifiedUser> userProvider;
+  private final Provider<String> usernameProvider;
 
   @Inject
-  GitHubGroupsCache(
+  @VisibleForTesting
+  public GitHubGroupsCache(
       @Named(ORGS_CACHE_NAME) LoadingCache<String, OrganizationStructure> byUsername,
-      Provider<IdentifiedUser> userProvider) {
+      @Named(CURRENT_USERNAME) Provider<String> usernameProvider) {
     this.orgTeamsByUsername = byUsername;
-    this.userProvider = userProvider;
+    this.usernameProvider = usernameProvider;
   }
 
   Set<String> getOrganizationsForUser(String username) {
@@ -135,7 +141,7 @@
   }
 
   Set<String> getOrganizationsForCurrentUser() throws ExecutionException {
-    return orgTeamsByUsername.get(userProvider.get().getUserName().get()).keySet();
+    return orgTeamsByUsername.get(usernameProvider.get()).keySet();
   }
 
   Set<String> getTeamsForUser(String organizationName, String username) {
@@ -156,7 +162,7 @@
   }
 
   Set<String> getTeamsForCurrentUser(String organizationName) {
-    return getTeamsForUser(organizationName, userProvider.get().getUserName().get());
+    return getTeamsForUser(organizationName, usernameProvider.get());
   }
 
   public Set<UUID> getGroupsForUser(String username) {
@@ -170,4 +176,8 @@
     }
     return groupsBuilder.build();
   }
+
+  public void invalidateCurrentUserGroups() {
+    orgTeamsByUsername.invalidate(usernameProvider.get());
+  }
 }
diff --git a/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpServletRequest.java b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpServletRequest.java
new file mode 100644
index 0000000..cdb229d
--- /dev/null
+++ b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpServletRequest.java
@@ -0,0 +1,465 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.github;
+
+import static com.google.common.base.Strings.nullToEmpty;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Objects.requireNonNull;
+import static java.util.stream.Collectors.toList;
+
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.ListMultimap;
+import com.google.common.collect.Maps;
+import com.google.gerrit.common.Nullable;
+import java.io.BufferedReader;
+import java.io.UnsupportedEncodingException;
+import java.net.URLDecoder;
+import java.security.Principal;
+import java.time.format.DateTimeFormatter;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import javax.servlet.AsyncContext;
+import javax.servlet.DispatcherType;
+import javax.servlet.RequestDispatcher;
+import javax.servlet.ServletContext;
+import javax.servlet.ServletInputStream;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+import javax.servlet.http.HttpUpgradeHandler;
+import javax.servlet.http.Part;
+
+/** Simple fake implementation of {@link HttpServletRequest}. */
+public class FakeHttpServletRequest implements HttpServletRequest {
+  public static final String SERVLET_PATH = "/b";
+  public static final DateTimeFormatter rfcDateformatter =
+      DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss ZZZ");
+
+  private final Map<String, Object> attributes;
+  private final ListMultimap<String, String> headers;
+
+  private ListMultimap<String, String> parameters;
+  private String queryString;
+  private String servletPath;
+  private String path;
+  private String method;
+  private HttpSession session;
+
+  public FakeHttpServletRequest() {
+    this(SERVLET_PATH, null);
+  }
+
+  public FakeHttpServletRequest(String servletPath, HttpSession existingSession) {
+    this.servletPath = requireNonNull(servletPath, "servletPath");
+    attributes = Maps.newConcurrentMap();
+    parameters = LinkedListMultimap.create();
+    headers = LinkedListMultimap.create();
+    session = existingSession;
+  }
+
+  @Override
+  public Object getAttribute(String name) {
+    return attributes.get(name);
+  }
+
+  @Override
+  public Enumeration<String> getAttributeNames() {
+    return Collections.enumeration(attributes.keySet());
+  }
+
+  @Override
+  public String getCharacterEncoding() {
+    return UTF_8.name();
+  }
+
+  @Override
+  public int getContentLength() {
+    return 0;
+  }
+
+  @Nullable
+  @Override
+  public String getContentType() {
+    return null;
+  }
+
+  @Override
+  public ServletInputStream getInputStream() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String getLocalAddr() {
+    return "1.2.3.4";
+  }
+
+  @Override
+  public String getLocalName() {
+    return "localhost";
+  }
+
+  @Override
+  public int getLocalPort() {
+    return 80;
+  }
+
+  @Override
+  public Locale getLocale() {
+    return Locale.US;
+  }
+
+  @Override
+  public Enumeration<Locale> getLocales() {
+    return Collections.enumeration(Collections.singleton(Locale.US));
+  }
+
+  @Override
+  public String getParameter(String name) {
+    return Iterables.getFirst(parameters.get(name), null);
+  }
+
+  @Override
+  public Map<String, String[]> getParameterMap() {
+    return Collections.unmodifiableMap(
+        Maps.transformValues(parameters.asMap(), vs -> vs.toArray(new String[0])));
+  }
+
+  @Override
+  public Enumeration<String> getParameterNames() {
+    return Collections.enumeration(parameters.keySet());
+  }
+
+  @Override
+  public String[] getParameterValues(String name) {
+    return parameters.get(name).toArray(new String[0]);
+  }
+
+  public void setQueryString(String qs) {
+    this.queryString = qs;
+    ListMultimap<String, String> params = LinkedListMultimap.create();
+    for (String entry : Splitter.on('&').split(qs)) {
+      List<String> kv = Splitter.on('=').limit(2).splitToList(entry);
+      try {
+        params.put(
+            URLDecoder.decode(kv.get(0), UTF_8.name()),
+            kv.size() == 2 ? URLDecoder.decode(kv.get(1), UTF_8.name()) : "");
+      } catch (UnsupportedEncodingException e) {
+        throw new IllegalArgumentException(e);
+      }
+    }
+    parameters = params;
+  }
+
+  @Override
+  public String getProtocol() {
+    return "HTTP/1.1";
+  }
+
+  @Override
+  public BufferedReader getReader() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  @Deprecated
+  public String getRealPath(String path) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String getRemoteAddr() {
+    return "5.6.7.8";
+  }
+
+  @Override
+  public String getRemoteHost() {
+    return "remotehost";
+  }
+
+  @Override
+  public int getRemotePort() {
+    return 1234;
+  }
+
+  @Override
+  public RequestDispatcher getRequestDispatcher(String path) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String getScheme() {
+    return "http";
+  }
+
+  @Override
+  public String getServerName() {
+    return "localhost";
+  }
+
+  @Override
+  public int getServerPort() {
+    return 80;
+  }
+
+  @Override
+  public boolean isSecure() {
+    return false;
+  }
+
+  @Override
+  public void removeAttribute(String name) {
+    attributes.remove(name);
+  }
+
+  @Override
+  public void setAttribute(String name, Object value) {
+    attributes.put(name, value);
+  }
+
+  @Override
+  public void setCharacterEncoding(String env) throws UnsupportedOperationException {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String getAuthType() {
+    return null;
+  }
+
+  @Override
+  public String getContextPath() {
+    return "";
+  }
+
+  @Override
+  public Cookie[] getCookies() {
+    return Splitter.on(";").splitToList(nullToEmpty(getHeader("Cookie"))).stream()
+        .filter(s -> !s.isEmpty())
+        .map(
+            (String cookieValue) -> {
+              List<String> kv = Splitter.on("=").splitToList(cookieValue);
+              return new Cookie(kv.get(0), kv.get(1));
+            })
+        .collect(toList())
+        .toArray(new Cookie[0]);
+  }
+
+  @Override
+  public long getDateHeader(String name) {
+    return 0L;
+  }
+
+  @Override
+  public String getHeader(String name) {
+    return Iterables.getFirst(headers.get(name), null);
+  }
+
+  @Override
+  public Enumeration<String> getHeaderNames() {
+    return Collections.enumeration(headers.keySet());
+  }
+
+  @Override
+  public Enumeration<String> getHeaders(String name) {
+    return Collections.enumeration(headers.get(name));
+  }
+
+  @Override
+  public int getIntHeader(String name) {
+    return Integer.parseInt(getHeader(name));
+  }
+
+  @Override
+  public String getMethod() {
+    return method;
+  }
+
+  public void setMethod(String method) {
+    this.method = method;
+  }
+
+  @Override
+  public String getPathInfo() {
+    return path;
+  }
+
+  public FakeHttpServletRequest setPathInfo(String path) {
+    this.path = path;
+    return this;
+  }
+
+  @Override
+  public String getPathTranslated() {
+    return path;
+  }
+
+  @Override
+  public String getQueryString() {
+    return queryString;
+  }
+
+  @Override
+  public String getRemoteUser() {
+    return null;
+  }
+
+  @Override
+  public String getRequestURI() {
+    return nullToEmpty(servletPath) + nullToEmpty(path);
+  }
+
+  @Override
+  public StringBuffer getRequestURL() {
+    return new StringBuffer("http://localhost" + getRequestURI());
+  }
+
+  @Override
+  public String getRequestedSessionId() {
+    return null;
+  }
+
+  @Override
+  public String getServletPath() {
+    return servletPath;
+  }
+
+  @Override
+  public HttpSession getSession() {
+    return getSession(true);
+  }
+
+  @Override
+  public HttpSession getSession(boolean create) {
+    if (session == null && create) {
+      session = new FakeHttpSession();
+    }
+    return session;
+  }
+
+  @Override
+  public Principal getUserPrincipal() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isRequestedSessionIdFromCookie() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isRequestedSessionIdFromURL() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  @Deprecated
+  public boolean isRequestedSessionIdFromUrl() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isRequestedSessionIdValid() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isUserInRole(String role) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public AsyncContext getAsyncContext() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public DispatcherType getDispatcherType() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public ServletContext getServletContext() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean isAsyncStarted() {
+    return false;
+  }
+
+  @Override
+  public boolean isAsyncSupported() {
+    return false;
+  }
+
+  @Override
+  public AsyncContext startAsync() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public AsyncContext startAsync(ServletRequest req, ServletResponse res) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public boolean authenticate(HttpServletResponse res) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public Part getPart(String name) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public Collection<Part> getParts() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void login(String username, String password) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void logout() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public long getContentLengthLong() {
+    return getContentLength();
+  }
+
+  @Override
+  public String changeSessionId() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public <T extends HttpUpgradeHandler> T upgrade(Class<T> httpUpgradeHandlerClass) {
+    throw new UnsupportedOperationException();
+  }
+}
diff --git a/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpServletResponse.java b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpServletResponse.java
new file mode 100644
index 0000000..6824a15
--- /dev/null
+++ b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpServletResponse.java
@@ -0,0 +1,249 @@
+// Copyright (C) 2015 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.github;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Objects.requireNonNull;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.ListMultimap;
+import com.google.common.net.HttpHeaders;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.nio.charset.Charset;
+import java.util.Collection;
+import java.util.Locale;
+import javax.servlet.ServletOutputStream;
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletResponse;
+
+/** Simple fake implementation of {@link HttpServletResponse}. */
+public class FakeHttpServletResponse implements HttpServletResponse {
+  private final ByteArrayOutputStream actualBody = new ByteArrayOutputStream();
+  private final ListMultimap<String, String> headers = LinkedListMultimap.create();
+
+  private int status = SC_OK;
+  private boolean committed;
+  private ServletOutputStream outputStream;
+  private PrintWriter writer;
+
+  public FakeHttpServletResponse() {}
+
+  @Override
+  public synchronized void flushBuffer() throws IOException {
+    if (outputStream != null) {
+      outputStream.flush();
+    }
+    if (writer != null) {
+      writer.flush();
+    }
+  }
+
+  @Override
+  public int getBufferSize() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String getCharacterEncoding() {
+    return UTF_8.name();
+  }
+
+  @Override
+  public String getContentType() {
+    return null;
+  }
+
+  @Override
+  public Locale getLocale() {
+    return Locale.US;
+  }
+
+  @Override
+  public synchronized ServletOutputStream getOutputStream() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public synchronized PrintWriter getWriter() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public synchronized boolean isCommitted() {
+    return committed;
+  }
+
+  @Override
+  public void reset() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void resetBuffer() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void setBufferSize(int sz) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void setCharacterEncoding(String name) {
+    checkArgument(UTF_8.equals(Charset.forName(name)), "unsupported charset: %s", name);
+  }
+
+  @Override
+  public void setContentLength(int length) {
+    setContentLengthLong(length);
+  }
+
+  @Override
+  public void setContentLengthLong(long length) {
+    headers.removeAll(HttpHeaders.CONTENT_LENGTH);
+    addHeader(HttpHeaders.CONTENT_LENGTH, Long.toString(length));
+  }
+
+  @Override
+  public void setContentType(String type) {
+    headers.removeAll(HttpHeaders.CONTENT_TYPE);
+    addHeader(HttpHeaders.CONTENT_TYPE, type);
+  }
+
+  @Override
+  public void setLocale(Locale locale) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void addCookie(Cookie cookie) {
+    addHeader("Set-Cookie", cookie.getName() + "=" + cookie.getValue());
+  }
+
+  @Override
+  public void addDateHeader(String name, long value) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void addHeader(String name, String value) {
+    headers.put(name.toLowerCase(Locale.US), value);
+  }
+
+  @Override
+  public void addIntHeader(String name, int value) {
+    addHeader(name, Integer.toString(value));
+  }
+
+  @Override
+  public boolean containsHeader(String name) {
+    return headers.containsKey(name.toLowerCase(Locale.US));
+  }
+
+  @Override
+  public String encodeRedirectURL(String url) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  @Deprecated
+  public String encodeRedirectUrl(String url) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public String encodeURL(String url) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  @Deprecated
+  public String encodeUrl(String url) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public synchronized void sendError(int sc) {
+    status = sc;
+    committed = true;
+  }
+
+  @Override
+  public synchronized void sendError(int sc, String msg) {
+    status = sc;
+    committed = true;
+  }
+
+  @Override
+  public synchronized void sendRedirect(String loc) {
+    status = SC_FOUND;
+    setHeader(HttpHeaders.LOCATION, loc);
+    committed = true;
+  }
+
+  @Override
+  public void setDateHeader(String name, long value) {
+    setHeader(name, Long.toString(value));
+  }
+
+  @Override
+  public void setHeader(String name, String value) {
+    headers.removeAll(name.toLowerCase(Locale.US));
+    addHeader(name, value);
+  }
+
+  @Override
+  public void setIntHeader(String name, int value) {
+    headers.removeAll(name.toLowerCase(Locale.US));
+    addIntHeader(name, value);
+  }
+
+  @Override
+  public synchronized void setStatus(int sc) {
+    status = sc;
+    committed = true;
+  }
+
+  @Override
+  @Deprecated
+  public synchronized void setStatus(int sc, String msg) {
+    status = sc;
+    committed = true;
+  }
+
+  @Override
+  public synchronized int getStatus() {
+    return status;
+  }
+
+  @Override
+  public String getHeader(String name) {
+    return Iterables.getFirst(headers.get(requireNonNull(name.toLowerCase(Locale.US))), null);
+  }
+
+  @Override
+  public Collection<String> getHeaderNames() {
+    return headers.keySet();
+  }
+
+  @Override
+  public Collection<String> getHeaders(String name) {
+    return headers.get(requireNonNull(name.toLowerCase(Locale.US)));
+  }
+}
diff --git a/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpSession.java b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpSession.java
new file mode 100644
index 0000000..6a820e7
--- /dev/null
+++ b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/FakeHttpSession.java
@@ -0,0 +1,112 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.github;
+
+import java.util.Enumeration;
+import java.util.HashMap;
+import javax.servlet.ServletContext;
+import javax.servlet.http.HttpSession;
+import javax.servlet.http.HttpSessionContext;
+
+public class FakeHttpSession implements HttpSession {
+  private final HashMap<String, Object> attributes;
+
+  public FakeHttpSession() {
+    this.attributes = new HashMap<>();
+  }
+
+  @Override
+  public long getCreationTime() {
+    return 0;
+  }
+
+  @Override
+  public String getId() {
+    return null;
+  }
+
+  @Override
+  public long getLastAccessedTime() {
+    return 0;
+  }
+
+  @Override
+  public ServletContext getServletContext() {
+    return null;
+  }
+
+  @Override
+  public void setMaxInactiveInterval(int i) {}
+
+  @Override
+  public int getMaxInactiveInterval() {
+    return 0;
+  }
+
+  @Override
+  public HttpSessionContext getSessionContext() {
+    return null;
+  }
+
+  @Override
+  public Object getAttribute(String s) {
+    return attributes.get(s);
+  }
+
+  @Override
+  public Object getValue(String s) {
+    return getAttribute(s);
+  }
+
+  @Override
+  public Enumeration<String> getAttributeNames() {
+    return java.util.Collections.enumeration(attributes.keySet());
+  }
+
+  @Override
+  public String[] getValueNames() {
+    return attributes.keySet().toArray(new String[0]);
+  }
+
+  @Override
+  public void setAttribute(String s, Object o) {
+    attributes.put(s, o);
+  }
+
+  @Override
+  public void putValue(String s, Object o) {
+    setAttribute(s, o);
+  }
+
+  @Override
+  public void removeAttribute(String s) {
+    attributes.remove(s);
+  }
+
+  @Override
+  public void removeValue(String s) {
+    removeAttribute(s);
+  }
+
+  @Override
+  public void invalidate() {
+    attributes.clear();
+  }
+
+  @Override
+  public boolean isNew() {
+    return false;
+  }
+}
diff --git a/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/GitHubGroupCacheRefreshFilterTest.java b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/GitHubGroupCacheRefreshFilterTest.java
new file mode 100644
index 0000000..59a8e3a
--- /dev/null
+++ b/github-plugin/src/test/java/com/googlesource/gerrit/plugins/github/GitHubGroupCacheRefreshFilterTest.java
@@ -0,0 +1,120 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.github;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.googlesource.gerrit.plugins.github.filters.GitHubGroupCacheRefreshFilter;
+import com.googlesource.gerrit.plugins.github.group.GitHubGroupsCache;
+import com.googlesource.gerrit.plugins.github.groups.OrganizationStructure;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+import org.junit.Before;
+import org.junit.Test;
+
+public class GitHubGroupCacheRefreshFilterTest {
+  private static final FilterChain NOOP_FILTER_CHAIN_TEST = (req, res) -> {};
+  private static final String GITHUB_USERNAME_TEST = "somegithubuser";
+  private static final OrganizationStructure GITHUB_USER_ORGANIZATION = new OrganizationStructure();
+
+  private LoadingCache<String, OrganizationStructure> groupsByUsernameCache;
+  private GitHubGroupCacheRefreshFilter filter;
+  private FakeGroupCacheLoader groupsCacheLoader;
+  private int initialLoadCount;
+
+  private static class FakeGroupCacheLoader extends CacheLoader<String, OrganizationStructure> {
+    private final String username;
+    private final OrganizationStructure organizationStructure;
+    private int loadCount;
+
+    FakeGroupCacheLoader(String username, OrganizationStructure organizationStructure) {
+      this.username = username;
+      this.organizationStructure = organizationStructure;
+    }
+
+    @Override
+    public OrganizationStructure load(String u) throws Exception {
+      if (u.equals(username)) {
+        loadCount++;
+        return organizationStructure;
+      } else {
+        return null;
+      }
+    }
+
+    public int getLoadCount() {
+      return loadCount;
+    }
+  }
+
+  @Before
+  public void setUp() throws Exception {
+    groupsCacheLoader = new FakeGroupCacheLoader(GITHUB_USERNAME_TEST, GITHUB_USER_ORGANIZATION);
+    groupsByUsernameCache = CacheBuilder.newBuilder().build(groupsCacheLoader);
+    filter =
+        new GitHubGroupCacheRefreshFilter(
+            new GitHubGroupsCache(groupsByUsernameCache, () -> GITHUB_USERNAME_TEST));
+    // Trigger the initial load of the groups cache
+    assertThat(groupsByUsernameCache.get(GITHUB_USERNAME_TEST)).isEqualTo(GITHUB_USER_ORGANIZATION);
+    initialLoadCount = groupsCacheLoader.getLoadCount();
+  }
+
+  @Test
+  public void shouldReloadGroupsUponSuccessfulLogin() throws Exception {
+    FakeHttpServletRequest finalLoginRequest = newFinalLoginRequest();
+    filter.doFilter(finalLoginRequest, newFinalLoginRedirectWithCookie(), NOOP_FILTER_CHAIN_TEST);
+    filter.doFilter(
+        newHomepageRequest(finalLoginRequest.getSession()),
+        new FakeHttpServletResponse(),
+        NOOP_FILTER_CHAIN_TEST);
+
+    assertThat(groupsByUsernameCache.get(GITHUB_USERNAME_TEST)).isEqualTo(GITHUB_USER_ORGANIZATION);
+    assertThat(groupsCacheLoader.getLoadCount()).isEqualTo(initialLoadCount + 1);
+  }
+
+  @Test
+  public void shouldNotReloadGroupsOnRegularRequests() throws Exception {
+    FakeHttpServletRequest regularRequest = new FakeHttpServletRequest();
+    filter.doFilter(regularRequest, new FakeHttpServletResponse(), NOOP_FILTER_CHAIN_TEST);
+    filter.doFilter(
+        newHomepageRequest(regularRequest.getSession()),
+        new FakeHttpServletResponse(),
+        NOOP_FILTER_CHAIN_TEST);
+
+    assertThat(groupsByUsernameCache.get(GITHUB_USERNAME_TEST)).isEqualTo(GITHUB_USER_ORGANIZATION);
+    assertThat(groupsCacheLoader.getLoadCount()).isEqualTo(initialLoadCount);
+  }
+
+  private ServletRequest newHomepageRequest(HttpSession session) {
+    return new FakeHttpServletRequest("/", session);
+  }
+
+  private static HttpServletResponse newFinalLoginRedirectWithCookie() {
+    HttpServletResponse res = new FakeHttpServletResponse();
+    res.setHeader("Set-Cookie", "GerritAccount=foo");
+    return res;
+  }
+
+  private static FakeHttpServletRequest newFinalLoginRequest() {
+    FakeHttpServletRequest req = new FakeHttpServletRequest("/login", null);
+    req.setQueryString("final=true");
+    return req;
+  }
+}