Delegate batch execution to SQLDialect

To perform optimistic concurrency control, it is necessary to know
whether all rows of a statement batch have been successfully executed.
However, it is not needed to known which individual item of the batch
had been modified concurrently.

To determine this, on most databases, the update counts for the
individual rows of the batch need to be known, which is not always
possible as some databases may return a general success indicator
(SUCCESS_NO_INFO) only.

Other databases, like SAP MaxDB, allow to determine the total number of
rows updated by a batch execution directly by a vendor-specific API
call. To leverage this, this patch delegates the batch execution to the
SQLDialect. Also, the SQLDialect is equipped with a methods to indicate
whether batch execution allows to determine the total batch update
count.

Change-Id: I2692305e62dd8fa68dabcff81c2be1269a3a9689
Signed-off-by: Adrian Goerler <adrian.goerler@sap.com>
diff --git a/src/main/java/com/google/gwtorm/jdbc/JdbcAccess.java b/src/main/java/com/google/gwtorm/jdbc/JdbcAccess.java
index fa4dace..7e2a611 100644
--- a/src/main/java/com/google/gwtorm/jdbc/JdbcAccess.java
+++ b/src/main/java/com/google/gwtorm/jdbc/JdbcAccess.java
@@ -151,7 +151,7 @@
   @Override
   public void insert(final Iterable<T> instances) throws OrmException {
     try {
-      if (schema.getDialect().canDetermineIndividualBatchUpdateCounts()) {
+      if (schema.getDialect().canDetermineTotalBatchUpdateCount()) {
         insertAsBatch(instances);
       } else {
         insertIndividually(instances);
@@ -210,7 +210,7 @@
   @Override
   public void update(final Iterable<T> instances) throws OrmException {
     try {
-      if (schema.getDialect().canDetermineIndividualBatchUpdateCounts()) {
+      if (schema.getDialect().canDetermineTotalBatchUpdateCount()) {
         updateAsBatch(instances);
       } else {
         updateIndividually(instances);
@@ -373,7 +373,7 @@
   @Override
   public void delete(final Iterable<T> instances) throws OrmException {
     try {
-      if (schema.getDialect().canDetermineIndividualBatchUpdateCounts()) {
+      if (schema.getDialect().canDetermineTotalBatchUpdateCount()) {
         deleteAsBatch(instances);
       } else {
         deleteIndividually(instances);
@@ -429,20 +429,15 @@
     }
   }
 
-  private static void execute(final PreparedStatement ps, final int cnt)
+  private void execute(final PreparedStatement ps, final int cnt)
       throws SQLException, OrmConcurrencyException {
     if (cnt == 0) {
       return;
     }
 
-    final int[] states = ps.executeBatch();
-    if (states == null) {
-      throw new SQLException("No rows affected; expected " + cnt + " rows");
-    }
-    for (int i = 0; i < cnt; i++) {
-      if (states.length <= i || states[i] != 1) {
+    final int numberOfRowsUpdated = schema.getDialect().executeBatch(ps);
+    if (numberOfRowsUpdated != cnt) {
         throw new OrmConcurrencyException();
-      }
     }
   }
 
diff --git a/src/main/java/com/google/gwtorm/schema/sql/SqlDialect.java b/src/main/java/com/google/gwtorm/schema/sql/SqlDialect.java
index afc1d25..d75d233 100644
--- a/src/main/java/com/google/gwtorm/schema/sql/SqlDialect.java
+++ b/src/main/java/com/google/gwtorm/schema/sql/SqlDialect.java
@@ -23,6 +23,7 @@
 
 import java.sql.Connection;
 import java.sql.DatabaseMetaData;
+import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Statement;
@@ -340,4 +341,44 @@
     return true;
 
   }
+
+  /**
+   * Can the total number of rows updated by the PreparedStatement.executeBatch
+   * be determined exactly by the SQLDialect.executeBatch method?
+   *
+   * @return <code>true</code> if the SQlDialect.executeBatch method can exactly
+   *         determine the total number of rows updated by a batch;
+   *         <code>false</code> otherwise
+   * @see #executeBatch(PreparedStatement)
+   */
+  public boolean canDetermineTotalBatchUpdateCount() {
+    return true;
+  }
+
+  /**
+   * Executes a prepared statement batch and returns the total number of rows
+   * successfully updated or inserted. This method is intended to be overridden.
+   *
+   * If the canDetermineTotalBatchUpdateCount returns false for a particular
+   * SQLDialect, this method should throw an UnsupportedOperationException.
+   *
+   * @param ps the prepared statement with the batch to be executed
+   * @return the total number of rows affected
+   * @see #canDetermineIndividualBatchUpdateCounts()
+   */
+  public int executeBatch(PreparedStatement ps) throws SQLException {
+    final int[] updateCounts = ps.executeBatch();
+    if (updateCounts == null) {
+      throw new SQLException("No rows affected");
+    }
+    int totalUpdateCount = 0;
+    for (int i = 0; i < updateCounts.length; i++) {
+      int updateCount = updateCounts[i];
+      if (updateCount > 0) {
+        totalUpdateCount += updateCount;
+      }
+    }
+    return totalUpdateCount;
+  }
+
 }
diff --git a/src/test/java/com/google/gwtorm/jdbc/AbstractTestJdbcAccess.java b/src/test/java/com/google/gwtorm/jdbc/AbstractTestJdbcAccess.java
index 76319d6..c3893cd 100644
--- a/src/test/java/com/google/gwtorm/jdbc/AbstractTestJdbcAccess.java
+++ b/src/test/java/com/google/gwtorm/jdbc/AbstractTestJdbcAccess.java
@@ -177,8 +177,8 @@
     }
   }
 
-  protected static void assertUsedBatchingOnly(PreparedStatement ps,
-      int ...ids) throws SQLException {
+  protected static void assertUsedBatchingOnly(PreparedStatement ps, int... ids)
+      throws SQLException {
     verify(ps, times(ids.length)).addBatch();
     verify(ps).executeBatch();
     verify(ps, never()).executeUpdate();
@@ -186,7 +186,7 @@
   }
 
   protected static void assertUsedNonBatchingOnly(PreparedStatement ps,
-      int ... ids) throws SQLException {
+      int... ids) throws SQLException {
     verify(ps, never()).addBatch();
     verify(ps, never()).executeBatch();
     verify(ps, times(ids.length)).executeUpdate();
@@ -197,13 +197,16 @@
     verifyZeroInteractions(insert);
   }
 
-  protected abstract void assertCorrectUpdating(PreparedStatement ps, int ... ids)
-      throws SQLException;
+  protected abstract void assertCorrectUpdating(PreparedStatement ps,
+      int... ids) throws SQLException;
 
-  private static void assertExpectedIdsUsed(PreparedStatement statement,
-      int... ids) throws SQLException {
+  protected abstract void assertCorrectAttempting(PreparedStatement ps,
+      int... ids) throws SQLException;
+
+  private static void assertExpectedIdsUsed(PreparedStatement ps, int... ids)
+      throws SQLException {
     for (int id : ids) {
-      verify(statement).setInt(1, id);
+      verify(ps).setInt(1, id);
     }
   }
 
@@ -292,7 +295,7 @@
 
     classUnderTest.upsert(oneRow);
 
-    assertCorrectUpdating(update, 1);
+    assertCorrectAttempting(update, 1);
     assertNotUsed(insert);
   }
 
@@ -309,7 +312,7 @@
       assertSame(e.getCause(), exception);
     }
 
-    assertCorrectUpdating(update, 1);
+    assertCorrectAttempting(update, 1);
   }
 
   @Test
@@ -319,7 +322,7 @@
 
     classUnderTest.upsert(oneRow);
 
-    assertCorrectUpdating(update, 1);
+    assertCorrectAttempting(update, 1);
     assertCorrectUpdating(insert, 1);
   }
 
@@ -331,7 +334,7 @@
 
     classUnderTest.upsert(twoRows);
 
-    assertCorrectUpdating(update, 1, 2);
+    assertCorrectAttempting(update, 1, 2);
     assertCorrectUpdating(insert, 1, 2);
   }
 
@@ -342,7 +345,7 @@
 
     classUnderTest.upsert(twoRows);
 
-    assertCorrectUpdating(update, 1, 2);
+    assertCorrectAttempting(update, 1, 2);
     assertCorrectUpdating(insert, 1, 2);
   }
 
@@ -353,7 +356,7 @@
 
     classUnderTest.upsert(twoRows);
 
-    assertCorrectUpdating(update, 1, 2);
+    assertCorrectAttempting(update, 1, 2);
     assertNotUsed(insert);
   }
 
@@ -364,7 +367,7 @@
 
     classUnderTest.upsert(twoRows);
 
-    assertCorrectUpdating(update, 1, 2);
+    assertCorrectAttempting(update, 1, 2);
     assertCorrectUpdating(insert, 2);
   }
 
@@ -376,7 +379,7 @@
 
     classUnderTest.upsert(twoRows);
 
-    assertCorrectUpdating(update, 1, 2);
+    assertCorrectAttempting(update, 1, 2);
     assertCorrectUpdating(insert, 1, 2);
   }
 
