Implement batch fetching by primary key via SQL IN operator

If an entity's primary key is a single column and is a Java object
(not a primitive) we can easily override the get(Iterable<K>) so it
uses a SQL IN to fetch all requested rows in one query, rather than
making a round-trip per request.

This optimization works only for primary key queries, because I'm
being lazy and not implementing any other form.  It also only is
good on databases which don't suffer from high query plan costs.
For example on Oracle the per-statement optimization cost is high
and this IN query causes a different SQL statement to be created
for each unique number of keys requested.

Signed-off-by: Shawn O. Pearce <sop@google.com>
diff --git a/src/com/google/gwtorm/jdbc/JdbcAccess.java b/src/com/google/gwtorm/jdbc/JdbcAccess.java
index c651f37..8bb668b 100644
--- a/src/com/google/gwtorm/jdbc/JdbcAccess.java
+++ b/src/com/google/gwtorm/jdbc/JdbcAccess.java
@@ -24,6 +24,8 @@
 import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 
 /** Internal base class for implementations of {@link Access}. */
 public abstract class JdbcAccess<T, K extends Key<?>> extends
@@ -34,6 +36,45 @@
     schema = s;
   }
 
+  @Override
+  public final com.google.gwtorm.client.ResultSet<T> get(final Iterable<K> keys)
+      throws OrmException {
+    final Collection<K> keySet;
+    if (keys instanceof Collection) {
+      keySet = (Collection<K>) keys;
+    } else {
+      keySet = new ArrayList<K>();
+      for (final K k : keys) {
+        keySet.add(k);
+      }
+    }
+
+    switch (keySet.size()) {
+      case 0:
+        // Nothing requested, nothing to return.
+        //
+        return new ListResultSet<T>(Collections.<T> emptyList());
+
+      case 1: {
+        // Only one key requested, use a faster equality lookup.
+        //
+        final T entity = get(keySet.iterator().next());
+        if (entity != null) {
+          return new ListResultSet<T>(Collections.singletonList(entity));
+        }
+        return new ListResultSet<T>(Collections.<T> emptyList());
+      }
+
+      default:
+        return getBySqlIn(keySet);
+    }
+  }
+
+  protected com.google.gwtorm.client.ResultSet<T> getBySqlIn(
+      final Collection<K> keys) throws OrmException {
+    return super.get(keys);
+  }
+
   protected PreparedStatement prepareStatement(final String sql)
       throws OrmException {
     try {
@@ -43,6 +84,22 @@
     }
   }
 
