Rollback ref update when global-refdb update failure

Reference update contains two updates: local repository update
and global-refdb. Failure of the global-refdb update cause split
brain because local repository is ahead of the global-refdb. To
mitigate the problem, rollback local repository update when
global-refdb update fails.

Also for split brain always rollback the ref-update and return
LOCK_FAILURE regardless of the enforcement policy.

Bug: Issue 14028
Change-Id: Ia7cc54a5b32fa2997cfb4673cb768fe43107b02a
diff --git a/pom.xml b/pom.xml
index b2583eb..222e077 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
 
     <groupId>com.gerritforge</groupId>
     <artifactId>global-refdb</artifactId>
-    <version>3.3.2.1</version>
+    <version>3.3.3</version>
     <packaging>jar</packaging>
 
     <name>global-refdb</name>
diff --git a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidator.java b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidator.java
index 71d6621..e91d083 100644
--- a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidator.java
+++ b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidator.java
@@ -106,7 +106,9 @@
    */
   @SuppressWarnings("JavadocReference")
   public void executeBatchUpdateWithValidation(
-      BatchRefUpdate batchRefUpdate, NoParameterVoidFunction batchRefUpdateFunction)
+      BatchRefUpdate batchRefUpdate,
+      NoParameterVoidFunction batchRefUpdateFunction,
+      OneParameterVoidFunction<List<ReceiveCommand>> batchRefUpdateRollbackFunction)
       throws IOException {
     if (refEnforcement.getPolicy(projectName) == EnforcePolicy.IGNORED
         || !isGlobalProject(projectName)) {
@@ -115,7 +117,7 @@
     }
 
     try {
-      doExecuteBatchUpdate(batchRefUpdate, batchRefUpdateFunction);
+      doExecuteBatchUpdate(batchRefUpdate, batchRefUpdateFunction, batchRefUpdateRollbackFunction);
     } catch (IOException e) {
       logger.atWarning().withCause(e).log(
           "Failed to execute Batch Update on project %s", projectName);
@@ -126,7 +128,10 @@
   }
 
   private void doExecuteBatchUpdate(
-      BatchRefUpdate batchRefUpdate, NoParameterVoidFunction delegateUpdate) throws IOException {
+      BatchRefUpdate batchRefUpdate,
+      NoParameterVoidFunction delegateUpdate,
+      OneParameterVoidFunction<List<ReceiveCommand>> delegateUpdateRollback)
+      throws IOException {
 
     List<ReceiveCommand> commands = batchRefUpdate.getCommands();
     if (commands.isEmpty()) {
@@ -148,9 +153,18 @@
     }
 
     try (CloseableSet<AutoCloseable> locks = new CloseableSet<>()) {
-      refsToUpdate = compareAndGetLatestLocalRefs(refsToUpdate, locks);
+      final List<RefPair> finalRefsToUpdate = compareAndGetLatestLocalRefs(refsToUpdate, locks);
       delegateUpdate.invoke();
-      updateSharedRefDb(batchRefUpdate.getCommands().stream(), refsToUpdate);
+      try {
+        updateSharedRefDb(batchRefUpdate.getCommands().stream(), finalRefsToUpdate);
+      } catch (Exception e) {
+        List<ReceiveCommand> receiveCommands = batchRefUpdate.getCommands();
+        logger.atWarning().withCause(e).log(
+            String.format(
+                "Batch ref-update failing because of failure during the global refdb update. Set all commands Result to LOCK_FAILURE [%d]",
+                receiveCommands.size()));
+        rollback(delegateUpdateRollback, finalRefsToUpdate, receiveCommands);
+      }
     } catch (OutOfSyncException e) {
       List<ReceiveCommand> receiveCommands = batchRefUpdate.getCommands();
       logger.atWarning().withCause(e).log(
@@ -161,6 +175,24 @@
     }
   }
 
+  private void rollback(
+      OneParameterVoidFunction<List<ReceiveCommand>> delegateUpdateRollback,
+      List<RefPair> refsBeforeUpdate,
+      List<ReceiveCommand> receiveCommands)
+      throws IOException {
+    List<ReceiveCommand> rollbackCommands =
+        refsBeforeUpdate.stream()
+            .map(
+                refBeforeUpdate ->
+                    new ReceiveCommand(
+                        refBeforeUpdate.putValue,
+                        refBeforeUpdate.compareRef.getObjectId(),
+                        refBeforeUpdate.getName()))
+            .collect(Collectors.toList());
+    delegateUpdateRollback.invoke(rollbackCommands);
+    receiveCommands.forEach(command -> command.setResult(ReceiveCommand.Result.LOCK_FAILURE));
+  }
+
   private void updateSharedRefDb(Stream<ReceiveCommand> commandStream, List<RefPair> refsToUpdate)
       throws IOException {
     if (commandStream
diff --git a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidator.java b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidator.java
index 16dee20..6ab430d 100644
--- a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidator.java
+++ b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidator.java
@@ -35,6 +35,7 @@
 import org.eclipse.jgit.lib.Ref;
 import org.eclipse.jgit.lib.RefDatabase;
 import org.eclipse.jgit.lib.RefUpdate;
+import org.eclipse.jgit.lib.RefUpdate.Result;
 
 /** Enables the detection of out-of-sync by validating ref updates against the global refdb. */
 public class RefUpdateValidator {
@@ -73,6 +74,14 @@
     void invoke() throws IOException;
   }
 
+  public interface OneParameterFunction<F, T> {
+    T invoke(F f) throws IOException;
+  }
+
+  public interface OneParameterVoidFunction<T> {
+    void invoke(T f) throws IOException;
+  }
+
   /**
    * Constructs a {@code RefUpdateValidator} able to check the validity of ref-updates against a
    * global refdb before execution.
@@ -114,9 +123,10 @@
    * Checks whether the provided refUpdate should be validated first against the shared ref-db. If
    * not it just execute the provided refUpdateFunction. If it should be validated against the
    * global refdb then it does so by executing the {@link
-   * RefUpdateValidator#doExecuteRefUpdate(RefUpdate, NoParameterFunction)} first. Upon success the
-   * refUpdate is returned, upon failure split brain metrics are incremented and a {@link
-   * SharedDbSplitBrainException} is thrown.
+   * RefUpdateValidator#doExecuteRefUpdate(RefUpdate, NoParameterFunction,
+   * OneParameterFunction<ObjectId, Result>)} first. Upon success the refUpdate is returned, upon
+   * failure split brain metrics are incremented and a {@link SharedDbSplitBrainException} is
+   * thrown.
    *
    * <p>Validation is performed when either of these condition is true
    *
@@ -135,7 +145,9 @@
    * @throws IOException Execution of ref update failed
    */
   public RefUpdate.Result executeRefUpdate(
-      RefUpdate refUpdate, NoParameterFunction<RefUpdate.Result> refUpdateFunction)
+      RefUpdate refUpdate,
+      NoParameterFunction<RefUpdate.Result> refUpdateFunction,
+      OneParameterFunction<ObjectId, Result> rollbackFunction)
       throws IOException {
     if (isRefToBeIgnored(refUpdate.getName())
         || !isGlobalProject(projectName)
@@ -143,19 +155,7 @@
       return refUpdateFunction.invoke();
     }
 
-    try {
-      return doExecuteRefUpdate(refUpdate, refUpdateFunction);
-    } catch (SharedDbSplitBrainException e) {
-      validationMetrics.incrementSplitBrain();
-
-      logger.atWarning().withCause(e).log(
-          "Unable to execute ref-update on project=%s ref=%s",
-          projectName, refUpdate.getRef().getName());
-      if (refEnforcement.getPolicy(projectName) == EnforcePolicy.REQUIRED) {
-        throw e;
-      }
-    }
-    return null;
+    return doExecuteRefUpdate(refUpdate, refUpdateFunction, rollbackFunction);
   }
 
   private Boolean isRefToBeIgnored(String refName) {
@@ -182,14 +182,27 @@
   }
 
   protected RefUpdate.Result doExecuteRefUpdate(
-      RefUpdate refUpdate, NoParameterFunction<RefUpdate.Result> refUpdateFunction)
+      RefUpdate refUpdate,
+      NoParameterFunction<Result> refUpdateFunction,
+      OneParameterFunction<ObjectId, Result> rollbackFunction)
       throws IOException {
     try (CloseableSet<AutoCloseable> locks = new CloseableSet<>()) {
       RefPair refPairForUpdate = newRefPairFrom(refUpdate);
       compareAndGetLatestLocalRef(refPairForUpdate, locks);
       RefUpdate.Result result = refUpdateFunction.invoke();
-      if (isSuccessful(result)) {
-        updateSharedDbOrThrowExceptionFor(refPairForUpdate);
+      try {
+        if (isSuccessful(result)) {
+          updateSharedDbOrThrowExceptionFor(refPairForUpdate);
+        }
+      } catch (Exception e) {
+        result = rollbackFunction.invoke(refPairForUpdate.compareRef.getObjectId());
+        if (isSuccessful(result)) {
+          result = RefUpdate.Result.LOCK_FAILURE;
+        }
+        logger.atSevere().withCause(e).log(
+            String.format(
+                "Failed to update global refdb, the local refdb has been rolled back: %s",
+                e.getMessage()));
       }
       return result;
     } catch (OutOfSyncException e) {
@@ -223,7 +236,7 @@
           projectName,
           refPair.getName(),
           e.getMessage());
-      throw new SharedDbSplitBrainException(errorMessage, e);
+      throw e;
     }
 
     if (!succeeded) {
diff --git a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdate.java b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdate.java
index d7c0c61..5007dc6 100644
--- a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdate.java
+++ b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdate.java
@@ -36,6 +36,7 @@
 public class SharedRefDbBatchRefUpdate extends BatchRefUpdate {
 
   private final BatchRefUpdate batchRefUpdate;
+  private final BatchRefUpdate batchRefUpdateRollback;
   private final String project;
   private final BatchRefUpdateValidator.Factory batchRefValidatorFactory;
   private final RefDatabase refDb;
@@ -67,6 +68,7 @@
     this.refDb = refDb;
     this.project = project;
     this.batchRefUpdate = refDb.newBatchUpdate();
+    this.batchRefUpdateRollback = refDb.newBatchUpdate();
     this.batchRefValidatorFactory = batchRefValidatorFactory;
     this.ignoredRefs = ignoredRefs;
   }
@@ -88,6 +90,7 @@
 
   @Override
   public BatchRefUpdate setAllowNonFastForwards(boolean allow) {
+    batchRefUpdateRollback.setAllowNonFastForwards(allow);
     return batchRefUpdate.setAllowNonFastForwards(allow);
   }
 
@@ -98,6 +101,7 @@
 
   @Override
   public BatchRefUpdate setRefLogIdent(PersonIdent pi) {
+    batchRefUpdateRollback.setRefLogIdent(pi);
     return batchRefUpdate.setRefLogIdent(pi);
   }
 
@@ -113,6 +117,7 @@
 
   @Override
   public BatchRefUpdate setRefLogMessage(String msg, boolean appendStatus) {
+    batchRefUpdateRollback.setRefLogMessage(msg, appendStatus);
     return batchRefUpdate.setRefLogMessage(msg, appendStatus);
   }
 
@@ -123,6 +128,7 @@
 
   @Override
   public BatchRefUpdate setForceRefLog(boolean force) {
+    batchRefUpdateRollback.setForceRefLog(force);
     return batchRefUpdate.setForceRefLog(force);
   }
 
@@ -202,7 +208,10 @@
     batchRefValidatorFactory
         .create(project, refDb, ignoredRefs)
         .executeBatchUpdateWithValidation(
-            batchRefUpdate, () -> batchRefUpdate.execute(walk, monitor, options));
+            batchRefUpdate,
+            () -> batchRefUpdate.execute(walk, monitor, options),
+            (commands) ->
+                batchRefUpdateRollback.addCommand(commands).execute(walk, monitor, options));
   }
 
   /**
@@ -224,7 +233,9 @@
     batchRefValidatorFactory
         .create(project, refDb, ignoredRefs)
         .executeBatchUpdateWithValidation(
-            batchRefUpdate, () -> batchRefUpdate.execute(walk, monitor));
+            batchRefUpdate,
+            () -> batchRefUpdate.execute(walk, monitor),
+            (commands) -> batchRefUpdateRollback.addCommand(commands).execute(walk, monitor));
   }
 
   @Override
diff --git a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbRefUpdate.java b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbRefUpdate.java
index 9935c8d..e9809e1 100644
--- a/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbRefUpdate.java
+++ b/src/main/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbRefUpdate.java
@@ -14,6 +14,7 @@
 
 package com.gerritforge.gerrit.globalrefdb.validation;
 
+import com.gerritforge.gerrit.globalrefdb.validation.RefUpdateValidator.NoParameterFunction;
 import com.google.common.collect.ImmutableSet;
 import com.google.inject.Inject;
 import com.google.inject.assistedinject.Assisted;
@@ -122,7 +123,10 @@
    */
   @Override
   public Result update() throws IOException {
-    return refUpdateValidator.executeRefUpdate(refUpdateBase, refUpdateBase::update);
+    return refUpdateValidator.executeRefUpdate(
+        refUpdateBase,
+        refUpdateBase::update,
+        objectId -> rollback(objectId, refUpdateBase::update));
   }
 
   /**
@@ -137,7 +141,10 @@
    */
   @Override
   public Result update(RevWalk rev) throws IOException {
-    return refUpdateValidator.executeRefUpdate(refUpdateBase, () -> refUpdateBase.update(rev));
+    return refUpdateValidator.executeRefUpdate(
+        refUpdateBase,
+        () -> refUpdateBase.update(rev),
+        objectId -> rollback(objectId, () -> refUpdateBase.update(rev)));
   }
 
   /**
@@ -150,7 +157,10 @@
    */
   @Override
   public Result delete() throws IOException {
-    return refUpdateValidator.executeRefUpdate(refUpdateBase, refUpdateBase::delete);
+    return refUpdateValidator.executeRefUpdate(
+        refUpdateBase,
+        refUpdateBase::delete,
+        objectId -> rollback(objectId, refUpdateBase::update));
   }
 
   /**
@@ -163,7 +173,10 @@
    */
   @Override
   public Result delete(RevWalk walk) throws IOException {
-    return refUpdateValidator.executeRefUpdate(refUpdateBase, () -> refUpdateBase.delete(walk));
+    return refUpdateValidator.executeRefUpdate(
+        refUpdateBase,
+        () -> refUpdateBase.delete(walk),
+        objectId -> rollback(objectId, () -> refUpdateBase.update(walk)));
   }
 
   @Override
@@ -278,7 +291,10 @@
 
   @Override
   public Result forceUpdate() throws IOException {
-    return refUpdateValidator.executeRefUpdate(refUpdateBase, refUpdateBase::forceUpdate);
+    return refUpdateValidator.executeRefUpdate(
+        refUpdateBase,
+        refUpdateBase::forceUpdate,
+        objectId -> rollback(objectId, refUpdateBase::forceUpdate));
   }
 
   @Override
@@ -290,4 +306,11 @@
   public void setCheckConflicting(boolean check) {
     refUpdateBase.setCheckConflicting(check);
   }
+
+  private Result rollback(ObjectId objectId, NoParameterFunction<Result> updateFunction)
+      throws IOException {
+    refUpdateBase.setExpectedOldObjectId(refUpdateBase.getNewObjectId());
+    refUpdateBase.setNewObjectId(objectId);
+    return updateFunction.invoke();
+  }
 }
diff --git a/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidatorTest.java b/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidatorTest.java
index d90f45a..bb572a8 100644
--- a/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidatorTest.java
+++ b/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/BatchRefUpdateValidatorTest.java
@@ -22,9 +22,12 @@
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import com.gerritforge.gerrit.globalrefdb.GlobalRefDbSystemError;
+import com.gerritforge.gerrit.globalrefdb.validation.RefUpdateValidator.OneParameterVoidFunction;
 import com.gerritforge.gerrit.globalrefdb.validation.dfsrefdb.DefaultSharedRefEnforcement;
 import com.gerritforge.gerrit.globalrefdb.validation.dfsrefdb.RefFixture;
 import com.gerritforge.gerrit.globalrefdb.validation.dfsrefdb.SharedRefEnforcement;
@@ -46,6 +49,7 @@
 import org.eclipse.jgit.revwalk.RevCommit;
 import org.eclipse.jgit.revwalk.RevWalk;
 import org.eclipse.jgit.transport.ReceiveCommand;
+import org.eclipse.jgit.transport.ReceiveCommand.Result;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -68,10 +72,12 @@
 
   @Mock SharedRefEnforcement tmpRefEnforcement;
   @Mock ProjectsFilter projectsFilter;
+  @Mock OneParameterVoidFunction<List<ReceiveCommand>> rollbackFunction;
 
   @Before
   public void setup() throws Exception {
     super.setUp();
+    doReturn(false).when(sharedRefDatabase).isUpToDate(any(), any());
     when(projectsFilter.matches(anyString())).thenReturn(true);
     gitRepoSetup();
   }
@@ -94,7 +100,7 @@
     BatchRefUpdateValidator BatchRefUpdateValidator = newDefaultValidator(A_TEST_PROJECT_NAME);
 
     BatchRefUpdateValidator.executeBatchUpdateWithValidation(
-        batchRefUpdate, () -> execute(batchRefUpdate));
+        batchRefUpdate, () -> execute(batchRefUpdate), this::defaultRollback);
 
     verify(sharedRefDatabase, never())
         .compareAndPut(any(Project.NameKey.class), any(Ref.class), any(ObjectId.class));
