Merge "Add option to enforce auth with IdP on session expiration"
diff --git a/src/main/java/com/googlesource/gerrit/plugins/saml/HttpServletBufferedStatusResponse.java b/src/main/java/com/googlesource/gerrit/plugins/saml/HttpServletBufferedStatusResponse.java
new file mode 100644
index 0000000..6c593df
--- /dev/null
+++ b/src/main/java/com/googlesource/gerrit/plugins/saml/HttpServletBufferedStatusResponse.java
@@ -0,0 +1,68 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.googlesource.gerrit.plugins.saml;
+
+import java.io.IOException;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpServletResponseWrapper;
+
+class HttpServletBufferedStatusResponse extends HttpServletResponseWrapper {
+  private int status;
+  private String statusMsg;
+  private int error;
+  private String errorMsg;
+
+  HttpServletBufferedStatusResponse(HttpServletResponse response) {
+    super(response);
+  }
+
+  @Override
+  public void sendError(int sc) throws IOException {
+    sendError(sc, null);
+  }
+
+  @Override
+  public void sendError(int sc, String msg) throws IOException {
+    error = sc;
+    errorMsg = msg;
+  }
+
+  @Override
+  public void setStatus(int sc) {
+    setStatus(sc, null);
+  }
+
+  @Override
+  public void setStatus(int sc, String msg) {
+    status = sc;
+    statusMsg = msg;
+  }
+
+  void commit() throws IOException {
+    if (error > 0) {
+      if (errorMsg != null) {
+        super.sendError(error, errorMsg);
+      } else {
+        super.sendError(error);
+      }
+    } else if (status > 0) {
+      if (statusMsg != null) {
+        super.setStatus(status, statusMsg);
+      } else {
+        super.setStatus(status);
+      }
+    }
+  }
+}
diff --git a/src/main/java/com/googlesource/gerrit/plugins/saml/SamlWebFilter.java b/src/main/java/com/googlesource/gerrit/plugins/saml/SamlWebFilter.java
index aca8a40..8eaf717 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/saml/SamlWebFilter.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/saml/SamlWebFilter.java
@@ -19,12 +19,15 @@
 import com.google.gerrit.entities.Account;
 import com.google.gerrit.extensions.api.GerritApi;
 import com.google.gerrit.extensions.api.accounts.Accounts;
+import com.google.gerrit.extensions.client.AccountFieldName;
 import com.google.gerrit.extensions.restapi.RestApiException;
 import com.google.gerrit.extensions.restapi.Url;
+import com.google.gerrit.server.account.Realm;
 import com.google.gerrit.server.config.AuthConfig;
 import com.google.gerrit.server.util.ManualRequestContext;
 import com.google.gerrit.server.util.OneOffRequestContext;
 import com.google.inject.Inject;
+import com.google.inject.ProvisionException;
 import com.google.inject.Singleton;
 import java.io.IOException;
 import java.util.Arrays;
@@ -74,10 +77,12 @@
   private final GerritApi gApi;
   private final Accounts accounts;
   private final OneOffRequestContext oneOffRequestContext;
