ProtobufWriter.java (15470B)
1 package protobuf; 2 3 import java.io.IOException; 4 import java.io.InputStream; 5 import java.nio.charset.StandardCharsets; 6 import java.util.Iterator; 7 import java.util.function.BinaryOperator; 8 import java.util.function.Consumer; 9 import java.util.function.Function; 10 import java.util.function.Supplier; 11 import java.util.function.UnaryOperator; 12 13 import protobuf.exception.InputException; 14 import protobuf.exception.OverflowException; 15 import protobuf.exception.UnexpectedTagException; 16 import protobuf.exception.WireTypeException; 17 18 /** 19 * Represents an interface for parsing Protobuf wire elements, providing methods 20 * for reading various Protobuf data types. 21 */ 22 public class ProtobufWriter { 23 private final MessageIterator message; 24 private final WireType type; 25 private final int tag; 26 private boolean resetType = false; 27 28 public ProtobufReader(MessageIterator message, WireType type, int tag) { 29 this.message = message; 30 this.type = type; 31 this.tag = tag; 32 } 33 34 /** 35 * Gets the underlying input stream. 36 * 37 * @return The input stream. 38 */ 39 public InputStream getInputStream() { 40 return message.input; 41 } 42 43 /** 44 * Gets the Protobuf wire type. 45 * 46 * @return The wire type. 47 */ 48 public WireType getType() { 49 return resetType ? null : type; 50 } 51 52 /** 53 * Resets the Protobuf wire type. 54 * Useful for parsing packed streams. 55 */ 56 public void resetType() { 57 resetType = true; 58 } 59 60 /** 61 * Gets the tag associated with the wire element. 62 * 63 * @return The wire tag. 64 */ 65 public int tag() { 66 return tag; 67 } 68 69 /** 70 * Reads a signed variable-length integer as a 64-bit integer. 71 * 72 * @return The parsed signed 64-bit integer. 73 */ 74 public long svarint64() { 75 long n = varint64(); 76 return (n & 0x01) == 0 77 ? (n >> 1) 78 : -(n >> 1) - 1; 79 } 80 81 /** 82 * Parses a variable-length integer as a signed 64-bit integer. 83 * 84 * @return The parsed signed 64-bit integer. 85 */ 86 public long varint64() { 87 return varint64(false); 88 } 89 90 /** 91 * Reads a signed variable-length integer as a 32-bit integer. 92 * 93 * @return The parsed signed 32-bit integer. 94 */ 95 public long varint64(boolean ignoreType) { 96 if (!ignoreType && getType() != null && getType() != WireType.VARINT) 97 throw new WireTypeException(WireType.VARINT, getType()); 98 99 long result = 0; 100 long b = 0; 101 int shift = 0; 102 while (shift < 64 && message.length > 0) { 103 try { 104 b = message.input.read(); 105 } catch (IOException exc) { 106 throw new InputException(exc); 107 } 108 if (b == -1) 109 break; 110 111 message.length--; 112 113 result |= (b & 0x7f) << shift; 114 shift += 7; 115 if ((b & 0x80) == 0) 116 return result; 117 } 118 119 throw new OverflowException("input exceed"); 120 } 121 122 public long svarint32() { 123 int n = varint32(); 124 return (n & 0x01) == 0 125 ? (n >> 1) 126 : -(n >> 1); 127 128 } 129 130 /** 131 * Parses a variable-length integer as a signed 32-bit integer. 132 * 133 * @return The parsed signed 32-bit integer. 134 */ 135 public int varint32() { 136 return varint32(false); 137 } 138 139 private int varint32(boolean ignoreType) { 140 if (!ignoreType && getType() != null && getType() != WireType.VARINT) 141 throw new WireTypeException(WireType.VARINT, getType()); 142 143 int result = 0; 144 int b = 0; 145 int shift = 0; 146 while (shift < 32 && message.length > 0) { 147 try { 148 b = message.input.read(); 149 } catch (IOException exc) { 150 throw new InputException(exc); 151 } 152 if (b == -1) 153 break; 154 155 message.length--; 156 157 result |= (b & 0x7f) << shift; 158 shift += 7; 159 if ((b & 0x80) == 0) 160 return result; 161 } 162 throw new OverflowException("input exceed"); 163 } 164 165 /** 166 * Skips the variable-length integer. 167 */ 168 public void skipVarint() { 169 if (getType() != null && getType() != WireType.VARINT) 170 throw new WireTypeException(WireType.VARINT, getType()); 171 172 int b = 0; 173 while (message.length > 0) { 174 try { 175 b = message.input.read(); 176 } catch (IOException exc) { 177 throw new InputException(exc); 178 } 179 if (b == -1) 180 break; 181 182 message.length--; 183 if ((b & 0x80) == 0) 184 return; 185 } 186 throw new OverflowException("input exceed"); 187 } 188 189 /** 190 * Reads a fixed 64-bit integer. 191 * 192 * @return The parsed 64-bit integer. 193 */ 194 public long fixed64() { 195 if (getType() != null && getType() != WireType.I64) 196 throw new WireTypeException(WireType.I64, getType()); 197 198 if (message.length < 8) 199 throw new OverflowException("input exceed"); 200 201 byte[] bytes; 202 try { 203 bytes = message.input.readNBytes(8); 204 } catch (IOException exc) { 205 throw new InputException(exc); 206 } 207 long result = 0; 208 209 for (int i = bytes.length - 1; i >= 0; i--) { 210 result <<= 8; 211 result |= bytes[i]; 212 } 213 214 return result; 215 } 216 217 /** 218 * Skips a fixed 64-bit integer. 219 */ 220 public void skip64() { 221 if (getType() != null && getType() != WireType.I64) 222 throw new WireTypeException(WireType.I64, getType()); 223 224 if (message.length < 8) 225 throw new OverflowException("input exceed"); 226 227 message.length -= 8; 228 try { 229 message.input.skipNBytes(8); 230 } catch (IOException exc) { 231 throw new InputException(exc); 232 } 233 } 234 235 /** 236 * Reads a fixed 32-bit integer. 237 * 238 * @return The parsed 32-bit integer. 239 */ 240 public int fixed32() { 241 if (getType() != null && getType() != WireType.I32) 242 throw new WireTypeException(WireType.I32, getType()); 243 244 if (message.length < 4) 245 throw new OverflowException("input exceed"); 246 247 byte[] bytes; 248 try { 249 bytes = message.input.readNBytes(4); 250 } catch (IOException exc) { 251 throw new InputException(exc); 252 } 253 int result = 0; 254 255 for (int i = bytes.length - 1; i >= 0; i--) { 256 result <<= 8; 257 result |= bytes[i]; 258 } 259 260 return result; 261 } 262 263 /** 264 * Skips a fixed 32-bit integer. 265 */ 266 public void skip32() { 267 if (getType() != null && getType() != WireType.I32) 268 throw new WireTypeException(WireType.I32, getType()); 269 270 if (message.length < 4) 271 throw new OverflowException("input exceed"); 272 273 message.length -= 4; 274 try { 275 message.input.skipNBytes(4); 276 } catch (IOException exc) { 277 throw new InputException(exc); 278 } 279 } 280 281 /** 282 * Reads a byte array. 283 * 284 * @return The read byte array. 285 */ 286 public byte[] bytes() { 287 if (getType() != null && getType() != WireType.LEN) 288 throw new WireTypeException(WireType.LEN, getType()); 289 290 int len = varint32(true); 291 if (message.length < len) 292 throw new OverflowException("input exceed"); 293 294 message.length -= len; 295 try { 296 return message.input.readNBytes(len); 297 } catch (IOException exc) { 298 throw new InputException(exc); 299 } 300 } 301 302 /** 303 * Skips a byte array. 304 */ 305 public void skipBytes() { 306 if (getType() != null && getType() != WireType.LEN) 307 throw new WireTypeException(WireType.LEN, getType()); 308 309 int len = varint32(true); 310 if (message.length < len) 311 throw new OverflowException("input exceed"); 312 313 message.length -= len; 314 try { 315 message.input.skipNBytes(len); 316 } catch (IOException exc) { 317 throw new InputException(exc); 318 } 319 } 320 321 /** 322 * Reads a string. 323 * 324 * @return The read string. 325 */ 326 public String string() { 327 return new String(bytes(), StandardCharsets.UTF_8); 328 } 329 330 /** 331 * Reads a message using the provided handler. 332 * 333 * @param handler The message handler. 334 * @param <T> The type of the parsed message. 335 * @return The parsed message. 336 */ 337 public <T> T message(Message<T> handler) { 338 if (getType() != null && getType() != WireType.LEN) 339 throw new WireTypeException(WireType.LEN, getType()); 340 341 int len = varint32(true); 342 if (message.length < len) 343 throw new OverflowException("input exceed"); 344 345 message.length -= len; 346 347 return handler.parse(message.input, len); 348 } 349 350 /** 351 * Reads a message using the provided handler and applies a mapping function to 352 * the byte array. 353 * 354 * @param handler The message handler. 355 * @param map The mapping function for the byte array. 356 * @param <T> The type of the parsed message. 357 * @return The parsed message. 358 */ 359 public <T> T message(Message<T> handler, UnaryOperator<byte[]> map) { 360 byte[] buffer = bytes(); 361 return handler.parse(map.apply(buffer)); 362 } 363 364 /** 365 * Creates an iterator for packed values. 366 * 367 * @param scalar The scalar supplier. 368 * @param <T> The type of the scalar values. 369 * @return The iterator for packed values. 370 */ 371 public <T> Iterator<T> packed(Supplier<T> scalar) { 372 return packed(scalar, v -> v); 373 } 374 375 /** 376 * Creates an iterator for packed values, applying a mapping function to each 377 * scalar value. 378 * 379 * @param scalar The scalar supplier. 380 * @param map The mapping function for scalar values. 381 * @param <T> The type of the scalar values. 382 * @param <M> The type of the mapped values. 383 * @return The iterator for packed values. 384 */ 385 public <T, M> Iterator<M> packed(Supplier<T> scalar, Function<T, M> map) { 386 if (getType() != null && getType() != WireType.LEN) 387 throw new WireTypeException(WireType.LEN, getType()); 388 389 int len = varint32(true); 390 if (message.length < len) 391 throw new OverflowException("input exceed"); 392 393 int end = message.length - len; 394 395 resetType(); 396 397 return new Iterator<>() { 398 public boolean hasNext() { 399 if (message.length < end) 400 throw new OverflowException("packed string overused"); 401 return message.length > end; 402 } 403 404 public M next() { 405 return map.apply(scalar.get()); 406 } 407 }; 408 } 409 410 /** 411 * Creates an iterator for packed values with an initial value and a binary 412 * operator. 413 * 414 * @param scalar The scalar supplier. 415 * @param init The initial value. 416 * @param operator The binary operator. 417 * @param <T> The type of the scalar values. 418 * @return The iterator for packed values. 419 */ 420 public <T> Iterator<T> packed(Supplier<T> scalar, T init, BinaryOperator<T> operator) { 421 return packed(scalar, v -> v, init, operator); 422 } 423 424 /** 425 * Creates an iterator for packed values with an initial value, a binary 426 * operator, and a mapping function. 427 * 428 * @param scalar The scalar supplier. 429 * @param map The mapping function for scalar values. 430 * @param init The initial value. 431 * @param operator The binary operator. 432 * @param <T> The type of the scalar values. 433 * @param <M> The type of the mapped values. 434 * @return The iterator for packed values. 435 */ 436 public <T, M> Iterator<M> packed(Supplier<T> scalar, Function<T, M> map, M init, BinaryOperator<M> operator) { 437 if (getType() != null && getType() != WireType.LEN) 438 throw new WireTypeException(WireType.LEN, getType()); 439 440 int len = varint32(true); 441 if (message.length < len) 442 throw new OverflowException("input exceed"); 443 444 int end = message.length - len; 445 446 resetType(); 447 448 return new Iterator<>() { 449 M value = init; 450 451 public boolean hasNext() { 452 if (message.length < end) 453 throw new OverflowException("packed string overused"); 454 return message.length > end; 455 } 456 457 public M next() { 458 return value = operator.apply(value, map.apply(scalar.get())); 459 } 460 }; 461 } 462 463 /** 464 * Defers the execution of a consumer with a supplied value. 465 * 466 * @param supplier The value supplier. 467 * @param defer The consumer to be deferred. 468 * @param <T> The type of the supplied value. 469 */ 470 public <T> void delayed(Supplier<T> supplier, Consumer<T> defer) { 471 T buffer = supplier.get(); 472 message.delayed.add(() -> defer.accept(buffer)); 473 } 474 475 /** 476 * Defers the execution of a consumer with a mapped supplied value. 477 * 478 * @param supplier The value supplier. 479 * @param map The mapping function for the supplied value. 480 * @param defer The consumer to be deferred. 481 * @param <T> The type of the supplied value. 482 */ 483 public <T> void delayed(Supplier<T> supplier, UnaryOperator<T> map, Consumer<T> defer) { 484 T buffer = supplier.get(); 485 message.delayed.add(() -> defer.accept(map.apply(buffer))); 486 } 487 488 /** 489 * Defers the execution of a consumer with a parsed message using the provided 490 * handler. 491 * 492 * @param handler The message handler. 493 * @param defer The consumer to be deferred. 494 * @param <T> The type of the parsed message. 495 */ 496 public <T> void delayed(Message<T> handler, Consumer<T> defer) { 497 byte[] buffer = bytes(); 498 message.delayed.add(() -> defer.accept(handler.parse(buffer))); 499 } 500 501 /** 502 * Defers the execution of a consumer with a mapped parsed message using the 503 * provided handler. 504 * 505 * @param handler The message handler. 506 * @param map The mapping function for the parsed message. 507 * @param defer The consumer to be deferred. 508 * @param <T> The type of the parsed message. 509 */ 510 public <T> void delayed(Message<T> handler, UnaryOperator<byte[]> map, Consumer<T> defer) { 511 byte[] buffer = bytes(); 512 message.delayed.add(() -> defer.accept(handler.parse(map.apply(buffer)))); 513 } 514 515 /** 516 * Skips the current wire element based on its type. 517 */ 518 public void skip() { 519 switch (getType()) { 520 case VARINT: 521 skipVarint(); 522 break; 523 case I64: 524 skip64(); 525 break; 526 case LEN: 527 skipBytes(); 528 break; 529 case I32: 530 skip32(); 531 break; 532 case SGROUP: 533 case EGROUP: 534 throw new UnsupportedOperationException("cannot skip sgroup of egroup"); 535 } 536 } 537 538 /** 539 * Throws an {@link UnexpectedTagException} for the current tag. 540 */ 541 public void throwUnexpected() { 542 throw new UnexpectedTagException(tag); 543 } 544 } 545 }