@@ -109,7 +115,7 @@
     BatchRefUpdateValidator BatchRefUpdateValidator = newDefaultValidator(A_TEST_PROJECT_NAME);
 
     BatchRefUpdateValidator.executeBatchUpdateWithValidation(
-        batchRefUpdate, () -> execute(batchRefUpdate));
+        batchRefUpdate, () -> execute(batchRefUpdate), this::defaultRollback);
 
     verify(sharedRefDatabase, never())
         .compareAndPut(A_TEST_PROJECT_NAME_KEY, newRef(DRAFT_COMMENT, A.getId()), B.getId());
@@ -132,7 +138,39 @@
         .isUpToDate(A_TEST_PROJECT_NAME_KEY, newRef(AN_OUT_OF_SYNC_REF, AN_OBJECT_ID_1));
 
     batchRefUpdateValidator.executeBatchUpdateWithValidation(
-        batchRefUpdate, () -> execute(batchRefUpdate));
+        batchRefUpdate, () -> execute(batchRefUpdate), rollbackFunction);
+
+    verify(rollbackFunction, never()).invoke(any());
+
+    final List<ReceiveCommand> commands = batchRefUpdate.getCommands();
+    assertThat(commands.size()).isEqualTo(1);
+    commands.forEach(
+        (command) -> assertThat(command.getResult()).isEqualTo(ReceiveCommand.Result.LOCK_FAILURE));
+  }
+
+  @Test
+  public void shouldRollbackRefUpdateWhenRefDbIsNotUpdated() throws Exception {
+    String REF_NAME = "refs/changes/01/1/meta";
+    BatchRefUpdate batchRefUpdate =
+        newBatchUpdate(singletonList(new ReceiveCommand(A, B, REF_NAME, UPDATE)));
+    BatchRefUpdateValidator batchRefUpdateValidator =
+        getRefValidatorForEnforcement(A_TEST_PROJECT_NAME, tmpRefEnforcement);
+
+    doReturn(SharedRefEnforcement.EnforcePolicy.REQUIRED)
+        .when(batchRefUpdateValidator.refEnforcement)
+        .getPolicy(A_TEST_PROJECT_NAME, REF_NAME);
+
+    doReturn(true).when(sharedRefDatabase).isUpToDate(any(), any());
+
+    lenient()
+        .doThrow(GlobalRefDbSystemError.class)
+        .when(sharedRefDatabase)
+        .compareAndPut(any(), any(), any());
+
+    batchRefUpdateValidator.executeBatchUpdateWithValidation(
+        batchRefUpdate, () -> execute(batchRefUpdate), rollbackFunction);
+
+    verify(rollbackFunction, times(1)).invoke(any());
 
     final List<ReceiveCommand> commands = batchRefUpdate.getCommands();
     assertThat(commands.size()).isEqualTo(1);
@@ -142,7 +180,7 @@
 
   @Test
   public void shouldNotUpdateSharedRefDbWhenProjectIsLocal() throws Exception {
-    when(projectsFilter.matches(anyString())).thenReturn(true);
+    when(projectsFilter.matches(anyString())).thenReturn(false);
 
     String AN_OUT_OF_SYNC_REF = "refs/changes/01/1/1";
     BatchRefUpdate batchRefUpdate =
@@ -151,7 +189,7 @@
         getRefValidatorForEnforcement(A_TEST_PROJECT_NAME, tmpRefEnforcement);
 
     batchRefUpdateValidator.executeBatchUpdateWithValidation(
-        batchRefUpdate, () -> execute(batchRefUpdate));
+        batchRefUpdate, () -> execute(batchRefUpdate), this::defaultRollback);
 
     verify(sharedRefDatabase, never())
         .compareAndPut(any(Project.NameKey.class), any(Ref.class), any(ObjectId.class));
@@ -185,9 +223,14 @@
   private BatchRefUpdate newBatchUpdate(List<ReceiveCommand> cmds) {
     BatchRefUpdate u = refdir.newBatchUpdate();
     u.addCommand(cmds);
+    cmds.forEach(c -> c.setResult(Result.OK));
     return u;
   }
 
+  private void defaultRollback(List<ReceiveCommand> cmds) throws IOException {
+    // do nothing
+  }
+
   @Override
   public String testBranch() {
     return "branch_" + nameRule.getMethodName();
diff --git a/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidatorTest.java b/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidatorTest.java
index 268eb2d..f937286 100644
--- a/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidatorTest.java
+++ b/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/RefUpdateValidatorTest.java
@@ -20,12 +20,14 @@
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import com.gerritforge.gerrit.globalrefdb.GlobalRefDbSystemError;
+import com.gerritforge.gerrit.globalrefdb.validation.RefUpdateValidator.OneParameterFunction;
 import com.gerritforge.gerrit.globalrefdb.validation.dfsrefdb.DefaultSharedRefEnforcement;
 import com.gerritforge.gerrit.globalrefdb.validation.dfsrefdb.RefFixture;
-import com.gerritforge.gerrit.globalrefdb.validation.dfsrefdb.SharedDbSplitBrainException;
 import com.google.common.collect.ImmutableSet;
 import com.google.gerrit.entities.Project;
 import org.eclipse.jgit.lib.ObjectId;
@@ -56,6 +58,10 @@
 
   @Mock ProjectsFilter projectsFilter;
 
+  @Mock OneParameterFunction<ObjectId, Result> rollbackFunction;
+
+  @Mock AutoCloseable lock;
+
   String refName;
   Ref oldUpdateRef;
   Ref newUpdateRef;
@@ -72,12 +78,12 @@
 
     doReturn(localRef).when(localRefDb).findRef(refName);
     doReturn(localRef).when(localRefDb).exactRef(refName);
-    doReturn(oldUpdateRef).when(refUpdate).getRef();
     doReturn(newUpdateRef.getObjectId()).when(refUpdate).getNewObjectId();
     doReturn(refName).when(refUpdate).getName();
     lenient().doReturn(oldUpdateRef.getObjectId()).when(refUpdate).getOldObjectId();
 
     doReturn(true).when(projectsFilter).matches(anyString());
+    doReturn(Result.FAST_FORWARD).when(rollbackFunction).invoke(any());
 
     refUpdateValidator = newRefUpdateValidator(sharedRefDb);
   }
@@ -87,7 +93,8 @@
     SharedRefDatabaseWrapper noopSharedRefDbWrapper = new SharedRefDatabaseWrapper(sharedRefLogger);
 
     Result result =
-        newRefUpdateValidator(noopSharedRefDbWrapper).executeRefUpdate(refUpdate, () -> Result.NEW);
+        newRefUpdateValidator(noopSharedRefDbWrapper)
+            .executeRefUpdate(refUpdate, () -> Result.NEW, this::defaultRollback);
     assertThat(result).isEqualTo(Result.NEW);
   }
 
