protobuf: Support encoding/decoding from any stream

When encoding to a stream we prefix the message with a varint
listing its length.  When decoding we read that varint, then use an
input stream wrapper which limits the number of bytes we consume
to that length.  This way the stream is left unaffected for the
next data segment, which might be something else.

Callers can use these to implement java.io.Serializable using their
own ProtobufCodec rather than relying on Java serialization.

Change-Id: I17fc5863b525fad22d14653831024eaf5b4ff9da
Signed-off-by: Shawn O. Pearce <sop@google.com>
diff --git a/src/main/java/com/google/gwtorm/protobuf/CappedInputStream.java b/src/main/java/com/google/gwtorm/protobuf/CappedInputStream.java
new file mode 100644
index 0000000..5bb3bea
--- /dev/null
+++ b/src/main/java/com/google/gwtorm/protobuf/CappedInputStream.java
@@ -0,0 +1,66 @@
+// Copyright 2010 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.gwtorm.protobuf;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+class CappedInputStream extends InputStream {
+  private final InputStream src;
+  private int remaining;
+
+  CappedInputStream(InputStream src, int limit) {
+    this.src = src;
+    this.remaining = limit;
+  }
+
+  @Override
+  public int read() throws IOException {
+    if (0 < remaining) {
+      int r = src.read();
+      if (r < 0) {
+        remaining = 0;
+      } else {
+        remaining--;
+      }
+      return r;
+    } else {
+      return -1;
+    }
+  }
+
+  @Override
+  public int read(byte[] b, int off, int len) throws IOException {
+    if (len == 0) {
+      return 0;
+    } else if (0 < remaining) {
+      int n = src.read(b, off, Math.min(len, remaining));
+      if (n < 0) {
+        remaining = 0;
+      } else {
+        remaining -= n;
+      }
+      return n;
+    } else {
+      return -1;
+    }
+  }
+
+  @Override
+  public void close() throws IOException {
+    remaining = 0;
+    src.close();
+  }
+}
diff --git a/src/main/java/com/google/gwtorm/protobuf/ProtobufCodec.java b/src/main/java/com/google/gwtorm/protobuf/ProtobufCodec.java
index e024be3..43202b1 100644
--- a/src/main/java/com/google/gwtorm/protobuf/ProtobufCodec.java
+++ b/src/main/java/com/google/gwtorm/protobuf/ProtobufCodec.java
@@ -18,8 +18,11 @@
 import com.google.protobuf.ByteString;
 import com.google.protobuf.CodedInputStream;
 import com.google.protobuf.CodedOutputStream;
+import com.google.protobuf.InvalidProtocolBufferException;
 
 import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.nio.ByteBuffer;
 
 /**
@@ -91,6 +94,22 @@
     }
   }
 
+  /**
+   * Encodes the object, prefixed by its encoded length.
+   * <p>
+   * The length is encoded as a raw varint with no tag.
+   *
+   * @param obj the object to encode.
+   * @param out stream that will receive the object's data.
+   * @throws IOException the stream failed to write data.
+   */
+  public void encodeWithSize(T obj, OutputStream out) throws IOException {
+    CodedOutputStream cos = CodedOutputStream.newInstance(out);
+    cos.writeRawVarint32(sizeof(obj));
+    encode(obj, cos);
+    cos.flush();
+  }
+
   private static ByteBufferOutputStream newStream(ByteBuffer buf) {
     return new ByteBufferOutputStream(buf);
   }
