persolijn

an efficient router for busses
Log | Files | Refs

ProtobufReader.java (15469B)


      1 package protobuf;
      2 
      3 import java.io.IOException;
      4 import java.io.InputStream;
      5 import java.nio.charset.StandardCharsets;
      6 import java.util.Iterator;
      7 import java.util.function.BinaryOperator;
      8 import java.util.function.Consumer;
      9 import java.util.function.Function;
     10 import java.util.function.Supplier;
     11 import java.util.function.UnaryOperator;
     12 
     13 import protobuf.exception.InputException;
     14 import protobuf.exception.OverflowException;
     15 import protobuf.exception.UnexpectedTagException;
     16 import protobuf.exception.WireTypeException;
     17 
     18 /**
     19  * Represents an interface for parsing Protobuf wire elements, providing methods
     20  * for reading various Protobuf data types.
     21  */
     22 public class ProtobufReader {
     23     private final MessageIterator message;
     24     private final WireType type;
     25     private final int tag;
     26     private boolean resetType = false;
     27 
     28     public ProtobufReader(MessageIterator message, WireType type, int tag) {
     29         this.message = message;
     30         this.type = type;
     31         this.tag = tag;
     32     }
     33 
     34     /**
     35      * Gets the underlying input stream.
     36      *
     37      * @return The input stream.
     38      */
     39     public InputStream getInputStream() {
     40         return message.input;
     41     }
     42 
     43     /**
     44      * Gets the Protobuf wire type.
     45      *
     46      * @return The wire type.
     47      */
     48     public WireType getType() {
     49         return resetType ? null : type;
     50     }
     51 
     52     /**
     53      * Resets the Protobuf wire type.
     54      * Useful for parsing packed streams.
     55      */
     56     public void resetType() {
     57         resetType = true;
     58     }
     59 
     60     /**
     61      * Gets the tag associated with the wire element.
     62      *
     63      * @return The wire tag.
     64      */
     65     public int tag() {
     66         return tag;
     67     }
     68 
     69     /**
     70      * Reads a signed variable-length integer as a 64-bit integer.
     71      *
     72      * @return The parsed signed 64-bit integer.
     73      */
     74     public long svarint64() {
     75         long n = varint64();
     76         return (n & 0x01) == 0
     77                 ? (n >> 1)
     78                 : -(n >> 1) - 1;
     79     }
     80 
     81     /**
     82      * Parses a variable-length integer as a signed 64-bit integer.
     83      *
     84      * @return The parsed signed 64-bit integer.
     85      */
     86     public long varint64() {
     87         return varint64(false);
     88     }
     89 
     90     /**
     91      * Reads a signed variable-length integer as a 32-bit integer.
     92      *
     93      * @return The parsed signed 32-bit integer.
     94      */
     95     public long varint64(boolean ignoreType) {
     96         if (!ignoreType && getType() != null && getType() != WireType.VARINT)
     97             throw new WireTypeException(WireType.VARINT, getType());
     98 
     99         long result = 0;
    100         long b = 0;
    101         int shift = 0;
    102         while (shift < 64 && message.length > 0) {
    103             try {
    104                 b = message.input.read();
    105             } catch (IOException exc) {
    106                 throw new InputException(exc);
    107             }
    108             if (b == -1)
    109                 break;
    110 
    111             message.length--;
    112 
    113             result |= (b & 0x7f) << shift;
    114             shift += 7;
    115             if ((b & 0x80) == 0)
    116                 return result;
    117         }
    118 
    119         throw new OverflowException("input exceed");
    120     }
    121 
    122     public long svarint32() {
    123         int n = varint32();
    124         return (n & 0x01) == 0
    125                 ? (n >> 1)
    126                 : -(n >> 1);
    127 
    128     }
    129 
    130     /**
    131      * Parses a variable-length integer as a signed 32-bit integer.
    132      *
    133      * @return The parsed signed 32-bit integer.
    134      */
    135     public int varint32() {
    136         return varint32(false);
    137     }
    138 
    139     private int varint32(boolean ignoreType) {
    140         if (!ignoreType && getType() != null && getType() != WireType.VARINT)
    141             throw new WireTypeException(WireType.VARINT, getType());
    142 
    143         int result = 0;
    144         int b = 0;
    145         int shift = 0;
    146         while (shift < 32 && message.length > 0) {
    147             try {
    148                 b = message.input.read();
    149             } catch (IOException exc) {
    150                 throw new InputException(exc);
    151             }
    152             if (b == -1)
    153                 break;
    154 
    155             message.length--;
    156 
    157             result |= (b & 0x7f) << shift;
    158             shift += 7;
    159             if ((b & 0x80) == 0)
    160                 return result;
    161         }
    162         throw new OverflowException("input exceed");
    163     }
    164 
    165     /**
    166      * Skips the variable-length integer.
    167      */
    168     public void skipVarint() {
    169         if (getType() != null && getType() != WireType.VARINT)
    170             throw new WireTypeException(WireType.VARINT, getType());
    171 
    172         int b = 0;
    173         while (message.length > 0) {
    174             try {
    175                 b = message.input.read();
    176             } catch (IOException exc) {
    177                 throw new InputException(exc);
    178             }
    179             if (b == -1)
    180                 break;
    181 
    182             message.length--;
    183             if ((b & 0x80) == 0)
    184                 return;
    185         }
    186         throw new OverflowException("input exceed");
    187     }
    188 
    189     /**
    190      * Reads a fixed 64-bit integer.
    191      *
    192      * @return The parsed 64-bit integer.
    193      */
    194     public long fixed64() {
    195         if (getType() != null && getType() != WireType.I64)
    196             throw new WireTypeException(WireType.I64, getType());
    197 
    198         if (message.length < 8)
    199             throw new OverflowException("input exceed");
    200 
    201         byte[] bytes;
    202         try {
    203             bytes = message.input.readNBytes(8);
    204         } catch (IOException exc) {
    205             throw new InputException(exc);
    206         }
    207         long result = 0;
    208 
    209         for (int i = bytes.length - 1; i >= 0; i--) {
    210             result <<= 8;
    211             result |= bytes[i];
    212         }
    213 
    214         return result;
    215     }
    216 
    217     /**
    218      * Skips a fixed 64-bit integer.
    219      */
    220     public void skip64() {
    221         if (getType() != null && getType() != WireType.I64)
    222             throw new WireTypeException(WireType.I64, getType());
    223 
    224         if (message.length < 8)
    225             throw new OverflowException("input exceed");
    226 
    227         message.length -= 8;
    228         try {
    229             message.input.skipNBytes(8);
    230         } catch (IOException exc) {
    231             throw new InputException(exc);
    232         }
    233     }
    234 
    235     /**
    236      * Reads a fixed 32-bit integer.
    237      *
    238      * @return The parsed 32-bit integer.
    239      */
    240     public int fixed32() {
    241         if (getType() != null && getType() != WireType.I32)
    242             throw new WireTypeException(WireType.I32, getType());
    243 
    244         if (message.length < 4)
    245             throw new OverflowException("input exceed");
    246 
    247         byte[] bytes;
    248         try {
    249             bytes = message.input.readNBytes(4);
    250         } catch (IOException exc) {
    251             throw new InputException(exc);
    252         }
    253         int result = 0;
    254 
    255         for (int i = bytes.length - 1; i >= 0; i--) {
    256             result <<= 8;
    257             result |= bytes[i];
    258         }
    259 
    260         return result;
    261     }
    262 
    263     /**
    264      * Skips a fixed 32-bit integer.
    265      */
    266     public void skip32() {
    267         if (getType() != null && getType() != WireType.I32)
    268             throw new WireTypeException(WireType.I32, getType());
    269 
    270         if (message.length < 4)
    271             throw new OverflowException("input exceed");
    272 
    273         message.length -= 4;
    274         try {
    275             message.input.skipNBytes(4);
    276         } catch (IOException exc) {
    277             throw new InputException(exc);
    278         }
    279     }
    280 
    281     /**
    282      * Reads a byte array.
    283      *
    284      * @return The read byte array.
    285      */
    286     public byte[] bytes() {
    287         if (getType() != null && getType() != WireType.LEN)
    288             throw new WireTypeException(WireType.LEN, getType());
    289 
    290         int len = varint32(true);
    291         if (message.length < len)
    292             throw new OverflowException("input exceed");
    293 
    294         message.length -= len;
    295         try {
    296             return message.input.readNBytes(len);
    297         } catch (IOException exc) {
    298             throw new InputException(exc);
    299         }
    300     }
    301 
    302     /**
    303      * Skips a byte array.
    304      */
    305     public void skipBytes() {
    306         if (getType() != null && getType() != WireType.LEN)
    307             throw new WireTypeException(WireType.LEN, getType());
    308 
    309         int len = varint32(true);
    310         if (message.length < len)
    311             throw new OverflowException("input exceed");
    312 
    313         message.length -= len;
    314         try {
    315             message.input.skipNBytes(len);
    316         } catch (IOException exc) {
    317             throw new InputException(exc);
    318         }
    319     }
    320 
    321     /**
    322      * Reads a string.
    323      *
    324      * @return The read string.
    325      */
    326     public String string() {
    327         return new String(bytes(), StandardCharsets.UTF_8);
    328     }
    329 
    330     /**
    331      * Reads a message using the provided handler.
    332      *
    333      * @param handler The message handler.
    334      * @param <T>     The type of the parsed message.
    335      * @return The parsed message.
    336      */
    337     public <T> T message(Message<T> handler) {
    338         if (getType() != null && getType() != WireType.LEN)
    339             throw new WireTypeException(WireType.LEN, getType());
    340 
    341         int len = varint32(true);
    342         if (message.length < len)
    343             throw new OverflowException("input exceed");
    344 
    345         message.length -= len;
    346 
    347         return handler.parse(message.input, len);
    348     }
    349 
    350     /**
    351      * Reads a message using the provided handler and applies a mapping function to
    352      * the byte array.
    353      *
    354      * @param handler The message handler.
    355      * @param map     The mapping function for the byte array.
    356      * @param <T>     The type of the parsed message.
    357      * @return The parsed message.
    358      */
    359     public <T> T message(Message<T> handler, UnaryOperator<byte[]> map) {
    360         byte[] buffer = bytes();
    361         return handler.parse(map.apply(buffer));
    362     }
    363 
    364     /**
    365      * Creates an iterator for packed values.
    366      *
    367      * @param scalar The scalar supplier.
    368      * @param <T>    The type of the scalar values.
    369      * @return The iterator for packed values.
    370      */
    371     public <T> Iterator<T> packed(Supplier<T> scalar) {
    372         return packed(scalar, v -> v);
    373     }
    374 
    375     /**
    376      * Creates an iterator for packed values, applying a mapping function to each
    377      * scalar value.
    378      *
    379      * @param scalar The scalar supplier.
    380      * @param map    The mapping function for scalar values.
    381      * @param <T>    The type of the scalar values.
    382      * @param <M>    The type of the mapped values.
    383      * @return The iterator for packed values.
    384      */
    385     public <T, M> Iterator<M> packed(Supplier<T> scalar, Function<T, M> map) {
    386         if (getType() != null && getType() != WireType.LEN)
    387             throw new WireTypeException(WireType.LEN, getType());
    388 
    389         int len = varint32(true);
    390         if (message.length < len)
    391             throw new OverflowException("input exceed");
    392 
    393         int end = message.length - len;
    394 
    395         resetType();
    396 
    397         return new Iterator<>() {
    398             public boolean hasNext() {
    399                 if (message.length < end)
    400                     throw new OverflowException("packed string overused");
    401                 return message.length > end;
    402             }
    403 
    404             public M next() {
    405                 return map.apply(scalar.get());
    406             }
    407         };
    408     }
    409 
    410     /**
    411      * Creates an iterator for packed values with an initial value and a binary
    412      * operator.
    413      *
    414      * @param scalar   The scalar supplier.
    415      * @param init     The initial value.
    416      * @param operator The binary operator.
    417      * @param <T>      The type of the scalar values.
    418      * @return The iterator for packed values.
    419      */
    420     public <T> Iterator<T> packed(Supplier<T> scalar, T init, BinaryOperator<T> operator) {
    421         return packed(scalar, v -> v, init, operator);
    422     }
    423 
    424     /**
    425      * Creates an iterator for packed values with an initial value, a binary
    426      * operator, and a mapping function.
    427      *
    428      * @param scalar   The scalar supplier.
    429      * @param map      The mapping function for scalar values.
    430      * @param init     The initial value.
    431      * @param operator The binary operator.
    432      * @param <T>      The type of the scalar values.
    433      * @param <M>      The type of the mapped values.
    434      * @return The iterator for packed values.
    435      */
    436     public <T, M> Iterator<M> packed(Supplier<T> scalar, Function<T, M> map, M init, BinaryOperator<M> operator) {
    437         if (getType() != null && getType() != WireType.LEN)
    438             throw new WireTypeException(WireType.LEN, getType());
    439 
    440         int len = varint32(true);
    441         if (message.length < len)
    442             throw new OverflowException("input exceed");
    443 
    444         int end = message.length - len;
    445 
    446         resetType();
    447 
    448         return new Iterator<>() {
    449             M value = init;
    450 
    451             public boolean hasNext() {
    452                 if (message.length < end)
    453                     throw new OverflowException("packed string overused");
    454                 return message.length > end;
    455             }
    456 
    457             public M next() {
    458                 return value = operator.apply(value, map.apply(scalar.get()));
    459             }
    460         };
    461     }
    462 
    463     /**
    464      * Defers the execution of a consumer with a supplied value.
    465      *
    466      * @param supplier The value supplier.
    467      * @param defer    The consumer to be deferred.
    468      * @param <T>      The type of the supplied value.
    469      */
    470     public <T> void delayed(Supplier<T> supplier, Consumer<T> defer) {
    471         T buffer = supplier.get();
    472         message.delayed.add(() -> defer.accept(buffer));
    473     }
    474 
    475     /**
    476      * Defers the execution of a consumer with a mapped supplied value.
    477      *
    478      * @param supplier The value supplier.
    479      * @param map      The mapping function for the supplied value.
    480      * @param defer    The consumer to be deferred.
    481      * @param <T>      The type of the supplied value.
    482      */
    483     public <T> void delayed(Supplier<T> supplier, UnaryOperator<T> map, Consumer<T> defer) {
    484         T buffer = supplier.get();
    485         message.delayed.add(() -> defer.accept(map.apply(buffer)));
    486     }
    487 
    488     /**
    489      * Defers the execution of a consumer with a parsed message using the provided
    490      * handler.
    491      *
    492      * @param handler The message handler.
    493      * @param defer   The consumer to be deferred.
    494      * @param <T>     The type of the parsed message.
    495      */
    496     public <T> void delayed(Message<T> handler, Consumer<T> defer) {
    497         byte[] buffer = bytes();
    498         message.delayed.add(() -> defer.accept(handler.parse(buffer)));
    499     }
    500 
    501     /**
    502      * Defers the execution of a consumer with a mapped parsed message using the
    503      * provided handler.
    504      *
    505      * @param handler The message handler.
    506      * @param map     The mapping function for the parsed message.
    507      * @param defer   The consumer to be deferred.
    508      * @param <T>     The type of the parsed message.
    509      */
    510     public <T> void delayed(Message<T> handler, UnaryOperator<byte[]> map, Consumer<T> defer) {
    511         byte[] buffer = bytes();
    512         message.delayed.add(() -> defer.accept(handler.parse(map.apply(buffer))));
    513     }
    514 
    515     /**
    516      * Skips the current wire element based on its type.
    517      */
    518     public void skip() {
    519         switch (getType()) {
    520             case VARINT:
    521                 skipVarint();
    522                 break;
    523             case I64:
    524                 skip64();
    525                 break;
    526             case LEN:
    527                 skipBytes();
    528                 break;
    529             case I32:
    530                 skip32();
    531                 break;
    532             case SGROUP:
    533             case EGROUP:
    534                 throw new UnsupportedOperationException("cannot skip sgroup of egroup");
    535         }
    536     }
    537 
    538     /**
    539      * Throws an {@link UnexpectedTagException} for the current tag.
    540      */
    541     public void throwUnexpected() {
    542         throw new UnexpectedTagException(tag);
    543     }
    544 }