@@ -106,7 +113,8 @@
         .when(sharedRefDb)
         .compareAndPut(A_TEST_PROJECT_NAME_KEY, localRef, newUpdateRef.getObjectId());
 
-    Result result = refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW);
+    Result result =
+        refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW, this::defaultRollback);
 
     assertThat(result).isEqualTo(Result.NEW);
   }
@@ -124,7 +132,8 @@
         .compareAndPut(A_TEST_PROJECT_NAME_KEY, localRef, ObjectId.zeroId());
     doReturn(localRef).doReturn(null).when(localRefDb).findRef(refName);
 
-    Result result = refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.FORCED);
+    Result result =
+        refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.FORCED, this::defaultRollback);
 
     assertThat(result).isEqualTo(Result.FORCED);
   }
@@ -143,7 +152,8 @@
         .compareAndPut(A_TEST_PROJECT_NAME_KEY, localNullRef, newUpdateRef.getObjectId());
     doReturn(localNullRef).doReturn(newUpdateRef).when(localRefDb).findRef(refName);
 
-    Result result = refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW);
+    Result result =
+        refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW, this::defaultRollback);
 
     assertThat(result).isEqualTo(Result.NEW);
   }
@@ -157,13 +167,14 @@
     doReturn(true).when(sharedRefDb).exists(A_TEST_PROJECT_NAME_KEY, refName);
     doReturn(false).when(sharedRefDb).isUpToDate(A_TEST_PROJECT_NAME_KEY, localRef);
 
