diff --git a/cmake/installed_include_golden.txt b/cmake/installed_include_golden.txt index 28c819e4b4976..ace9fc6198248 100644 --- a/cmake/installed_include_golden.txt +++ b/cmake/installed_include_golden.txt @@ -155,6 +155,7 @@ google/protobuf/wire_format_lite.h google/protobuf/wrappers.pb.h google/protobuf/wrappers.proto upb/base/descriptor_constants.h +upb/base/error_handler.h upb/base/status.h upb/base/status.hpp upb/base/string_view.h diff --git a/rust/upb/sys/wire/wire.rs b/rust/upb/sys/wire/wire.rs index a90bc3ece77e7..5f9d9b5374429 100644 --- a/rust/upb/sys/wire/wire.rs +++ b/rust/upb/sys/wire/wire.rs @@ -31,8 +31,8 @@ pub enum EncodeStatus { #[derive(PartialEq, Eq, Copy, Clone, Debug)] pub enum DecodeStatus { Ok = 0, - Malformed = 1, - OutOfMemory = 2, + OutOfMemory = 1, + Malformed = 2, BadUtf8 = 3, MaxDepthExceeded = 4, MissingRequired = 5, diff --git a/upb/base/BUILD b/upb/base/BUILD index 62ef2547bc4c3..f6688451f9e83 100644 --- a/upb/base/BUILD +++ b/upb/base/BUILD @@ -18,6 +18,7 @@ cc_library( ], hdrs = [ "descriptor_constants.h", + "error_handler.h", "status.h", "status.hpp", "string_view.h", diff --git a/upb/base/error_handler.h b/upb/base/error_handler.h new file mode 100644 index 0000000000000..24c95d0c2a12e --- /dev/null +++ b/upb/base/error_handler.h @@ -0,0 +1,79 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2025 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#ifndef GOOGLE_UPB_UPB_BASE_ERROR_HANDLER_H__ +#define GOOGLE_UPB_UPB_BASE_ERROR_HANDLER_H__ + +#include + +// Must be last. +#include "upb/port/def.inc" + +// upb_ErrorHandler is a standard longjmp()-based exception handler for UPB. +// It is used for efficient error handling in cases where longjmp() is safe to +// use, such as in highly performance-sensitive C parsing code. +// +// This structure contains both a jmp_buf and an error code; the error code is +// stored in the structure prior to calling longjmp(). This is necessary because +// per the C standard, it is not possible to store the result of setjmp(), so +// the error code must be passed out-of-band. +// +// upb_ErrorHandler is generally not C++-compatible, because longjmp() does not +// run C++ destructors. So any library that supports upb_ErrorHandler should +// also support a regular return-based error handling mechanism. (Note: we +// could conceivably extend this to take a callback, which could either call +// longjmp() or throw a C++ exception. But since C++ exceptions are forbidden +// by the C++ style guide, there's not likely to be a demand for this.) +// +// To support both cases (longjmp() or return-based status) efficiently, code +// can be written like this: +// +// UPB_ATTR_CONST bool upb_Arena_HasErrHandler(const upb_Arena* a); +// +// INLINE void* upb_Arena_Malloc(upb_Arena* a, size_t size) { +// if (UPB_UNLIKELY(a->end - a->ptr < size)) { +// void* ret = upb_Arena_MallocFallback(a, size); +// UPB_MAYBE_ASSUME(upb_Arena_HasErrHandler(a), ret != NULL); +// return ret; +// } +// void* ret = a->ptr; +// a->ptr += size; +// UPB_ASSUME(ret != NULL); +// return ret; +// } +// +// If the optimizer can prove that an error handler is present, it can assume +// that upb_Arena_Malloc() will not return NULL. + +// We need to standardize on any error code that might be thrown by an error +// handler. + +typedef enum { + kUpb_ErrorCode_Ok = 0, + kUpb_ErrorCode_OutOfMemory = 1, + kUpb_ErrorCode_Malformed = 2, +} upb_ErrorCode; + +typedef struct { + int code; + jmp_buf buf; +} upb_ErrorHandler; + +UPB_INLINE void upb_ErrorHandler_Init(upb_ErrorHandler* e) { + e->code = kUpb_ErrorCode_Ok; +} + +UPB_INLINE UPB_NORETURN void upb_ErrorHandler_ThrowError(upb_ErrorHandler* e, + int code) { + UPB_ASSERT(code != kUpb_ErrorCode_Ok); + e->code = code; + UPB_LONGJMP(e->buf, 1); +} + +#include "upb/port/undef.inc" + +#endif // GOOGLE_UPB_UPB_BASE_ERROR_HANDLER_H__ diff --git a/upb/port/def.inc b/upb/port/def.inc index e1adac6f4a15e..c7327fcbfc17e 100644 --- a/upb/port/def.inc +++ b/upb/port/def.inc @@ -317,6 +317,15 @@ Error, UINTPTR_MAX is undefined #define UPB_ASSUME(expr) assert(expr) #endif +#if UPB_HAS_BUILTIN(__builtin_constant_p) && UPB_HAS_ATTRIBUTE(const) +#define UPB_MAYBE_ASSUME(pred, x) \ + if (__builtin_constant_p(pred) && pred) UPB_ASSUME(x) +#define UPB_ATTR_CONST __attribute__((const)) +#else +#define UPB_MAYBE_ASSUME(pred, x) +#define UPB_ATTR_CONST +#endif + /* UPB_ASSERT(): in release mode, we use the expression without letting it be * evaluated. This prevents "unused variable" warnings. */ #ifdef NDEBUG diff --git a/upb/port/undef.inc b/upb/port/undef.inc index ebbec28410c82..3481b46acee14 100644 --- a/upb/port/undef.inc +++ b/upb/port/undef.inc @@ -88,3 +88,5 @@ #undef UPB_ARM64_ASM #undef UPB_ARM64_BTI_DEFAULT #undef UPB_DEPRECATE_AND_INLINE +#undef UPB_MAYBE_ASSUME +#undef UPB_ATTR_CONST diff --git a/upb/wire/decode.c b/upb/wire/decode.c index 449514a897595..d5997043104ea 100644 --- a/upb/wire/decode.c +++ b/upb/wire/decode.c @@ -14,6 +14,7 @@ #include #include "upb/base/descriptor_constants.h" +#include "upb/base/error_handler.h" #include "upb/base/internal/endian.h" #include "upb/base/string_view.h" #include "upb/hash/common.h" @@ -95,9 +96,15 @@ typedef union { uint32_t size; } wireval; +static void _upb_Decoder_AssumeEpsHasErrorHandler(upb_Decoder* d) { + UPB_ASSUME(upb_EpsCopyInputStream_HasErrorHandler(&d->input)); +} + +#define EPS(d) (_upb_Decoder_AssumeEpsHasErrorHandler(d), &(d)->input) + static void _upb_Decoder_VerifyUtf8(upb_Decoder* d, const char* buf, int len) { if (!_upb_Decoder_VerifyUtf8Inline(buf, len)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_BadUtf8); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_BadUtf8); } } @@ -106,7 +113,7 @@ static bool _upb_Decoder_Reserve(upb_Decoder* d, upb_Array* arr, size_t elem) { arr->UPB_PRIVATE(capacity) - arr->UPB_PRIVATE(size) < elem; if (need_realloc && !UPB_PRIVATE(_upb_Array_Realloc)( arr, arr->UPB_PRIVATE(size) + elem, &d->arena)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } return need_realloc; } @@ -116,79 +123,15 @@ typedef struct { uint64_t val; } _upb_DecodeLongVarintReturn; -UPB_NOINLINE -static _upb_DecodeLongVarintReturn _upb_Decoder_DecodeLongVarint( - const char* ptr, uint64_t val, upb_Decoder* d) { - uint64_t byte; - for (int i = 1; i < 10; i++) { - byte = (uint8_t)ptr[i]; - val += (byte - 1) << (i * 7); - if (!(byte & 0x80)) { - return (_upb_DecodeLongVarintReturn){.ptr = ptr + i + 1, .val = val}; - } - } - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); -} - -UPB_NOINLINE -static _upb_DecodeLongVarintReturn _upb_Decoder_DecodeLongTag(const char* ptr, - uint64_t val, - upb_Decoder* d) { - uint64_t byte; - for (int i = 1; i < 5; i++) { - byte = (uint8_t)ptr[i]; - val += (byte - 1) << (i * 7); - if (!(byte & 0x80)) { - if (val > UINT32_MAX) { - break; - } - return (_upb_DecodeLongVarintReturn){.ptr = ptr + i + 1, .val = val}; - } - } - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); -} - -UPB_FORCEINLINE -const char* _upb_Decoder_DecodeVarint(upb_Decoder* d, const char* ptr, - uint64_t* val) { - UPB_PRIVATE(upb_EpsCopyInputStream_ConsumeBytes)(&d->input, 10); - uint64_t byte = (uint8_t)*ptr; - if (UPB_LIKELY((byte & 0x80) == 0)) { - *val = byte; - return ptr + 1; - } else { - _upb_DecodeLongVarintReturn res = - _upb_Decoder_DecodeLongVarint(ptr, byte, d); - *val = res.val; - return res.ptr; - } -} - -UPB_FORCEINLINE -const char* _upb_Decoder_DecodeTag(upb_Decoder* d, const char* ptr, - uint32_t* val) { - UPB_PRIVATE(upb_EpsCopyInputStream_ConsumeBytes)(&d->input, 5); - uint64_t byte = (uint8_t)*ptr; - if (UPB_LIKELY((byte & 0x80) == 0)) { - *val = byte; - return ptr + 1; - } else { - _upb_DecodeLongVarintReturn res = _upb_Decoder_DecodeLongTag(ptr, byte, d); - *val = res.val; - return res.ptr; - } -} - UPB_FORCEINLINE const char* upb_Decoder_DecodeSize(upb_Decoder* d, const char* ptr, uint32_t* size) { - uint64_t size64; - ptr = _upb_Decoder_DecodeVarint(d, ptr, &size64); - if (size64 >= INT32_MAX || - !upb_EpsCopyInputStream_CheckSize(&d->input, ptr, (int)size64)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + int sz; + ptr = upb_WireReader_ReadSize(ptr, &sz, EPS(d)); + if (!upb_EpsCopyInputStream_CheckSize(&d->input, ptr, sz)) { + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } - *size = size64; + *size = sz; return ptr; } @@ -229,7 +172,7 @@ static upb_Message* _upb_Decoder_NewSubMessage2(upb_Decoder* d, upb_Message** target) { UPB_ASSERT(subl); upb_Message* msg = _upb_Message_New(subl, &d->arena); - if (!msg) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + if (!msg) upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); *target = msg; return msg; @@ -245,7 +188,7 @@ static upb_Message* _upb_Decoder_NewSubMessage(upb_Decoder* d, static const char* _upb_Decoder_ReadString2(upb_Decoder* d, const char* ptr, int size, upb_StringView* str) { if (!_upb_Decoder_ReadString(d, &ptr, size, str)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } return ptr; } @@ -256,12 +199,12 @@ const char* _upb_Decoder_RecurseSubMessage(upb_Decoder* d, const char* ptr, const upb_MiniTable* subl, uint32_t expected_end_group) { if (--d->depth < 0) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_MaxDepthExceeded); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_MaxDepthExceeded); } ptr = _upb_Decoder_DecodeMessage(d, ptr, submsg, subl); d->depth++; if (d->end_group != expected_end_group) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } return ptr; } @@ -272,9 +215,6 @@ const char* _upb_Decoder_DecodeSubMessage(upb_Decoder* d, const char* ptr, const upb_MiniTableField* field, size_t size) { ptrdiff_t delta = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, size); - if (UPB_UNLIKELY(delta < 0)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - } const upb_MiniTable* subl = upb_MiniTable_GetSubMessageTable(field); UPB_ASSERT(subl); ptr = _upb_Decoder_RecurseSubMessage(d, ptr, submsg, subl, DECODE_NOGROUP); @@ -287,8 +227,8 @@ const char* _upb_Decoder_DecodeGroup(upb_Decoder* d, const char* ptr, upb_Message* submsg, const upb_MiniTable* subl, uint32_t number) { - if (_upb_Decoder_IsDone(d, &ptr)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + if (upb_EpsCopyInputStream_IsDone(EPS(d), &ptr)) { + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } ptr = _upb_Decoder_RecurseSubMessage(d, ptr, submsg, subl, number); d->end_group = DECODE_NOGROUP; @@ -335,7 +275,7 @@ void _upb_Decoder_AddEnumValueToUnknown(upb_Decoder* d, upb_Message* msg, if (!UPB_PRIVATE(_upb_Message_AddUnknown)(unknown_msg, buf, end - buf, &d->arena, kUpb_AddUnknown_Copy)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } } @@ -347,9 +287,10 @@ const char* _upb_Decoder_DecodeFixedPacked(upb_Decoder* d, const char* ptr, upb_StringView sv; ptr = upb_EpsCopyInputStream_ReadStringEphemeral(&d->input, ptr, val->size, &sv); + if (!ptr) upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); int mask = (1 << lg2) - 1; if (UPB_UNLIKELY((val->size & mask) != 0 || ptr == NULL)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } size_t count = val->size >> lg2; if (count == 0) return ptr; @@ -391,14 +332,11 @@ const char* _upb_Decoder_DecodeVarintPacked(upb_Decoder* d, const char* ptr, int lg2) { int scale = 1 << lg2; ptrdiff_t delta = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, val->size); - if (UPB_UNLIKELY(delta < 0)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - } char* out = UPB_PTR_AT(upb_Array_MutableDataPtr(arr), arr->UPB_PRIVATE(size) << lg2, void); - while (!_upb_Decoder_IsDone(d, &ptr)) { + while (!upb_EpsCopyInputStream_IsDone(EPS(d), &ptr)) { wireval elem; - ptr = _upb_Decoder_DecodeVarint(d, ptr, &elem.uint64_val); + ptr = upb_WireReader_ReadVarint(ptr, &elem.uint64_val, EPS(d)); _upb_Decoder_Munge(field, &elem); if (_upb_Decoder_Reserve(d, arr, 1)) { out = UPB_PTR_AT(upb_Array_MutableDataPtr(arr), @@ -418,14 +356,11 @@ static const char* _upb_Decoder_DecodeEnumPacked( const upb_MiniTableField* field, wireval* val) { const upb_MiniTableEnum* e = upb_MiniTable_GetSubEnumTable(field); ptrdiff_t delta = upb_EpsCopyInputStream_PushLimit(&d->input, ptr, val->size); - if (UPB_UNLIKELY(delta < 0)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - } char* out = UPB_PTR_AT(upb_Array_MutableDataPtr(arr), arr->UPB_PRIVATE(size) * 4, void); - while (!_upb_Decoder_IsDone(d, &ptr)) { + while (!upb_EpsCopyInputStream_IsDone(EPS(d), &ptr)) { wireval elem; - ptr = _upb_Decoder_DecodeVarint(d, ptr, &elem.uint64_val); + ptr = upb_WireReader_ReadVarint(ptr, &elem.uint64_val, EPS(d)); if (!upb_MiniTableEnum_CheckValue(e, elem.uint64_val)) { _upb_Decoder_AddEnumValueToUnknown(d, msg, field, &elem); continue; @@ -447,7 +382,7 @@ static upb_Array* _upb_Decoder_CreateArray(upb_Decoder* d, const upb_FieldType field_type = field->UPB_PRIVATE(descriptortype); const size_t lg2 = UPB_PRIVATE(_upb_FieldType_SizeLg2)(field_type); upb_Array* ret = UPB_PRIVATE(_upb_Array_New)(&d->arena, 4, lg2); - if (!ret) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + if (!ret) upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); return ret; } @@ -548,7 +483,7 @@ static upb_Map* _upb_Decoder_CreateMap(upb_Decoder* d, UPB_ASSERT(key_field->UPB_PRIVATE(offset) == offsetof(upb_MapEntry, k)); UPB_ASSERT(val_field->UPB_PRIVATE(offset) == offsetof(upb_MapEntry, v)); upb_Map* ret = _upb_Map_New(&d->arena, key_size, val_size); - if (!ret) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + if (!ret) upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); return ret; } @@ -560,7 +495,7 @@ UPB_NOINLINE static void _upb_Decoder_AddMapEntryUnknown( upb_EncodeStatus status = upb_Encode(ent_msg, entry, 0, &d->arena, &buf, &size); if (status != kUpb_EncodeStatus_Ok) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } char delim_buf[2 * kUpb_Decoder_EncodeVarint32MaxSize]; char* delim_end = delim_buf; @@ -574,7 +509,7 @@ UPB_NOINLINE static void _upb_Decoder_AddMapEntryUnknown( }; if (!UPB_PRIVATE(_upb_Message_AddUnknownV)(msg, &d->arena, unknown, 2)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } } @@ -617,7 +552,7 @@ static const char* _upb_Decoder_DecodeToMap(upb_Decoder* d, const char* ptr, } else { if (_upb_Map_Insert(map, &ent.k, map->key_size, &ent.v, map->val_size, &d->arena) == kUpb_MapInsertStatus_OutOfMemory) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } } return ptr; @@ -676,13 +611,6 @@ static const char* _upb_Decoder_DecodeToSubMessage( return ptr; } -static const char* upb_Decoder_SkipField(upb_Decoder* d, const char* ptr, - uint32_t tag) { - ptr = _upb_WireReader_SkipValue(ptr, tag, d->depth, &d->input); - if (!ptr) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - return ptr; -} - enum { kStartItemTag = ((kUpb_MsgSet_Item << 3) | kUpb_WireType_StartGroup), kEndItemTag = ((kUpb_MsgSet_Item << 3) | kUpb_WireType_EndGroup), @@ -696,7 +624,7 @@ static void upb_Decoder_AddKnownMessageSetItem( upb_Extension* ext = UPB_PRIVATE(_upb_Message_GetOrCreateExtension)(msg, item_mt, &d->arena); if (UPB_UNLIKELY(!ext)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } upb_Message** submsgp = (upb_Message**)&ext->data.msg_val; upb_Message* submsg = _upb_Decoder_NewSubMessage2( @@ -705,7 +633,9 @@ static void upb_Decoder_AddKnownMessageSetItem( upb_DecodeStatus status = upb_Decode( data, size, submsg, upb_MiniTableExtension_GetSubMessage(item_mt), d->extreg, d->options, &d->arena); - if (status != kUpb_DecodeStatus_Ok) _upb_Decoder_ErrorJmp(d, status); + if (status != kUpb_DecodeStatus_Ok) { + upb_ErrorHandler_ThrowError(&d->err, status); + } } static void upb_Decoder_AddUnknownMessageSetItem(upb_Decoder* d, @@ -730,7 +660,7 @@ static void upb_Decoder_AddUnknownMessageSetItem(upb_Decoder* d, {split, end - split}, }; if (!UPB_PRIVATE(_upb_Message_AddUnknownV)(msg, &d->arena, unknown, 3)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } } @@ -757,15 +687,15 @@ static const char* upb_Decoder_DecodeMessageSetItem( kUpb_HavePayload = 1 << 1, } StateMask; StateMask state_mask = 0; - while (!_upb_Decoder_IsDone(d, &ptr)) { + while (!upb_EpsCopyInputStream_IsDone(EPS(d), &ptr)) { uint32_t tag; - ptr = _upb_Decoder_DecodeTag(d, ptr, &tag); + ptr = upb_WireReader_ReadTag(ptr, &tag, EPS(d)); switch (tag) { case kEndItemTag: return ptr; case kTypeIdTag: { uint64_t tmp; - ptr = _upb_Decoder_DecodeVarint(d, ptr, &tmp); + ptr = upb_WireReader_ReadVarint(ptr, &tmp, EPS(d)); if (state_mask & kUpb_HaveId) break; // Ignore dup. state_mask |= kUpb_HaveId; type_id = tmp; @@ -781,7 +711,9 @@ static const char* upb_Decoder_DecodeMessageSetItem( ptr = upb_Decoder_DecodeSize(d, ptr, &size); ptr = upb_EpsCopyInputStream_ReadStringAlwaysAlias(&d->input, ptr, size, &sv); - if (!ptr) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + if (!ptr) { + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); + } if (state_mask & kUpb_HavePayload) break; // Ignore dup. state_mask |= kUpb_HavePayload; if (state_mask & kUpb_HaveId) { @@ -795,11 +727,11 @@ static const char* upb_Decoder_DecodeMessageSetItem( } default: // We do not preserve unexpected fields inside a message set item. - ptr = upb_Decoder_SkipField(d, ptr, tag); + ptr = _upb_WireReader_SkipValue(ptr, tag, d->depth, &d->input); break; } } - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } static upb_MiniTableField upb_Decoder_FieldNotFoundField = { @@ -981,7 +913,7 @@ const char* _upb_Decoder_DecodeWireValue(upb_Decoder* d, const char* ptr, switch (wire_type) { case kUpb_WireType_Varint: - ptr = _upb_Decoder_DecodeVarint(d, ptr, &val->uint64_val); + ptr = upb_WireReader_ReadVarint(ptr, &val->uint64_val, EPS(d)); if (upb_MiniTableField_IsClosedEnum(field)) { const upb_MiniTableEnum* e = upb_MiniTable_GetSubEnumTable(field); if (!upb_MiniTableEnum_CheckValue(e, val->uint64_val)) { @@ -1024,7 +956,7 @@ const char* _upb_Decoder_DecodeWireValue(upb_Decoder* d, const char* ptr, default: break; } - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } UPB_FORCEINLINE @@ -1041,7 +973,7 @@ const char* _upb_Decoder_DecodeKnownField(upb_Decoder* d, const char* ptr, upb_Extension* ext = UPB_PRIVATE(_upb_Message_GetOrCreateExtension)( msg, ext_layout, &d->arena); if (UPB_UNLIKELY(!ext)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } d->original_msg = msg; msg = &ext->data.UPB_PRIVATE(ext_msg_val); @@ -1113,7 +1045,9 @@ static const char* _upb_Decoder_FindFieldStart(upb_Decoder* d, const char* ptr, static const char* _upb_Decoder_DecodeUnknownField( upb_Decoder* d, const char* ptr, upb_Message* msg, uint32_t field_number, uint32_t wire_type, wireval val) { - if (field_number == 0) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); + if (field_number == 0) { + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); + } const char* start = _upb_Decoder_FindFieldStart(d, ptr, field_number, wire_type); @@ -1124,18 +1058,14 @@ static const char* _upb_Decoder_DecodeUnknownField( upb_StringView sv; ptr = upb_EpsCopyInputStream_ReadStringEphemeral(&d->input, ptr, val.size, &sv); - if (UPB_UNLIKELY(ptr == NULL)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - } + if (!ptr) upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_Malformed); } else if (wire_type == kUpb_WireType_StartGroup) { ptr = UPB_PRIVATE(_upb_WireReader_SkipGroup)(ptr, field_number << 3, d->depth, &d->input); } upb_StringView sv; - if (ptr == NULL || !upb_EpsCopyInputStream_EndCapture(&d->input, ptr, &sv)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - } + upb_EpsCopyInputStream_EndCapture(&d->input, ptr, &sv); upb_AddUnknownMode mode = kUpb_AddUnknown_Copy; if (d->options & kUpb_DecodeOption_AliasString) { @@ -1150,7 +1080,7 @@ static const char* _upb_Decoder_DecodeUnknownField( if (!UPB_PRIVATE(_upb_Message_AddUnknown)(msg, sv.data, sv.size, &d->arena, mode)) { - _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_OutOfMemory); + upb_ErrorHandler_ThrowError(&d->err, kUpb_DecodeStatus_OutOfMemory); } return ptr; @@ -1166,7 +1096,7 @@ const char* _upb_Decoder_DecodeFieldTag(upb_Decoder* d, const char* ptr, uint32_t tag; UPB_ASSERT(ptr < d->input.limit_ptr); - ptr = _upb_Decoder_DecodeTag(d, ptr, &tag); + ptr = upb_WireReader_ReadTag(ptr, &tag, EPS(d)); *field_number = tag >> 3; *wire_type = tag & 7; return ptr; @@ -1273,7 +1203,7 @@ const char* _upb_Decoder_DecodeField(upb_Decoder* d, const char* ptr, if (_upb_Decoder_TryDecodeMessageFast(d, &ptr, msg, mt, last_field_index, data)) { return ptr; - } else if (_upb_Decoder_IsDone(d, &ptr)) { + } else if (upb_EpsCopyInputStream_IsDone(EPS(d), &ptr)) { return _upb_Decoder_EndMessage(d, ptr); } @@ -1312,10 +1242,10 @@ static upb_DecodeStatus upb_Decoder_Decode(upb_Decoder* const decoder, upb_Message* const msg, const upb_MiniTable* const m, upb_Arena* const arena) { - if (UPB_SETJMP(decoder->err) == 0) { - decoder->status = _upb_Decoder_DecodeTop(decoder, buf, msg, m); + if (UPB_SETJMP(decoder->err.buf) == 0) { + decoder->err.code = _upb_Decoder_DecodeTop(decoder, buf, msg, m); } else { - UPB_ASSERT(decoder->status != kUpb_DecodeStatus_Ok); + UPB_ASSERT(decoder->err.code != kUpb_DecodeStatus_Ok); } return upb_Decoder_Destroy(decoder, arena); diff --git a/upb/wire/decode.h b/upb/wire/decode.h index 647e8cfa8a127..5b199642bd79d 100644 --- a/upb/wire/decode.h +++ b/upb/wire/decode.h @@ -83,8 +83,8 @@ UPB_INLINE int upb_Decode_LimitDepth(uint32_t decode_options, uint32_t limit) { // LINT.IfChange typedef enum { kUpb_DecodeStatus_Ok = 0, - kUpb_DecodeStatus_Malformed = 1, // Wire format was corrupt - kUpb_DecodeStatus_OutOfMemory = 2, // Arena alloc failed + kUpb_DecodeStatus_OutOfMemory = 1, // Arena alloc failed + kUpb_DecodeStatus_Malformed = 2, // Wire format was corrupt kUpb_DecodeStatus_BadUtf8 = 3, // String field had bad UTF-8 kUpb_DecodeStatus_MaxDepthExceeded = 4, // Exceeded upb_DecodeOptions_MaxDepth diff --git a/upb/wire/decode_fast/cardinality.c b/upb/wire/decode_fast/cardinality.c index 6c00bba8c49f5..06b86c91c1f44 100644 --- a/upb/wire/decode_fast/cardinality.c +++ b/upb/wire/decode_fast/cardinality.c @@ -10,6 +10,6 @@ const char* upb_DecodeFast_IsDoneFallback(upb_Decoder* d, const char* ptr) { upb_IsDoneStatus status = UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneStatus)( &d->input, ptr, &overrun); UPB_ASSERT(status == kUpb_IsDoneStatus_NeedFallback); - return UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackInline)( - &d->input, ptr, overrun, _upb_Decoder_BufferFlipCallback); + return UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallback)(&d->input, ptr, + overrun); } diff --git a/upb/wire/decode_fast/cardinality.h b/upb/wire/decode_fast/cardinality.h index 86bc0d0d5d2b5..f279b9c8da641 100644 --- a/upb/wire/decode_fast/cardinality.h +++ b/upb/wire/decode_fast/cardinality.h @@ -94,7 +94,7 @@ fastdecode_nextret fastdecode_nextrepeated(upb_Decoder* d, void* dst, fastdecode_nextret ret; dst = (char*)dst + valbytes; - if (UPB_LIKELY(!_upb_Decoder_IsDone(d, ptr))) { + if (UPB_LIKELY(!upb_EpsCopyInputStream_IsDone(&d->input, ptr))) { ret.tag = _upb_FastDecoder_LoadTag(*ptr); if (fastdecode_tagmatch(ret.tag, data, tagbytes)) { ret.next = FD_NEXT_SAMEFIELD; diff --git a/upb/wire/decode_fast/dispatch.c b/upb/wire/decode_fast/dispatch.c index 35b6264f144ff..d0b7a9f8bcc02 100644 --- a/upb/wire/decode_fast/dispatch.c +++ b/upb/wire/decode_fast/dispatch.c @@ -34,8 +34,8 @@ UPB_NOINLINE UPB_PRESERVE_NONE const char* upb_DecodeFast_MessageIsDoneFallback( } case kUpb_IsDoneStatus_NeedFallback: // We've reached end-of-buffer. Refresh the buffer. - ptr = UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackInline)( - &d->input, ptr, overrun, _upb_Decoder_BufferFlipCallback); + ptr = UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallback)(&d->input, ptr, + overrun); // We successfully refreshed the buffer (otherwise the function above // would have thrown an error with longjmp()). So continue with the @@ -49,7 +49,7 @@ UPB_NOINLINE UPB_PRESERVE_NONE const char* upb_DecodeFast_MessageIsDoneFallback( } const char* _upb_FastDecoder_ErrorJmp2(upb_Decoder* d) { - UPB_LONGJMP(d->err, 1); + UPB_LONGJMP(d->err.buf, 1); return NULL; } diff --git a/upb/wire/decode_fast/dispatch.h b/upb/wire/decode_fast/dispatch.h index 8a939635fc547..e09cd05976ea4 100644 --- a/upb/wire/decode_fast/dispatch.h +++ b/upb/wire/decode_fast/dispatch.h @@ -202,7 +202,7 @@ const char* _upb_FastDecoder_ErrorJmp2(upb_Decoder* d); UPB_INLINE const char* _upb_FastDecoder_ErrorJmp(upb_Decoder* d, upb_DecodeStatus status) { - d->status = status; + d->err.code = status; return _upb_FastDecoder_ErrorJmp2(d); } @@ -213,7 +213,7 @@ const char* _upb_FastDecoder_ErrorJmp(upb_Decoder* d, upb_DecodeStatus status) { case kUpb_DecodeFastNext_FallbackToMiniTable: \ UPB_MUSTTAIL return _upb_FastDecoder_DecodeGeneric(UPB_PARSE_ARGS); \ case kUpb_DecodeFastNext_Error: \ - UPB_ASSERT(d->status != kUpb_DecodeStatus_Ok); \ + UPB_ASSERT(d->err.code != kUpb_DecodeStatus_Ok); \ return _upb_FastDecoder_ErrorJmp2(d); \ case kUpb_DecodeFastNext_MessageIsDoneFallback: \ UPB_MUSTTAIL return upb_DecodeFast_MessageIsDoneFallback( \ @@ -253,7 +253,7 @@ UPB_INLINE bool upb_DecodeFast_SetError(upb_Decoder* d, #ifdef UPB_TRACE_FASTDECODER fprintf(stderr, "Fasttable error @ %s:%d -> %s (%d)\n", file, line, sym, val); #endif - d->status = val; + d->err.code = val; *next = kUpb_DecodeFastNext_Error; return false; } diff --git a/upb/wire/eps_copy_input_stream.c b/upb/wire/eps_copy_input_stream.c index 57411a6ff329f..49c97351db553 100644 --- a/upb/wire/eps_copy_input_stream.c +++ b/upb/wire/eps_copy_input_stream.c @@ -7,16 +7,42 @@ #include "upb/wire/eps_copy_input_stream.h" +#include +#include + +#include "upb/base/error_handler.h" +#include "upb/wire/internal/eps_copy_input_stream.h" + // Must be last. #include "upb/port/def.inc" -static const char* _upb_EpsCopyInputStream_NoOpCallback( - upb_EpsCopyInputStream* e, const char* old_end, const char* new_start) { - return new_start; +const char* UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)( + upb_EpsCopyInputStream* e) { + e->error = true; + if (e->err) upb_ErrorHandler_ThrowError(e->err, kUpb_ErrorCode_Malformed); + return NULL; } -const char* UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackNoCallback)( - upb_EpsCopyInputStream* e, const char* ptr, int overrun) { - return UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackInline)( - e, ptr, overrun, _upb_EpsCopyInputStream_NoOpCallback); +const char* UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallback)( + struct upb_EpsCopyInputStream* e, const char* ptr, int overrun) { + if (overrun < e->limit) { + // Need to copy remaining data into patch buffer. + UPB_ASSERT(overrun < kUpb_EpsCopyInputStream_SlopBytes); + const char* old_end = ptr; + const char* new_start = &e->patch[overrun]; + memset(&e->patch[kUpb_EpsCopyInputStream_SlopBytes], 0, + kUpb_EpsCopyInputStream_SlopBytes); + memcpy(e->patch, e->end, kUpb_EpsCopyInputStream_SlopBytes); + ptr = new_start; + e->end = &e->patch[kUpb_EpsCopyInputStream_SlopBytes]; + e->limit -= kUpb_EpsCopyInputStream_SlopBytes; + e->limit_ptr = e->end + e->limit; + UPB_ASSERT(ptr < e->limit_ptr); + e->input_delta = (uintptr_t)old_end - (uintptr_t)new_start; + UPB_PRIVATE(upb_EpsCopyInputStream_BoundsChecked)(e); + return new_start; + } else { + UPB_ASSERT(overrun > e->limit); + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(e); + } } diff --git a/upb/wire/eps_copy_input_stream.h b/upb/wire/eps_copy_input_stream.h index 62fdd8726b19e..96d1cfd3ca294 100644 --- a/upb/wire/eps_copy_input_stream.h +++ b/upb/wire/eps_copy_input_stream.h @@ -11,6 +11,7 @@ #include #include +#include "upb/base/error_handler.h" #include "upb/base/string_view.h" #include "upb/wire/internal/eps_copy_input_stream.h" @@ -29,6 +30,24 @@ typedef struct upb_EpsCopyInputStream upb_EpsCopyInputStream; UPB_INLINE void upb_EpsCopyInputStream_Init(upb_EpsCopyInputStream* e, const char** ptr, size_t size); +// Like the previous function, but registers an error handler that will be +// called for any errors encountered. +UPB_INLINE void upb_EpsCopyInputStream_InitWithErrorHandler( + upb_EpsCopyInputStream* e, const char** ptr, size_t size, + upb_ErrorHandler* err); + +// Returns true if the stream has an error handler. +// +// This function is marked const, which indicates to the compiler that the +// return value is solely a function of the pointer value. This is not +// entirely true if the stream is reinitialized with +// upb_EpsCopyInputStream_Init*(), so users must not call this function in +// any context where the stream may be reinitialized between calls to this +// function, and the presence of an error handler changes when reinitialized. +UPB_ATTR_CONST +UPB_INLINE bool upb_EpsCopyInputStream_HasErrorHandler( + const upb_EpsCopyInputStream* e); + // Returns true if the stream is in the error state. A stream enters the error // state when the user reads past a limit (caught in IsDone()) or the // ZeroCopyInputStream returns an error. diff --git a/upb/wire/internal/decoder.c b/upb/wire/internal/decoder.c index 0ba231a48a244..0f132a0648898 100644 --- a/upb/wire/internal/decoder.c +++ b/upb/wire/internal/decoder.c @@ -28,17 +28,3 @@ const char* _upb_Decoder_CheckRequired(upb_Decoder* d, const char* ptr, } return ptr; } - -UPB_NORETURN void* _upb_Decoder_ErrorJmp(upb_Decoder* d, - upb_DecodeStatus status) { - UPB_ASSERT(status != kUpb_DecodeStatus_Ok); - d->status = status; - UPB_LONGJMP(d->err, 1); -} - -UPB_NOINLINE -const char* _upb_Decoder_IsDoneFallback(upb_EpsCopyInputStream* e, - const char* ptr, int overrun) { - return UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackInline)( - e, ptr, overrun, _upb_Decoder_BufferFlipCallback); -} diff --git a/upb/wire/internal/decoder.h b/upb/wire/internal/decoder.h index c180e4fbbabe1..819cdd8670e3d 100644 --- a/upb/wire/internal/decoder.h +++ b/upb/wire/internal/decoder.h @@ -19,6 +19,7 @@ #include #include "upb/base/descriptor_constants.h" +#include "upb/base/error_handler.h" #include "upb/base/string_view.h" #include "upb/mem/arena.h" #include "upb/mem/internal/arena.h" @@ -50,8 +51,7 @@ typedef struct upb_Decoder { upb_Arena arena; void* foo[UPB_ARENA_SIZE_HACK]; }; - upb_DecodeStatus status; - jmp_buf err; + upb_ErrorHandler err; #ifndef NDEBUG const char* debug_tagstart; @@ -67,7 +67,17 @@ UPB_INLINE const char* upb_Decoder_Init(upb_Decoder* d, const char* buf, const upb_ExtensionRegistry* extreg, int options, upb_Arena* arena, char* trace_buf, size_t trace_size) { - upb_EpsCopyInputStream_Init(&d->input, &buf, size); + upb_ErrorHandler_Init(&d->err); + upb_EpsCopyInputStream_InitWithErrorHandler(&d->input, &buf, size, &d->err); + + UPB_STATIC_ASSERT((int)kUpb_DecodeStatus_Ok == (int)kUpb_ErrorCode_Ok, + "mismatched error codes"); + UPB_STATIC_ASSERT( + (int)kUpb_DecodeStatus_OutOfMemory == (int)kUpb_ErrorCode_OutOfMemory, + "mismatched error codes"); + UPB_STATIC_ASSERT( + (int)kUpb_DecodeStatus_Malformed == (int)kUpb_ErrorCode_Malformed, + "mismatched error codes"); if (options & kUpb_DecodeOption_AlwaysValidateUtf8) { // Fasttable decoder does not support this option. @@ -79,7 +89,6 @@ UPB_INLINE const char* upb_Decoder_Init(upb_Decoder* d, const char* buf, d->end_group = DECODE_NOGROUP; d->options = (uint16_t)options; d->missing_required = false; - d->status = kUpb_DecodeStatus_Ok; d->message_is_done = false; #ifndef NDEBUG d->trace_buf = trace_buf; @@ -100,7 +109,7 @@ UPB_INLINE const char* upb_Decoder_Init(upb_Decoder* d, const char* buf, UPB_INLINE upb_DecodeStatus upb_Decoder_Destroy(upb_Decoder* d, upb_Arena* arena) { UPB_PRIVATE(_upb_Arena_SwapOut)(arena, &d->arena); - return d->status; + return (upb_DecodeStatus)d->err.code; } #ifndef NDEBUG @@ -168,28 +177,10 @@ UPB_INLINE const upb_MiniTable* decode_totablep(intptr_t table) { return (const upb_MiniTable*)(table >> 8); } -const char* _upb_Decoder_IsDoneFallback(upb_EpsCopyInputStream* e, - const char* ptr, int overrun); - const char* _upb_Decoder_DecodeMessage(upb_Decoder* d, const char* ptr, upb_Message* msg, const upb_MiniTable* layout); -UPB_INLINE bool _upb_Decoder_IsDone(upb_Decoder* d, const char** ptr) { - return UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneWithCallback)( - &d->input, ptr, &_upb_Decoder_IsDoneFallback); -} - -UPB_NORETURN void* _upb_Decoder_ErrorJmp(upb_Decoder* d, - upb_DecodeStatus status); - -UPB_INLINE const char* _upb_Decoder_BufferFlipCallback( - upb_EpsCopyInputStream* e, const char* old_end, const char* new_start) { - upb_Decoder* d = (upb_Decoder*)e; - if (!old_end) _upb_Decoder_ErrorJmp(d, kUpb_DecodeStatus_Malformed); - return new_start; -} - UPB_INLINE bool _upb_Decoder_FieldRequiresUtf8Validation( const upb_Decoder* d, const upb_MiniTableField* field) { if (field->UPB_PRIVATE(descriptortype) == kUpb_FieldType_String) return true; diff --git a/upb/wire/internal/eps_copy_input_stream.h b/upb/wire/internal/eps_copy_input_stream.h index 9e48c01b0621e..f8bc2401f0fee 100644 --- a/upb/wire/internal/eps_copy_input_stream.h +++ b/upb/wire/internal/eps_copy_input_stream.h @@ -12,6 +12,7 @@ #include #include +#include "upb/base/error_handler.h" #include "upb/base/string_view.h" // Must be last. @@ -38,6 +39,7 @@ struct upb_EpsCopyInputStream { const char* buffer_start; // Pointer to the original input buffer const char* capture_start; // If non-NULL, the start of the captured region. ptrdiff_t limit; // Submessage limit relative to end + upb_ErrorHandler* err; // Error handler to use when things go wrong. bool error; // To distinguish between EOF and error. #ifndef NDEBUG int guaranteed_bytes; @@ -56,10 +58,12 @@ UPB_INLINE bool upb_EpsCopyInputStream_IsError( return e->error; } -UPB_INLINE void upb_EpsCopyInputStream_Init(struct upb_EpsCopyInputStream* e, - const char** ptr, size_t size) { +UPB_INLINE void upb_EpsCopyInputStream_InitWithErrorHandler( + struct upb_EpsCopyInputStream* e, const char** ptr, size_t size, + upb_ErrorHandler* err) { e->buffer_start = *ptr; e->capture_start = NULL; + e->err = err; if (size <= kUpb_EpsCopyInputStream_SlopBytes) { memset(&e->patch, 0, 32); if (size) memcpy(&e->patch, *ptr, size); @@ -77,12 +81,28 @@ UPB_INLINE void upb_EpsCopyInputStream_Init(struct upb_EpsCopyInputStream* e, UPB_PRIVATE(upb_EpsCopyInputStream_BoundsChecked)(e); } -typedef const char* upb_EpsCopyInputStream_BufferFlipCallback( - struct upb_EpsCopyInputStream* e, const char* old_end, - const char* new_start); +UPB_INLINE void upb_EpsCopyInputStream_Init(struct upb_EpsCopyInputStream* e, + const char** ptr, size_t size) { + upb_EpsCopyInputStream_InitWithErrorHandler(e, ptr, size, NULL); +} -typedef const char* upb_EpsCopyInputStream_IsDoneFallbackFunc( - struct upb_EpsCopyInputStream* e, const char* ptr, int overrun); +UPB_ATTR_CONST +UPB_INLINE bool upb_EpsCopyInputStream_HasErrorHandler( + const struct upb_EpsCopyInputStream* e) { + return e && e->err != NULL; +} + +// Call this function to signal an error. If an error handler is set, it will be +// called and the function will never return. Otherwise, returns NULL to +// indicate an error. +const char* UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)( + struct upb_EpsCopyInputStream* e); + +UPB_INLINE const char* UPB_PRIVATE(upb_EpsCopyInputStream_AssumeResult)( + struct upb_EpsCopyInputStream* e, const char* ptr) { + UPB_MAYBE_ASSUME(upb_EpsCopyInputStream_HasErrorHandler(e), ptr != NULL); + return ptr; +} //////////////////////////////////////////////////////////////////////////////// @@ -150,9 +170,11 @@ UPB_INLINE upb_IsDoneStatus UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneStatus)( } } -UPB_INLINE bool UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneWithCallback)( - struct upb_EpsCopyInputStream* e, const char** ptr, - upb_EpsCopyInputStream_IsDoneFallbackFunc* func) { +const char* UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallback)( + struct upb_EpsCopyInputStream* e, const char* ptr, int overrun); + +UPB_INLINE bool upb_EpsCopyInputStream_IsDone(struct upb_EpsCopyInputStream* e, + const char** ptr) { int overrun; switch (UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneStatus)(e, *ptr, &overrun)) { case kUpb_IsDoneStatus_Done: @@ -162,7 +184,8 @@ UPB_INLINE bool UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneWithCallback)( UPB_PRIVATE(upb_EpsCopyInputStream_BoundsChecked)(e); return false; case kUpb_IsDoneStatus_NeedFallback: - *ptr = func(e, *ptr, overrun); + *ptr = + UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallback)(e, *ptr, overrun); if (*ptr) { UPB_PRIVATE(upb_EpsCopyInputStream_BoundsChecked)(e); } else { @@ -173,15 +196,6 @@ UPB_INLINE bool UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneWithCallback)( UPB_UNREACHABLE(); } -const char* UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackNoCallback)( - struct upb_EpsCopyInputStream* e, const char* ptr, int overrun); - -UPB_INLINE bool upb_EpsCopyInputStream_IsDone(struct upb_EpsCopyInputStream* e, - const char** ptr) { - return UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneWithCallback)( - e, ptr, UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackNoCallback)); -} - UPB_INLINE bool upb_EpsCopyInputStream_CheckSize( const struct upb_EpsCopyInputStream* e, const char* ptr, int size) { UPB_ASSERT(size >= 0); @@ -211,7 +225,9 @@ UPB_INLINE void upb_EpsCopyInputStream_StartCapture( UPB_INLINE bool upb_EpsCopyInputStream_EndCapture( struct upb_EpsCopyInputStream* e, const char* ptr, upb_StringView* sv) { UPB_ASSERT(e->capture_start != NULL); - if (ptr - e->end > e->limit) return false; + if (ptr - e->end > e->limit) { + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(e); + } const char* end = UPB_PRIVATE(upb_EpsCopyInputStream_GetInputPtr)(e, ptr); sv->data = e->capture_start; sv->size = end - sv->data; @@ -230,7 +246,12 @@ UPB_INLINE const char* upb_EpsCopyInputStream_ReadStringAlwaysAlias( // buffer, so we must fail if the size extends into the slop bytes. const char* limit = e->end + (e->input_delta == 0) * kUpb_EpsCopyInputStream_SlopBytes; - if ((ptrdiff_t)size > limit - ptr) return NULL; + if ((ptrdiff_t)size > limit - ptr) { + // For the moment, we consider this an error. In a multi-buffer world, + // it could be that the requested string extends into the next buffer, which + // is not an error and should be recoverable. + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(e); + } const char* input = UPB_PRIVATE(upb_EpsCopyInputStream_GetInputPtr)(e, ptr); *sv = upb_StringView_FromDataAndSize(input, size); return ptr + size; @@ -242,7 +263,12 @@ UPB_INLINE const char* upb_EpsCopyInputStream_ReadStringEphemeral( UPB_ASSERT(size <= PTRDIFF_MAX); // Size must be within the current buffer (including slop bytes). const char* limit = e->end + kUpb_EpsCopyInputStream_SlopBytes; - if ((ptrdiff_t)size > limit - ptr) return NULL; + if ((ptrdiff_t)size > limit - ptr) { + // For the moment, we consider this an error. In a multi-buffer world, + // it could be that the requested string extends into the next buffer, which + // is not an error and should be recoverable. + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(e); + } *sv = upb_StringView_FromDataAndSize(ptr, size); return ptr + size; } @@ -261,7 +287,9 @@ UPB_INLINE ptrdiff_t upb_EpsCopyInputStream_PushLimit( e->limit = limit; e->limit_ptr = e->end + UPB_MIN(0, limit); UPB_PRIVATE(upb_EpsCopyInputStream_CheckLimit)(e); - if (UPB_UNLIKELY(delta < 0)) e->error = true; + if (UPB_UNLIKELY(delta < 0)) { + UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(e); + } return delta; } @@ -277,35 +305,6 @@ UPB_INLINE void upb_EpsCopyInputStream_PopLimit( UPB_PRIVATE(upb_EpsCopyInputStream_CheckLimit)(e); } -UPB_INLINE const char* UPB_PRIVATE(upb_EpsCopyInputStream_IsDoneFallbackInline)( - struct upb_EpsCopyInputStream* e, const char* ptr, int overrun, - upb_EpsCopyInputStream_BufferFlipCallback* callback) { - if (overrun < e->limit) { - // Need to copy remaining data into patch buffer. - UPB_ASSERT(overrun < kUpb_EpsCopyInputStream_SlopBytes); - const char* old_end = ptr; - const char* new_start = &e->patch[0] + overrun; - memset(e->patch + kUpb_EpsCopyInputStream_SlopBytes, 0, - kUpb_EpsCopyInputStream_SlopBytes); - memcpy(e->patch, e->end, kUpb_EpsCopyInputStream_SlopBytes); - ptr = new_start; - e->end = &e->patch[kUpb_EpsCopyInputStream_SlopBytes]; - e->limit -= kUpb_EpsCopyInputStream_SlopBytes; - e->limit_ptr = e->end + e->limit; - UPB_ASSERT(ptr < e->limit_ptr); - e->input_delta = (uintptr_t)old_end - (uintptr_t)new_start; - const char* ret = callback(e, old_end, new_start); - if (ret) { - UPB_PRIVATE(upb_EpsCopyInputStream_BoundsChecked)(e); - } - return ret; - } else { - UPB_ASSERT(overrun > e->limit); - e->error = true; - return callback(e, NULL, NULL); - } -} - typedef const char* upb_EpsCopyInputStream_ParseDelimitedFunc( struct upb_EpsCopyInputStream* e, const char* ptr, int size, void* ctx); diff --git a/upb/wire/internal/reader.h b/upb/wire/internal/reader.h index 9b3c34c7534b4..5770e04a87013 100644 --- a/upb/wire/internal/reader.h +++ b/upb/wire/internal/reader.h @@ -29,11 +29,16 @@ extern "C" { #endif UPB_PRIVATE(_upb_WireReader_LongVarint) -UPB_PRIVATE(_upb_WireReader_ReadLongVarint32)(const char* ptr, uint32_t val); +UPB_PRIVATE(_upb_WireReader_ReadLongVarint)(const char* ptr, uint64_t val, + upb_EpsCopyInputStream* stream); UPB_PRIVATE(_upb_WireReader_LongVarint) -UPB_PRIVATE(_upb_WireReader_ReadLongVarint64)(const char* ptr, uint64_t val); +UPB_PRIVATE(_upb_WireReader_ReadLongTag)(const char* ptr, uint64_t val, + upb_EpsCopyInputStream* stream); +UPB_PRIVATE(_upb_WireReader_LongVarint) +UPB_PRIVATE(_upb_WireReader_ReadLongSize)(const char* ptr, uint64_t val, + upb_EpsCopyInputStream* stream); -UPB_FORCEINLINE const char* UPB_PRIVATE(_upb_WireReader_ReadVarint)( +UPB_FORCEINLINE const char* upb_WireReader_ReadVarint( const char* ptr, uint64_t* val, upb_EpsCopyInputStream* stream) { UPB_PRIVATE(upb_EpsCopyInputStream_ConsumeBytes)(stream, 10); uint8_t byte = *ptr; @@ -41,14 +46,13 @@ UPB_FORCEINLINE const char* UPB_PRIVATE(_upb_WireReader_ReadVarint)( *val = byte; return ptr + 1; } - UPB_PRIVATE(_upb_WireReader_LongVarint) - res = UPB_PRIVATE(_upb_WireReader_ReadLongVarint64)(ptr, byte); - if (UPB_UNLIKELY(!res.ptr)) return NULL; + UPB_PRIVATE(_upb_WireReader_LongVarint) res; + res = UPB_PRIVATE(_upb_WireReader_ReadLongVarint)(ptr, byte, stream); *val = res.val; - return res.ptr; + return UPB_PRIVATE(upb_EpsCopyInputStream_AssumeResult)(stream, res.ptr); } -UPB_FORCEINLINE const char* UPB_PRIVATE(_upb_WireReader_ReadTag)( +UPB_FORCEINLINE const char* upb_WireReader_ReadTag( const char* ptr, uint32_t* val, upb_EpsCopyInputStream* stream) { UPB_PRIVATE(upb_EpsCopyInputStream_ConsumeBytes)(stream, 5); uint8_t byte = *ptr; @@ -56,11 +60,24 @@ UPB_FORCEINLINE const char* UPB_PRIVATE(_upb_WireReader_ReadTag)( *val = byte; return ptr + 1; } - UPB_PRIVATE(_upb_WireReader_LongVarint) - res = UPB_PRIVATE(_upb_WireReader_ReadLongVarint32)(ptr, byte); - if (UPB_UNLIKELY(!res.ptr)) return NULL; + UPB_PRIVATE(_upb_WireReader_LongVarint) res; + res = UPB_PRIVATE(_upb_WireReader_ReadLongTag)(ptr, byte, stream); + *val = res.val; + return UPB_PRIVATE(upb_EpsCopyInputStream_AssumeResult)(stream, res.ptr); +} + +UPB_FORCEINLINE const char* upb_WireReader_ReadSize( + const char* ptr, int* val, upb_EpsCopyInputStream* stream) { + UPB_PRIVATE(upb_EpsCopyInputStream_ConsumeBytes)(stream, 5); + uint8_t byte = *ptr; + if (UPB_LIKELY((byte & 0x80) == 0)) { + *val = byte; + return ptr + 1; + } + UPB_PRIVATE(_upb_WireReader_LongVarint) res; + res = UPB_PRIVATE(_upb_WireReader_ReadLongSize)(ptr, byte, stream); *val = res.val; - return res.ptr; + return UPB_PRIVATE(upb_EpsCopyInputStream_AssumeResult)(stream, res.ptr); } UPB_API_INLINE uint32_t upb_WireReader_GetFieldNumber(uint32_t tag) { diff --git a/upb/wire/reader.c b/upb/wire/reader.c index 7b92dcfd64a2b..85cd3868f1a25 100644 --- a/upb/wire/reader.c +++ b/upb/wire/reader.c @@ -17,51 +17,64 @@ #include "upb/port/def.inc" UPB_NOINLINE UPB_PRIVATE(_upb_WireReader_LongVarint) - UPB_PRIVATE(_upb_WireReader_ReadLongVarint64)(const char* ptr, - uint64_t val) { - UPB_PRIVATE(_upb_WireReader_LongVarint) ret = {NULL, 0}; - uint64_t byte; + UPB_PRIVATE(_upb_WireReader_ReadLongVarint)( + const char* ptr, uint64_t val, upb_EpsCopyInputStream* stream) { for (int i = 1; i < 10; i++) { - byte = (uint8_t)ptr[i]; + uint64_t byte = (uint8_t)ptr[i]; val += (byte - 1) << (i * 7); if (!(byte & 0x80)) { - ret.ptr = ptr + i + 1; - ret.val = val; - return ret; + return (UPB_PRIVATE(_upb_WireReader_LongVarint)){ptr + i + 1, val}; } } - return ret; + return (UPB_PRIVATE(_upb_WireReader_LongVarint)){ + UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream), 0}; } UPB_NOINLINE UPB_PRIVATE(_upb_WireReader_LongVarint) - UPB_PRIVATE(_upb_WireReader_ReadLongVarint32)(const char* ptr, - uint32_t val) { - UPB_PRIVATE(_upb_WireReader_LongVarint) ret = {NULL, 0}; - uint64_t byte; + UPB_PRIVATE(_upb_WireReader_ReadLongTag)(const char* ptr, uint64_t val, + upb_EpsCopyInputStream* stream) { for (int i = 1; i < 5; i++) { - byte = (uint8_t)ptr[i]; + uint64_t byte = (uint8_t)ptr[i]; val += (byte - 1) << (i * 7); if (!(byte & 0x80)) { - ret.ptr = ptr + i + 1; - ret.val = val; - return ret; + if (val > UINT32_MAX) break; + return (UPB_PRIVATE(_upb_WireReader_LongVarint)){ptr + i + 1, val}; } } - return ret; + return (UPB_PRIVATE(_upb_WireReader_LongVarint)){ + UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream), 0}; +} + +UPB_NOINLINE UPB_PRIVATE(_upb_WireReader_LongVarint) + UPB_PRIVATE(_upb_WireReader_ReadLongSize)(const char* ptr, uint64_t val, + upb_EpsCopyInputStream* stream) { + for (int i = 1; i < 5; i++) { + uint64_t byte = (uint8_t)ptr[i]; + val += (byte - 1) << (i * 7); + if (!(byte & 0x80)) { + if (val > INT32_MAX) break; + return (UPB_PRIVATE(_upb_WireReader_LongVarint)){ptr + i + 1, val}; + } + } + return (UPB_PRIVATE(_upb_WireReader_LongVarint)){ + UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream), 0}; } const char* UPB_PRIVATE(_upb_WireReader_SkipGroup)( const char* ptr, uint32_t tag, int depth_limit, upb_EpsCopyInputStream* stream) { - if (--depth_limit < 0) return NULL; + if (--depth_limit < 0) { + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream); + } uint32_t end_group_tag = (tag & ~7ULL) | kUpb_WireType_EndGroup; while (!upb_EpsCopyInputStream_IsDone(stream, &ptr)) { uint32_t tag; ptr = upb_WireReader_ReadTag(ptr, &tag, stream); - if (!ptr) return NULL; + if (!ptr) break; if (tag == end_group_tag) return ptr; ptr = _upb_WireReader_SkipValue(ptr, tag, depth_limit, stream); - if (!ptr) return NULL; + if (!ptr) break; } - return NULL; // Encountered limit end before end group tag. + // Encountered limit end before end group tag. + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream); } diff --git a/upb/wire/reader.h b/upb/wire/reader.h index 063ad3895ba72..c69e777fd9a92 100644 --- a/upb/wire/reader.h +++ b/upb/wire/reader.h @@ -37,9 +37,7 @@ extern "C" { // Bounds checks must be performed before calling this function, preferably // by calling upb_EpsCopyInputStream_IsDone(). UPB_FORCEINLINE const char* upb_WireReader_ReadTag( - const char* ptr, uint32_t* tag, upb_EpsCopyInputStream* stream) { - return UPB_PRIVATE(_upb_WireReader_ReadTag)(ptr, tag, stream); -} + const char* ptr, uint32_t* tag, upb_EpsCopyInputStream* stream); // Given a tag, returns the field number. UPB_API_INLINE uint32_t upb_WireReader_GetFieldNumber(uint32_t tag); @@ -47,10 +45,8 @@ UPB_API_INLINE uint32_t upb_WireReader_GetFieldNumber(uint32_t tag); // Given a tag, returns the wire type. UPB_API_INLINE uint8_t upb_WireReader_GetWireType(uint32_t tag); -UPB_INLINE const char* upb_WireReader_ReadVarint( - const char* ptr, uint64_t* val, upb_EpsCopyInputStream* stream) { - return UPB_PRIVATE(_upb_WireReader_ReadVarint)(ptr, val, stream); -} +UPB_FORCEINLINE const char* upb_WireReader_ReadVarint( + const char* ptr, uint64_t* val, upb_EpsCopyInputStream* stream); // Skips data for a varint, returning a pointer past the end of the varint, or // NULL if there was an error in the varint data. @@ -71,13 +67,7 @@ UPB_INLINE const char* upb_WireReader_SkipVarint( // Bounds checks must be performed before calling this function, preferably // by calling upb_EpsCopyInputStream_IsDone(). UPB_INLINE const char* upb_WireReader_ReadSize(const char* ptr, int* size, - upb_EpsCopyInputStream* stream) { - uint64_t size64; - ptr = upb_WireReader_ReadVarint(ptr, &size64, stream); - if (!ptr || size64 >= INT32_MAX) return NULL; - *size = size64; - return ptr; -} + upb_EpsCopyInputStream* stream); // Reads a fixed32 field, performing byte swapping if necessary. // @@ -124,7 +114,9 @@ const char* UPB_PRIVATE(_upb_WireReader_SkipGroup)( // control over this? UPB_INLINE const char* upb_WireReader_SkipGroup( const char* ptr, uint32_t tag, upb_EpsCopyInputStream* stream) { - return UPB_PRIVATE(_upb_WireReader_SkipGroup)(ptr, tag, 100, stream); + const char* ret = + UPB_PRIVATE(_upb_WireReader_SkipGroup)(ptr, tag, 100, stream); + return UPB_PRIVATE(upb_EpsCopyInputStream_AssumeResult)(stream, ret); } UPB_INLINE const char* _upb_WireReader_SkipValue( @@ -143,7 +135,7 @@ UPB_INLINE const char* _upb_WireReader_SkipValue( int size; ptr = upb_WireReader_ReadSize(ptr, &size, stream); if (!ptr || !upb_EpsCopyInputStream_CheckSize(stream, ptr, size)) { - return NULL; + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream); } ptr += size; return ptr; @@ -152,9 +144,10 @@ UPB_INLINE const char* _upb_WireReader_SkipValue( return UPB_PRIVATE(_upb_WireReader_SkipGroup)(ptr, tag, depth_limit, stream); case kUpb_WireType_EndGroup: - return NULL; // Should be handled before now. + // Should be handled before now. default: - return NULL; // Unknown wire type. + // Unknown wire type. + return UPB_PRIVATE(upb_EpsCopyInputStream_ReturnError)(stream); } }