From e33f3fcac962603654b855f6cfc1916b9d5c0a97 Mon Sep 17 00:00:00 2001 From: Guillaume Nodet Date: Mon, 26 Oct 2020 22:37:31 +0100 Subject: [PATCH] Fix serialization of long (> 64k) strings, fixes #155 --- .../org/jboss/fuse/mvnd/common/Message.java | 146 +++++++++++++++--- .../jboss/fuse/mvnd/common/MessageTest.java | 51 ++++++ 2 files changed, 173 insertions(+), 24 deletions(-) create mode 100644 common/src/test/java/org/jboss/fuse/mvnd/common/MessageTest.java diff --git a/common/src/main/java/org/jboss/fuse/mvnd/common/Message.java b/common/src/main/java/org/jboss/fuse/mvnd/common/Message.java index 9c208abd..a494997f 100644 --- a/common/src/main/java/org/jboss/fuse/mvnd/common/Message.java +++ b/common/src/main/java/org/jboss/fuse/mvnd/common/Message.java @@ -17,9 +17,11 @@ package org.jboss.fuse.mvnd.common; import java.io.DataInputStream; import java.io.DataOutputStream; +import java.io.EOFException; import java.io.IOException; import java.io.PrintWriter; import java.io.StringWriter; +import java.io.UTFDataFormatException; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @@ -66,15 +68,15 @@ public abstract class Message { static void writeStringList(DataOutputStream output, List value) throws IOException { output.writeInt(value.size()); for (String v : value) { - output.writeUTF(v); + writeUTF(output, v); } } static void writeStringMap(DataOutputStream output, Map value) throws IOException { output.writeInt(value.size()); for (Map.Entry e : value.entrySet()) { - output.writeUTF(e.getKey()); - output.writeUTF(e.getValue()); + writeUTF(output, e.getKey()); + writeUTF(output, e.getValue()); } } @@ -82,7 +84,7 @@ public abstract class Message { ArrayList l = new ArrayList<>(); int nb = input.readInt(); for (int i = 0; i < nb; i++) { - l.add(input.readUTF()); + l.add(readUTF(input)); } return l; } @@ -91,13 +93,109 @@ public abstract class Message { LinkedHashMap m = new LinkedHashMap<>(); int nb = input.readInt(); for (int i = 0; i < nb; i++) { - String k = input.readUTF(); - String v = input.readUTF(); + String k = readUTF(input); + String v = readUTF(input); m.put(k, v); } return m; } + private static final String INVALID_BYTE = "Invalid byte"; + private static final int UTF_BUFS_CHAR_CNT = 256; + private static final int UTF_BUFS_BYTE_CNT = UTF_BUFS_CHAR_CNT * 3; + private static final ThreadLocal BUF_TLS = ThreadLocal.withInitial(() -> new byte[UTF_BUFS_BYTE_CNT]); + + static String readUTF(DataInputStream input) throws IOException { + byte[] byteBuf = BUF_TLS.get(); + int len = input.readInt(); + final char[] chars = new char[len]; + int i = 0, cnt = 0, charIdx = 0; + while (charIdx < len) { + if (i == cnt) { + cnt = input.read(byteBuf, 0, Math.min(UTF_BUFS_BYTE_CNT, len - charIdx)); + if (cnt < 0) { + throw new EOFException(); + } + i = 0; + } + final int a = byteBuf[i++] & 0xff; + if (a < 0x80) { + // low bit clear + chars[charIdx++] = (char) a; + } else if (a < 0xc0) { + throw new UTFDataFormatException(INVALID_BYTE); + } else if (a < 0xe0) { + if (i == cnt) { + cnt = input.read(byteBuf, 0, Math.min(UTF_BUFS_BYTE_CNT, len - charIdx)); + if (cnt < 0) { + throw new EOFException(); + } + i = 0; + } + final int b = byteBuf[i++] & 0xff; + if ((b & 0xc0) != 0x80) { + throw new UTFDataFormatException(INVALID_BYTE); + } + chars[charIdx++] = (char) ((a & 0x1f) << 6 | b & 0x3f); + } else if (a < 0xf0) { + if (i == cnt) { + cnt = input.read(byteBuf, 0, Math.min(UTF_BUFS_BYTE_CNT, len - charIdx)); + if (cnt < 0) { + throw new EOFException(); + } + i = 0; + } + final int b = byteBuf[i++] & 0xff; + if ((b & 0xc0) != 0x80) { + throw new UTFDataFormatException(INVALID_BYTE); + } + if (i == cnt) { + cnt = input.read(byteBuf, 0, Math.min(UTF_BUFS_BYTE_CNT, len - charIdx)); + if (cnt < 0) { + throw new EOFException(); + } + i = 0; + } + final int c = byteBuf[i++] & 0xff; + if ((c & 0xc0) != 0x80) { + throw new UTFDataFormatException(INVALID_BYTE); + } + chars[charIdx++] = (char) ((a & 0x0f) << 12 | (b & 0x3f) << 6 | c & 0x3f); + } else { + throw new UTFDataFormatException(INVALID_BYTE); + } + } + return String.valueOf(chars); + } + + static void writeUTF(DataOutputStream output, String s) throws IOException { + byte[] byteBuf = BUF_TLS.get(); + final int length = s.length(); + output.writeInt(length); + int strIdx = 0; + int byteIdx = 0; + while (strIdx < length) { + final char c = s.charAt(strIdx++); + if (c > 0 && c <= 0x7f) { + byteBuf[byteIdx++] = (byte) c; + } else if (c <= 0x07ff) { + byteBuf[byteIdx++] = (byte) (0xc0 | 0x1f & c >> 6); + byteBuf[byteIdx++] = (byte) (0x80 | 0x3f & c); + } else { + byteBuf[byteIdx++] = (byte) (0xe0 | 0x0f & c >> 12); + byteBuf[byteIdx++] = (byte) (0x80 | 0x3f & c >> 6); + byteBuf[byteIdx++] = (byte) (0x80 | 0x3f & c); + } + if (byteIdx > UTF_BUFS_BYTE_CNT - 4) { + output.write(byteBuf, 0, byteIdx); + byteIdx = 0; + } + } + if (byteIdx > 0) { + output.write(byteBuf, 0, byteIdx); + } + } + public static class BuildRequest extends Message { final List args; final String workingDir; @@ -106,8 +204,8 @@ public abstract class Message { public static Message read(DataInputStream input) throws IOException { List args = readStringList(input); - String workingDir = input.readUTF(); - String projectDir = input.readUTF(); + String workingDir = readUTF(input); + String projectDir = readUTF(input); Map env = readStringMap(input); return new BuildRequest(args, workingDir, projectDir, env); } @@ -148,8 +246,8 @@ public abstract class Message { public void write(DataOutputStream output) throws IOException { output.write(BUILD_REQUEST); writeStringList(output, args); - output.writeUTF(workingDir); - output.writeUTF(projectDir); + writeUTF(output, workingDir); + writeUTF(output, projectDir); writeStringMap(output, env); } } @@ -160,9 +258,9 @@ public abstract class Message { final String stackTrace; public static Message read(DataInputStream input) throws IOException { - String message = input.readUTF(); - String className = input.readUTF(); - String stackTrace = input.readUTF(); + String message = readUTF(input); + String className = readUTF(input); + String stackTrace = readUTF(input); return new BuildException(message, className, stackTrace); } @@ -206,9 +304,9 @@ public abstract class Message { @Override public void write(DataOutputStream output) throws IOException { output.write(BUILD_EXCEPTION); - output.writeUTF(message); - output.writeUTF(className); - output.writeUTF(stackTrace); + writeUTF(output, message); + writeUTF(output, className); + writeUTF(output, stackTrace); } } @@ -223,8 +321,8 @@ public abstract class Message { public static Message read(DataInputStream input) throws IOException { BuildEvent.Type type = BuildEvent.Type.values()[input.read()]; - String projectId = input.readUTF(); - String display = input.readUTF(); + String projectId = readUTF(input); + String display = readUTF(input); return new BuildEvent(type, projectId, display); } @@ -259,8 +357,8 @@ public abstract class Message { public void write(DataOutputStream output) throws IOException { output.write(BUILD_EVENT); output.write(type.ordinal()); - output.writeUTF(projectId); - output.writeUTF(display); + writeUTF(output, projectId); + writeUTF(output, display); } } @@ -269,8 +367,8 @@ public abstract class Message { final String message; public static Message read(DataInputStream input) throws IOException { - String projectId = input.readUTF(); - String message = input.readUTF(); + String projectId = readUTF(input); + String message = readUTF(input); return new BuildMessage(projectId.isEmpty() ? null : projectId, message); } @@ -298,8 +396,8 @@ public abstract class Message { @Override public void write(DataOutputStream output) throws IOException { output.write(BUILD_MESSAGE); - output.writeUTF(projectId != null ? projectId : ""); - output.writeUTF(message); + writeUTF(output, projectId != null ? projectId : ""); + writeUTF(output, message); } } diff --git a/common/src/test/java/org/jboss/fuse/mvnd/common/MessageTest.java b/common/src/test/java/org/jboss/fuse/mvnd/common/MessageTest.java new file mode 100644 index 00000000..0462b11e --- /dev/null +++ b/common/src/test/java/org/jboss/fuse/mvnd/common/MessageTest.java @@ -0,0 +1,51 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jboss.fuse.mvnd.common; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MessageTest { + + @Test + public void testBigMessage() throws IOException { + StringBuilder stringToWrite = new StringBuilder(); + for (int i = 0; i < 66000; ++i) { + stringToWrite.append("a"); + } + Message msg = new Message.BuildMessage("project", stringToWrite.toString()); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (DataOutputStream daos = new DataOutputStream(baos)) { + msg.write(daos); + } + + Message msg2; + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + try (DataInputStream dis = new DataInputStream(bais)) { + msg2 = Message.read(dis); + } + + assertTrue(msg2 instanceof Message.BuildMessage); + assertEquals(stringToWrite, ((Message.BuildMessage) msg2).getMessage()); + } +}