-    Result result = refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW);
+    Result result =
+        refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW, this::defaultRollback);
 
     assertThat(result).isEqualTo(Result.LOCK_FAILURE);
   }
 
-  @Test(expected = SharedDbSplitBrainException.class)
-  public void shouldTrowSplitBrainWhenLocalRefDbIsUpToDateButFinalCompareAndPutIsFailing()
+  @Test
+  public void shouldRollbackWhenLocalRefDbIsUpToDateButFinalCompareAndPutIsFailing()
       throws Exception {
     lenient()
         .doReturn(false)
@@ -177,8 +188,13 @@
     doReturn(false)
         .when(sharedRefDb)
         .compareAndPut(A_TEST_PROJECT_NAME_KEY, localRef, newUpdateRef.getObjectId());
+    doReturn(lock).when(sharedRefDb).lockRef(any(), anyString());
 
-    refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW);
+    Result result =
+        refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW, rollbackFunction);
+
+    verify(rollbackFunction, times(1)).invoke(any());
+    assertThat(result).isEqualTo(Result.LOCK_FAILURE);
   }
 
   @Test
@@ -189,7 +205,9 @@
         .isUpToDate(any(Project.NameKey.class), any(Ref.class));
     doReturn(true).when(sharedRefDb).isUpToDate(A_TEST_PROJECT_NAME_KEY, localRef);
 