+  private final boolean realmAllowsFullNameEditing;
 
   @Inject
   SamlWebFilter(
       AuthConfig auth,
+      Realm realm,
       SamlConfig samlConfig,
       SamlMembership samlMembership,
       @AuthHeaders Set<String> authHeaders,
@@ -86,7 +91,13 @@
       SAML2Client saml2Client,
       OneOffRequestContext oneOffRequestContext) {
     this.auth = auth;
+    if (auth.getHttpDisplaynameHeader() != null) {
+      throw new ProvisionException(
+          "auth.httpdisplaynameheader is not compatible with SAML: remove the config and restart");
+    }
+
     this.samlConfig = samlConfig;
+    this.realmAllowsFullNameEditing = realm.allowsEdit(AccountFieldName.FULL_NAME);
     this.samlMembership = samlMembership;
     log.debug("Max Authentication Lifetime: " + samlConfig.getMaxAuthLifetimeAttr());
     this.saml2Client = saml2Client;
@@ -129,13 +140,22 @@
           redirectToIdentityProvider(context);
         } else {
           HttpServletRequest req = new AuthenticatedHttpRequest(httpRequest, user);
-          chain.doFilter(req, response);
-          try (ManualRequestContext ignored =
-              oneOffRequestContext.openAs(
-                  Account.id(accounts.id(user.getUsername()).get()._accountId))) {
-            gApi.accounts().id(user.getUsername()).setName(user.getDisplayName());
-          } catch (RestApiException e) {
-            log.error("Saml plugin could not set account name", e);
+
+          if (realmAllowsFullNameEditing) {
+            HttpServletBufferedStatusResponse respWrapper =
+                new HttpServletBufferedStatusResponse(httpResponse);
+            chain.doFilter(req, respWrapper);
+            try (ManualRequestContext ignored =
+                oneOffRequestContext.openAs(
+                    Account.id(accounts.id(user.getUsername()).get()._accountId))) {
+              gApi.accounts().id(user.getUsername()).setName(user.getDisplayName());
+              respWrapper.commit();
+            } catch (RestApiException e) {
+              log.error("Saml plugin could not set account name", e);
+              httpResponse.sendError(HttpServletResponse.SC_FORBIDDEN);
+            }
+          } else {
+            chain.doFilter(req, httpResponse);
           }
         }
       } else if (isGerritLogout(httpRequest)) {
diff --git a/src/test/java/com/googlesource/gerrit/plugins/saml/SamlWebFilterIT.java b/src/test/java/com/googlesource/gerrit/plugins/saml/SamlWebFilterIT.java
index 9442c3a..be9a7f7 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/saml/SamlWebFilterIT.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/saml/SamlWebFilterIT.java
@@ -16,20 +16,37 @@
 
 import static com.google.common.base.Strings.nullToEmpty;
 import static com.google.common.truth.Truth.assertThat;
+import static javax.servlet.http.HttpServletResponse.SC_FORBIDDEN;
 import static javax.servlet.http.HttpServletResponse.SC_OK;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 
 import com.google.gerrit.acceptance.AbstractDaemonTest;
+import com.google.gerrit.extensions.api.GerritApi;
+import com.google.gerrit.extensions.api.accounts.Accounts;
+import com.google.gerrit.extensions.client.AccountFieldName;
 import com.google.gerrit.extensions.common.AccountDetailInfo;
 import com.google.gerrit.extensions.restapi.RestApiException;
+import com.google.gerrit.server.ServerInitiated;
+import com.google.gerrit.server.account.AccountsUpdate;
+import com.google.gerrit.server.account.Realm;
+import com.google.gerrit.server.config.AuthConfig;
+import com.google.gerrit.server.util.OneOffRequestContext;
 import com.google.gerrit.testing.ConfigSuite;
 import com.google.gerrit.util.http.testutil.FakeHttpServletRequest;
 import com.google.gerrit.util.http.testutil.FakeHttpServletResponse;
+import com.google.inject.Inject;
+import com.google.inject.Injector;
 import com.google.inject.Module;
+import com.google.inject.Provider;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpSession;
 import org.eclipse.jgit.errors.ConfigInvalidException;
@@ -38,6 +55,8 @@
 
 public class SamlWebFilterIT extends AbstractDaemonTest {
 
+  @Inject @ServerInitiated private Provider<AccountsUpdate> accountsUpdateProvider;
+
   @ConfigSuite.Default
   public static Config setupSaml() throws ConfigInvalidException {
     Config cfg = new Config();
@@ -57,8 +76,7 @@
   }
 
   @Test
-  public void supportAccountNamesWithNonIso88591Characters()
-      throws IOException, ServletException, RestApiException {
+  public void supportAccountNamesWithNonIso88591Characters() throws Exception {
     SamlWebFilter samlWebFilter = server.getTestInjector().getInstance(SamlWebFilter.class);
 
     String samlDisplayName = nullToEmpty(user.displayName()) + " Saml Test 合覺那加情力心";
@@ -72,8 +90,7 @@
     req.setPathInfo(SamlWebFilter.GERRIT_LOGIN);
     HttpServletResponse res = new FakeHttpServletResponse();
 
-    samlWebFilter.doFilter(req, res, mock(FilterChain.class));
-
+    samlWebFilter.doFilter(req, res, mockFilterReturningStatusOK());
     assertThat(res.getStatus()).isEqualTo(SC_OK);
 
     AccountDetailInfo account = gApi.accounts().id(user.username()).detail();
@@ -85,6 +102,109 @@
     return new com.googlesource.gerrit.plugins.saml.Module();
   }
 
+  @Test
+  public void failAuthenticationWhenAccountManipulationFails() throws Exception {
+    SamlWebFilter samlWebFilter =
+        newSamlWebFilter(
+            server.getTestInjector().getInstance(Realm.class),
+            newGerritApiMockFailingOnAccountsApi());
+    List<Integer> responseStatuses = new ArrayList<>();
+
+    HttpServletResponse res =
+        new FakeHttpServletResponse() {
+          @Override
+          public synchronized void setStatus(int sc) {
+            super.setStatus(sc);
+            responseStatuses.add(sc);
+          }
+
+          @SuppressWarnings("deprecation")
+          @Override
+          public synchronized void setStatus(int sc, String msg) {
+            super.setStatus(sc, msg);
+            responseStatuses.add(sc);
+          }
+
+          @Override
+          public synchronized void sendError(int sc) {
+            super.sendError(sc);
+            responseStatuses.add(sc);
+          }
+
+          @Override
+          public synchronized void sendError(int sc, String msg) {
+            super.sendError(sc, msg);
+            responseStatuses.add(sc);
+          }
+        };
+
+    samlWebFilter.doFilter(newFinalLoginFakeHttpRequest(), res, mockFilterReturningStatusOK());
+
+    assertThat(res.getStatus()).isEqualTo(SC_FORBIDDEN);
+    assertThat(responseStatuses).containsExactly(SC_FORBIDDEN);
+  }
+
+  private FilterChain mockFilterReturningStatusOK() {
+    return (request, response) ->
+        ((HttpServletResponse) response).setStatus(HttpServletResponse.SC_OK);
+  }
+
+  private FakeHttpServletRequest newFinalLoginFakeHttpRequest() {
+    HttpSession httpSession = mock(HttpSession.class);
+    AuthenticatedUser authenticatedUser =
+        new AuthenticatedUser(user.username(), "User Fullname", user.email(), "externalId");
+    doReturn(authenticatedUser).when(httpSession).getAttribute(SamlWebFilter.SESSION_ATTR_USER);
+    FakeHttpServletRequest req = new FakeHttpServletRequestWithSession(httpSession);
+    req.setPathInfo(SamlWebFilter.GERRIT_LOGIN);
+    return req;
+  }
+
+  private GerritApi newGerritApiMockFailingOnAccountsApi() throws RestApiException {
+    GerritApi apiMock = mock(GerritApi.class);
+    Accounts accountsApiMock = mock(Accounts.class);
+    doReturn(accountsApiMock).when(apiMock).accounts();
+    doThrow(RestApiException.class).when(accountsApiMock).id(any());
+    return apiMock;
+  }
+
+  private SamlWebFilter newSamlWebFilter(Realm realm, GerritApi gerritApi) throws IOException {
+    Injector testInjector = server.getTestInjector();
+    return new SamlWebFilter(
+        testInjector.getInstance(AuthConfig.class),
+        realm,
+        testInjector.getInstance(SamlConfig.class),
+        testInjector.getInstance(SamlMembership.class),
+        Collections.emptySet(),
+        gerritApi,
+        testInjector.getInstance(Accounts.class),
+        null,
+        testInjector.getInstance(OneOffRequestContext.class));
+  }
+
+  @Test
+  public void shouldSucceedAndNotSetFullNameWhenNotAllowedByRealm() throws Exception {
+    Realm realmMock = mock(Realm.class);
+    doReturn(false).when(realmMock).allowsEdit(eq(AccountFieldName.FULL_NAME));
+    SamlWebFilter samlWebFilter = newSamlWebFilter(realmMock, gApi);
+
+    String samlDisplayName = "Test display name";
+
+    HttpSession httpSession = mock(HttpSession.class);
+    AuthenticatedUser authenticatedUser =
+        new AuthenticatedUser(user.username(), samlDisplayName, user.email(), "externalId");
+    doReturn(authenticatedUser).when(httpSession).getAttribute(SamlWebFilter.SESSION_ATTR_USER);
+
+    FakeHttpServletRequest req = new FakeHttpServletRequestWithSession(httpSession);
+    req.setPathInfo(SamlWebFilter.GERRIT_LOGIN);
+    HttpServletResponse res = new FakeHttpServletResponse();
+
+    samlWebFilter.doFilter(req, res, mock(FilterChain.class));
+    assertThat(res.getStatus()).isEqualTo(SC_OK);
+
+    AccountDetailInfo account = gApi.accounts().id(user.username()).detail();
+    assertThat(account.name).isNotEqualTo(samlDisplayName);
+  }
+
   private static class FakeHttpServletRequestWithSession extends FakeHttpServletRequest {
     HttpSession session;