Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0d5055a

Browse files
haiyizxxmergify[bot]
authored andcommittedJan 6, 2025·
refactor: improve edge case handling for recursion limits (#22988)
Co-authored-by: Alex | Skip <alex@skip.money> (cherry picked from commit 93282e1) # Conflicts: # CHANGELOG.md # x/tx/decode/unknown.go
1 parent 8e710b7 commit 0d5055a

File tree

4 files changed

+213
-3
lines changed

4 files changed

+213
-3
lines changed
 

‎CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ Ref: https://keepachangelog.com/en/1.0.0/
4040

4141
Every module contains its own CHANGELOG.md. Please refer to the module you are interested in.
4242

43+
<<<<<<< HEAD
44+
=======
45+
### Features
46+
47+
* (baseapp) [#20291](https://github.com/cosmos/cosmos-sdk/pull/20291) Simulate nested messages.
48+
* (client/keys) [#21829](https://github.com/cosmos/cosmos-sdk/pull/21829) Add support for importing hex key using standard input.
49+
* (x/auth/ante) [#23128](https://github.com/cosmos/cosmos-sdk/pull/23128) Allow custom verifyIsOnCurve when validate tx for public key like ethsecp256k1.
50+
51+
### Improvements
52+
53+
* (codec) [#22988](https://github.com/cosmos/cosmos-sdk/pull/22988) Improve edge case handling for recursion limits.
54+
55+
>>>>>>> 93282e101 (refactor: improve edge case handling for recursion limits (#22988))
4356
### Bug Fixes
4457

4558
* (x/auth/tx) [#23148](https://github.com/cosmos/cosmos-sdk/pull/23148) Avoid panic from intoAnyV2 when v1.PublicKey is optional.

‎codec/types/interface_registry.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,10 @@ func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker {
274274
// UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer.
275275
// It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required.
276276
func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error {
277-
if r.maxDepth == 0 {
277+
if r.maxDepth <= 0 {
278278
return errors.New("max depth exceeded")
279279
}
280-
if r.maxCalls.count == 0 {
280+
if r.maxCalls.count <= 0 {
281281
return errors.New("call limit exceeded")
282282
}
283283
// here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding

‎codec/unknownproto/unknown_fields.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func doRejectUnknownFields(
5454
if len(bz) == 0 {
5555
return hasUnknownNonCriticals, nil
5656
}
57-
if recursionLimit == 0 {
57+
if recursionLimit <= 0 {
5858
return false, errors.New("recursion limit reached")
5959
}
6060

‎x/tx/decode/unknown.go

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
package decode
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"strings"
7+
8+
"google.golang.org/protobuf/encoding/protowire"
9+
"google.golang.org/protobuf/proto"
10+
"google.golang.org/protobuf/reflect/protodesc"
11+
"google.golang.org/protobuf/reflect/protoreflect"
12+
"google.golang.org/protobuf/types/known/anypb"
13+
)
14+
15+
const bit11NonCritical = 1 << 10
16+
17+
var (
18+
anyDesc = (&anypb.Any{}).ProtoReflect().Descriptor()
19+
anyFullName = anyDesc.FullName()
20+
)
21+
22+
// RejectUnknownFieldsStrict operates by the same rules as RejectUnknownFields, but returns an error if any unknown
23+
// non-critical fields are encountered.
24+
func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, resolver protodesc.Resolver) error {
25+
_, err := RejectUnknownFields(bz, msg, false, resolver)
26+
return err
27+
}
28+
29+
// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
30+
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
31+
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
32+
// used to treat a message with non-critical field different in different security contexts (such as transaction signing).
33+
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
34+
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
35+
func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) {
36+
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
37+
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000)
38+
}
39+
40+
func doRejectUnknownFields(
41+
bz []byte,
42+
desc protoreflect.MessageDescriptor,
43+
allowUnknownNonCriticals bool,
44+
resolver protodesc.Resolver,
45+
recursionLimit int,
46+
) (hasUnknownNonCriticals bool, err error) {
47+
if len(bz) == 0 {
48+
return hasUnknownNonCriticals, nil
49+
}
50+
if recursionLimit <= 0 {
51+
return false, errors.New("recursion limit reached")
52+
}
53+
54+
fields := desc.Fields()
55+
56+
for len(bz) > 0 {
57+
tagNum, wireType, m := protowire.ConsumeTag(bz)
58+
if m < 0 {
59+
return hasUnknownNonCriticals, errors.New("invalid length")
60+
}
61+
62+
fieldDesc := fields.ByNumber(tagNum)
63+
if fieldDesc == nil {
64+
isCriticalField := tagNum&bit11NonCritical == 0
65+
66+
if !isCriticalField {
67+
hasUnknownNonCriticals = true
68+
}
69+
70+
if isCriticalField || !allowUnknownNonCriticals {
71+
// The tag is critical, so report it.
72+
return hasUnknownNonCriticals, ErrUnknownField.Wrapf(
73+
"%s: {TagNum: %d, WireType:%q}",
74+
desc.FullName(), tagNum, WireTypeToString(wireType))
75+
}
76+
}
77+
78+
// Skip over the bytes that store fieldNumber and wireType bytes.
79+
bz = bz[m:]
80+
n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
81+
if n < 0 {
82+
err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w",
83+
tagNum, WireTypeToString(wireType), protowire.ParseError(n))
84+
return hasUnknownNonCriticals, err
85+
}
86+
fieldBytes := bz[:n]
87+
bz = bz[n:]
88+
89+
// An unknown but non-critical field
90+
if fieldDesc == nil {
91+
continue
92+
}
93+
94+
fieldMessage := fieldDesc.Message()
95+
// not message or group kind
96+
if fieldMessage == nil {
97+
continue
98+
}
99+
// if a message descriptor is a placeholder resolve it using the injected resolver.
100+
// this can happen when a descriptor has been registered in the
101+
// "google.golang.org/protobuf" registry but not in "github.com/cosmos/gogoproto".
102+
// fixes: https://github.com/cosmos/cosmos-sdk/issues/22574
103+
if fieldMessage.IsPlaceholder() {
104+
gogoDesc, err := resolver.FindDescriptorByName(fieldMessage.FullName())
105+
if err != nil {
106+
return hasUnknownNonCriticals, fmt.Errorf("could not resolve placeholder descriptor: %v: %w", fieldMessage, err)
107+
}
108+
fieldMessage = gogoDesc.(protoreflect.MessageDescriptor)
109+
}
110+
111+
// consume length prefix of nested message
112+
_, o := protowire.ConsumeVarint(fieldBytes)
113+
if o < 0 {
114+
err = fmt.Errorf("could not consume length prefix fieldBytes for nested message: %v: %w",
115+
fieldMessage, protowire.ParseError(o))
116+
return hasUnknownNonCriticals, err
117+
} else if o > len(fieldBytes) {
118+
err = fmt.Errorf("length prefix > len(fieldBytes) for nested message: %v", fieldMessage)
119+
return hasUnknownNonCriticals, err
120+
}
121+
122+
fieldBytes = fieldBytes[o:]
123+
124+
var err error
125+
126+
if fieldMessage.FullName() == anyFullName {
127+
// Firstly typecheck types.Any to ensure nothing snuck in.
128+
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1)
129+
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
130+
if err != nil {
131+
return hasUnknownNonCriticals, err
132+
}
133+
var a anypb.Any
134+
if err = proto.Unmarshal(fieldBytes, &a); err != nil {
135+
return hasUnknownNonCriticals, err
136+
}
137+
138+
msgName := protoreflect.FullName(strings.TrimPrefix(a.TypeUrl, "/"))
139+
msgDesc, err := resolver.FindDescriptorByName(msgName)
140+
if err != nil {
141+
return hasUnknownNonCriticals, err
142+
}
143+
144+
fieldMessage = msgDesc.(protoreflect.MessageDescriptor)
145+
fieldBytes = a.Value
146+
}
147+
148+
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1)
149+
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
150+
if err != nil {
151+
return hasUnknownNonCriticals, err
152+
}
153+
}
154+
155+
return hasUnknownNonCriticals, nil
156+
}
157+
158+
// errUnknownField represents an error indicating that we encountered
159+
// a field that isn't available in the target proto.Message.
160+
type errUnknownField struct {
161+
Desc protoreflect.MessageDescriptor
162+
TagNum protowire.Number
163+
WireType protowire.Type
164+
}
165+
166+
// String implements fmt.Stringer.
167+
func (twt *errUnknownField) String() string {
168+
return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
169+
twt.Desc.FullName(), twt.TagNum, WireTypeToString(twt.WireType))
170+
}
171+
172+
// Error implements the error interface.
173+
func (twt *errUnknownField) Error() string {
174+
return twt.String()
175+
}
176+
177+
var _ error = (*errUnknownField)(nil)
178+
179+
// WireTypeToString returns a string representation of the given protowire.Type.
180+
func WireTypeToString(wt protowire.Type) string {
181+
switch wt {
182+
case 0:
183+
return "varint"
184+
case 1:
185+
return "fixed64"
186+
case 2:
187+
return "bytes"
188+
case 3:
189+
return "start_group"
190+
case 4:
191+
return "end_group"
192+
case 5:
193+
return "fixed32"
194+
default:
195+
return fmt.Sprintf("unknown type: %d", wt)
196+
}
197+
}

0 commit comments

Comments
 (0)
Please sign in to comment.