-    Result result = refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.LOCK_FAILURE);
+    Result result =
+        refUpdateValidator.executeRefUpdate(
+            refUpdate, () -> Result.LOCK_FAILURE, this::defaultRollback);
 
     verify(sharedRefDb, never())
         .compareAndPut(any(Project.NameKey.class), any(Ref.class), any(ObjectId.class));
@@ -197,15 +215,37 @@
   }
 
   @Test
+  public void shouldRollbackRefUpdateCompareAndPutIsFailing() throws Exception {
+    lenient()
+        .doReturn(false)
+        .when(sharedRefDb)
+        .isUpToDate(any(Project.NameKey.class), any(Ref.class));
+    doReturn(true).when(sharedRefDb).isUpToDate(A_TEST_PROJECT_NAME_KEY, localRef);
+
+    when(sharedRefDb.compareAndPut(any(Project.NameKey.class), any(Ref.class), any(ObjectId.class)))
+        .thenThrow(GlobalRefDbSystemError.class);
+    when(rollbackFunction.invoke(any())).thenReturn(Result.LOCK_FAILURE);
+
+    Result result =
+        refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW, rollbackFunction);
+
+    verify(rollbackFunction, times(1)).invoke(any());
+  }
+
+  @Test
   public void shouldNotUpdateSharedRefDbWhenProjectIsLocal() throws Exception {
     when(projectsFilter.matches(anyString())).thenReturn(false);
 
-    refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW);
+    refUpdateValidator.executeRefUpdate(refUpdate, () -> Result.NEW, this::defaultRollback);
 
     verify(sharedRefDb, never())
         .compareAndPut(any(Project.NameKey.class), any(Ref.class), any(ObjectId.class));
   }
 
+  private Result defaultRollback(ObjectId objectId) {
+    return Result.NO_CHANGE;
+  }
+
   private RefUpdateValidator newRefUpdateValidator(SharedRefDatabaseWrapper refDbWrapper) {
     return new RefUpdateValidator(
         refDbWrapper,
diff --git a/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdateTest.java b/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdateTest.java
index 576e92e..365f382 100644
--- a/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdateTest.java
+++ b/src/test/java/com/gerritforge/gerrit/globalrefdb/validation/SharedRefDbBatchRefUpdateTest.java
@@ -139,7 +139,7 @@
     sharedRefDbRefUpdate = getSharedRefDbBatchRefUpdateWithMockedValidator();
     doThrow(new IOException("IO Test Exception"))
         .when(batchRefUpdateValidator)
-        .executeBatchUpdateWithValidation(any(), any());
+        .executeBatchUpdateWithValidation(any(), any(), any());
 
     sharedRefDbRefUpdate.execute(revWalk, progressMonitor, EMPTY_LIST);
   }