+  protected PreparedStatement prepareBySqlIn(final String sql,
+      final Collection<K> keys) throws OrmException {
+    final int n = keys.size();
+    final StringBuilder buf = new StringBuilder(sql.length() + n << 1 + 1);
+    buf.append(sql);
+    buf.append('(');
+    for (int i = 0; i < n; i++) {
+      if (i > 0) {
+        buf.append(',');
+      }
+      buf.append('?');
+    }
+    buf.append(')');
+    return prepareStatement(buf.toString());
+  }
+
   protected T queryOne(final PreparedStatement ps) throws OrmException {
     try {
       try {
diff --git a/src/com/google/gwtorm/jdbc/gen/AccessGen.java b/src/com/google/gwtorm/jdbc/gen/AccessGen.java
index f4001d9..908fe48 100644
--- a/src/com/google/gwtorm/jdbc/gen/AccessGen.java
+++ b/src/com/google/gwtorm/jdbc/gen/AccessGen.java
@@ -37,6 +37,7 @@
 import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Iterator;
 import java.util.List;
 
@@ -108,6 +109,11 @@
 
     if (model.getPrimaryKey() != null) {
       implementKeyQuery(model.getPrimaryKey());
+      if ((model.getPrimaryKey().getField().isNested() || !model
+          .getPrimaryKey().getField().getPrimitiveType().isPrimitive())
+          && model.getPrimaryKey().getAllLeafColumns().size() == 1) {
+        overrideGetMany();
+      }
     }
     for (final KeyModel key : model.getSecondaryKeys()) {
       implementKeyQuery(key);
@@ -492,6 +498,106 @@
     mv.visitEnd();
   }
 
+  private void overrideGetMany() {
+    final KeyModel pk = model.getPrimaryKey();
+    final StringBuilder query = new StringBuilder();
+    query.append(model.getSelectSql(dialect, REL_ALIAS));
+    query.append(" WHERE ");
+    final ColumnModel pkcol = pk.getAllLeafColumns().iterator().next();
+    query.append(REL_ALIAS);
+    query.append('.');
+    query.append(pkcol.getColumnName());
+    query.append(" IN");
+
+    final MethodVisitor mv =
+        cw.visitMethod(ACC_PUBLIC | ACC_FINAL, "getBySqlIn", Type
+            .getMethodDescriptor(Type
+                .getType(com.google.gwtorm.client.ResultSet.class),
+                new Type[] {Type.getType(Collection.class)}), null,
+            new String[] {Type.getType(OrmException.class).getInternalName()});
+    mv.visitCode();
+
+    final int keyset = 1;
+    final int psvar = 2;
+    final int itrvar = 3;
+    final int colvar = 4;
+    final int keyvar = 5;
+
+    mv.visitVarInsn(ALOAD, 0);
+    mv.visitLdcInsn(query.toString());
+    mv.visitVarInsn(ALOAD, keyset);
+    mv.visitMethodInsn(INVOKEVIRTUAL, superTypeName, "prepareBySqlIn", Type
+        .getMethodDescriptor(Type.getType(PreparedStatement.class), new Type[] {
+            Type.getType(String.class), Type.getType(Collection.class)}));
+    mv.visitVarInsn(ASTORE, psvar);
+
+    mv.visitVarInsn(ALOAD, keyset);
+    mv.visitMethodInsn(INVOKEINTERFACE, Type.getInternalName(Collection.class),
+        "iterator", Type.getMethodDescriptor(Type.getType(Iterator.class),
+            new Type[] {}));
+    mv.visitVarInsn(ASTORE, itrvar);
+
+    mv.visitInsn(ICONST_1);
+    mv.visitVarInsn(ISTORE, colvar);
+
+    final Label endbind = new Label();
+    final Label again = new Label();
+    mv.visitLabel(again);
+    mv.visitVarInsn(ALOAD, itrvar);
+    mv.visitMethodInsn(INVOKEINTERFACE, Type.getInternalName(Iterator.class),
+        "hasNext", Type.getMethodDescriptor(Type.BOOLEAN_TYPE, new Type[] {}));
+    mv.visitJumpInsn(IFEQ, endbind);
+
+    mv.visitVarInsn(ALOAD, itrvar);
+    mv.visitMethodInsn(INVOKEINTERFACE, Type.getInternalName(Iterator.class),
+        "next", Type.getMethodDescriptor(Type.getType(Object.class),
+            new Type[] {}));
+    mv.visitTypeInsn(CHECKCAST, CodeGenSupport.toType(pk.getField())
+        .getInternalName());
+    mv.visitVarInsn(ASTORE, keyvar);
+
+    final CodeGenSupport cgs = new CodeGenSupport(mv) {
+      @Override
+      public void pushSqlHandle() {
+        mv.visitVarInsn(ALOAD, psvar);
+      }
+
+      @Override
+      public void pushFieldValue() {
+        appendGetField(getFieldReference());
+      }
+
+      @Override
+      public void pushColumnIndex() {
+        mv.visitVarInsn(ILOAD, colvar);
+      }
+
+      @Override
+      protected void appendGetField(final ColumnModel c) {
+        if (c.getParent() == null) {
+          mv.visitVarInsn(ALOAD, keyvar);
+        } else {
+          super.appendGetField(c);
+        }
+      }
+    };
+
+    cgs.setFieldReference(pkcol);
+    dialect.getSqlTypeInfo(pkcol).generatePreparedStatementSet(cgs);
+    mv.visitIincInsn(colvar, 1);
+    mv.visitJumpInsn(GOTO, again);
+
+    mv.visitLabel(endbind);
+    mv.visitVarInsn(ALOAD, 0);
+    mv.visitVarInsn(ALOAD, psvar);
+    mv.visitMethodInsn(INVOKEVIRTUAL, superTypeName, "queryList", Type
+        .getMethodDescriptor(Type.getType(ListResultSet.class),
+            new Type[] {Type.getType(PreparedStatement.class)}));
+    mv.visitInsn(ARETURN);
+    mv.visitMaxs(-1, -1);
+    mv.visitEnd();
+  }
+
   private void implementQuery(final QueryModel info) {
     final List<ColumnModel> pCols = info.getParameters();
     final boolean hasLimitParam = info.hasLimitParameter();