@@ -387,7 +390,7 @@
 
     classUnderTest.upsert(twoRows);
 
-    assertCorrectUpdating(update, 1, 2);
+    assertCorrectAttempting(update, 1, 2);
     assertCorrectUpdating(insert, 1);
   }
 
diff --git a/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessBatching.java b/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessBatching.java
index eebd3a7..0d687d6 100644
--- a/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessBatching.java
+++ b/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessBatching.java
@@ -40,6 +40,12 @@
   }
 
   @Override
+  protected void assertCorrectAttempting(PreparedStatement ps,
+      int ... ids) throws SQLException {
+    assertUsedBatchingOnly(ps, ids);
+  }
+
+  @Override
   protected SqlDialect createDialect() {
     return mock(SqlDialect.class, CALLS_REAL_METHODS);
   }
diff --git a/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessNonBatching.java b/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessNonBatching.java
index 165d4c3..7ab5b80 100644
--- a/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessNonBatching.java
+++ b/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessNonBatching.java
@@ -36,8 +36,14 @@
   }
 
   @Override
-  protected void assertCorrectUpdating(PreparedStatement ps,
-      int ... ids) throws SQLException {
+  protected void assertCorrectUpdating(PreparedStatement ps, int... ids)
+      throws SQLException {
+    assertUsedNonBatchingOnly(ps, ids);
+  }
+
+  @Override
+  protected void assertCorrectAttempting(PreparedStatement ps, int... ids)
+      throws SQLException {
     assertUsedNonBatchingOnly(ps, ids);
   }
 
diff --git a/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessTotalUpdateCount.java b/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessTotalUpdateCount.java
new file mode 100644
index 0000000..9e329aa
--- /dev/null
+++ b/src/test/java/com/google/gwtorm/jdbc/TestJdbcAccessTotalUpdateCount.java
@@ -0,0 +1,80 @@
+// Copyright (C) 2011 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.gwtorm.jdbc;
+
+import static java.lang.Boolean.FALSE;
+import static java.lang.Boolean.TRUE;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.gwtorm.schema.sql.SqlDialect;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.sql.PreparedStatement;
+import java.sql.SQLException;
+
+@RunWith(Parameterized.class)
+public class TestJdbcAccessTotalUpdateCount extends AbstractTestJdbcAccess {
+
+  public TestJdbcAccessTotalUpdateCount(IterableProvider<Data> dataProvider)
+      throws SQLException {
+    super(dataProvider);
+  }
+
+  @Override
+  protected void assertCorrectUpdating(PreparedStatement ps, int... ids)
+      throws SQLException {
+    verify(dialect).executeBatch(ps);
+  }
+
+  @Override
+  protected void assertCorrectAttempting(PreparedStatement ps, int... ids)
+      throws SQLException {
+    assertUsedNonBatchingOnly(ps, ids);
+  }
+
+  @Override
+  protected SqlDialect createDialect() throws SQLException {
+    SqlDialect dialect = mock(SqlDialect.class);
+    when(dialect.canDetermineIndividualBatchUpdateCounts()).thenReturn(FALSE);
+    when(dialect.canDetermineTotalBatchUpdateCount()).thenReturn(TRUE);
+    when(dialect.executeBatch(any(PreparedStatement.class))).thenAnswer(
+        new Answer<Integer>() {
+
+          @Override
+          public Integer answer(InvocationOnMock invocation) throws Throwable {
+            if (sqlException != null) {
+              throw sqlException;
+            }
+            if (totalUpdateCount == null) {
+              throw new IllegalStateException("totalCount is not set");
+            }
+            return totalUpdateCount;
+          }
+        });
+    when(
+        dialect.convertError(any(String.class), any(String.class),
+            any(SQLException.class))).thenCallRealMethod();
+
+    return dialect;
+  }
+
+}