@@ -116,45 +135,30 @@
 
   /** Decode a byte string into an object instance. */
   public T decode(ByteString buf) {
-    try {
-      return decode(buf.newCodedInput());
-    } catch (IOException err) {
-      throw new RuntimeException("Cannot decode message", err);
-    }
+    T obj = newInstance();
+    mergeFrom(buf, obj);
+    return obj;
   }
 
   /** Decode a byte array into an object instance. */
   public T decode(byte[] data) {
-    return decode(data, 0, data.length);
+    T obj = newInstance();
+    mergeFrom(data, obj);
+    return obj;
   }
 
   /** Decode a byte array into an object instance. */
   public T decode(byte[] data, int offset, int length) {
-    try {
-      return decode(CodedInputStream.newInstance(data, offset, length));
-    } catch (IOException err) {
-      throw new RuntimeException("Cannot decode message", err);
-    }
+    T obj = newInstance();
+    mergeFrom(data, offset, length, obj);
+    return obj;
   }
 
   /** Decode a byte buffer into an object instance. */
   public T decode(ByteBuffer buf) {
-    if (buf.hasArray()) {
-      CodedInputStream in = CodedInputStream.newInstance( //
-          buf.array(), //
-          buf.position(), //
-          buf.remaining());
-      T obj;
-      try {
-        obj = decode(in);
-      } catch (IOException err) {
-        throw new RuntimeException("Cannot decode message", err);
-      }
-      buf.position(buf.position() + in.getTotalBytesRead());
-      return obj;
-    } else {
-      return decode(ByteString.copyFrom(buf));
-    }
+    T obj = newInstance();
+    mergeFrom(buf, obj);
+    return obj;
   }
 
   /**
@@ -168,10 +172,101 @@
     return obj;
   }
 
+  /** Decode an object that is prefixed by its encoded length. */
+  public T decodeWithSize(InputStream in) throws IOException {
+    T obj = newInstance();
+    mergeFromWithSize(in, obj);
+    return obj;
+  }
+
+  /** Decode a byte string into an existing object instance. */
+  public void mergeFrom(ByteString buf, T obj) {
+    try {
+      mergeFrom(buf.newCodedInput(), obj);
+    } catch (IOException err) {
+      throw new RuntimeException("Cannot decode message", err);
+    }
+  }
+
+  /** Decode a byte array into an existing object instance. */
+  public void mergeFrom(byte[] data, T obj) {
+    mergeFrom(data, 0, data.length, obj);
+  }
+
+  /** Decode a byte array into an existing object instance. */
+  public void mergeFrom(byte[] data, int offset, int length, T obj) {
+    try {
+      mergeFrom(CodedInputStream.newInstance(data, offset, length), obj);
+    } catch (IOException err) {
+      throw new RuntimeException("Cannot decode message", err);
+    }
+  }
+
+  /** Decode a byte buffer into an existing object instance. */
+  public void mergeFrom(ByteBuffer buf, T obj) {
+    if (buf.hasArray()) {
+      CodedInputStream in = CodedInputStream.newInstance( //
+          buf.array(), //
+          buf.position(), //
+          buf.remaining());
+      try {
+        mergeFrom(in, obj);
+      } catch (IOException err) {
+        throw new RuntimeException("Cannot decode message", err);
+      }
+      buf.position(buf.position() + in.getTotalBytesRead());
+    } else {
+      mergeFrom(ByteString.copyFrom(buf), obj);
+    }
+  }
+
+  /** Decode an object that is prefixed by its encoded length. */
+  public void mergeFromWithSize(InputStream in, T obj) throws IOException {
+    int sz = readRawVarint32(in);
+    mergeFrom(CodedInputStream.newInstance(new CappedInputStream(in, sz)), obj);
+  }
+
   /**
    * Decode an input stream into an existing object instance.
    *
    * @throws IOException the underlying stream cannot be read.
    */
   public abstract void mergeFrom(CodedInputStream in, T obj) throws IOException;
+
+  private static int readRawVarint32(InputStream in) throws IOException {
+    int b = in.read();
+    if (b == -1) {
+      throw new InvalidProtocolBufferException("Truncated input");
+    }
+
+    if ((b & 0x80) == 0) {
+      return b;
+    }
+
+    int result = b & 0x7f;
+    int offset = 7;
+    for (; offset < 32; offset += 7) {
+      b = in.read();
+      if (b == -1) {
+        throw new InvalidProtocolBufferException("Truncated input");
+      }
+      result |= (b & 0x7f) << offset;
+      if ((b & 0x80) == 0) {
+        return result;
+      }
+    }
+
+    // Keep reading up to 64 bits.
+    for (; offset < 64; offset += 7) {
+      b = in.read();
+      if (b == -1) {
+        throw new InvalidProtocolBufferException("Truncated input");
+      }
+      if ((b & 0x80) == 0) {
+        return result;
+      }
+    }
+
+    throw new InvalidProtocolBufferException("Malformed varint");
+  }
 }
diff --git a/src/test/java/com/google/gwtorm/protobuf/ProtobufEncoderTest.java b/src/test/java/com/google/gwtorm/protobuf/ProtobufEncoderTest.java
index 21007f6..7c9ce10 100644
--- a/src/test/java/com/google/gwtorm/protobuf/ProtobufEncoderTest.java
+++ b/src/test/java/com/google/gwtorm/protobuf/ProtobufEncoderTest.java
@@ -22,6 +22,8 @@
 
 import junit.framework.TestCase;
 
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.UnsupportedEncodingException;
 import java.nio.ByteBuffer;
@@ -215,6 +217,26 @@
     assertEquals(asString(exp), asString(act));
   }
 
+  public void testEncodeToStream()throws IOException {
+    ProtobufCodec<ThingWithEnum> e = CodecFactory.encoder(ThingWithEnum.class);
+
+    ThingWithEnum thing = new ThingWithEnum();
+    thing.type = ThingWithEnum.Type.B;
+
+    ByteArrayOutputStream out = new ByteArrayOutputStream();
+    e.encodeWithSize(thing, out);
+    byte[] exp = {0x02, 0x08, 0x01};
+    assertEquals(asString(exp), asString(out.toByteArray()));
+
+    byte[] exp2 = {0x02, 0x08, 0x01, '\n'};
+    ByteArrayInputStream in = new ByteArrayInputStream(exp2);
+    ThingWithEnum other = e.decodeWithSize(in);
+    assertEquals('\n', in.read());
+    assertEquals(-1, in.read());
+    assertNotNull(other.type);
+    assertSame(thing.type, other.type);
+  }
+
   private static String asString(byte[] bin)
       throws UnsupportedEncodingException {
     return new String(bin, "ISO-8859-1");