@@ -511,6 +511,13 @@ impl<'a> CodedInputStream<'a> {
511511 }
512512
513513 fn skip_group ( & mut self ) -> crate :: Result < ( ) > {
514+ self . incr_recursion ( ) ?;
515+ let ret = self . skip_group_no_depth_check ( ) ;
516+ self . decr_recursion ( ) ;
517+ ret
518+ }
519+
520+ fn skip_group_no_depth_check ( & mut self ) -> crate :: Result < ( ) > {
514521 while !self . eof ( ) ? {
515522 let wire_type = self . read_tag_unpack ( ) ?. 1 ;
516523 if wire_type == WireType :: EndGroup {
@@ -631,19 +638,16 @@ impl<'a> CodedInputStream<'a> {
631638 /// Read message, do not check if message is initialized
632639 pub fn merge_message < M : Message > ( & mut self , message : & mut M ) -> crate :: Result < ( ) > {
633640 self . incr_recursion ( ) ?;
634- struct DecrRecursion < ' a , ' b > ( & ' a mut CodedInputStream < ' b > ) ;
635- impl < ' a , ' b > Drop for DecrRecursion < ' a , ' b > {
636- fn drop ( & mut self ) {
637- self . 0 . decr_recursion ( ) ;
638- }
639- }
640-
641- let mut decr = DecrRecursion ( self ) ;
641+ let ret = self . merge_message_no_depth_check ( message) ;
642+ self . decr_recursion ( ) ;
643+ ret
644+ }
642645
643- let len = decr. 0 . read_raw_varint64 ( ) ?;
644- let old_limit = decr. 0 . push_limit ( len) ?;
645- message. merge_from ( & mut decr. 0 ) ?;
646- decr. 0 . pop_limit ( old_limit) ;
646+ fn merge_message_no_depth_check < M : Message > ( & mut self , message : & mut M ) -> crate :: Result < ( ) > {
647+ let len = self . read_raw_varint64 ( ) ?;
648+ let old_limit = self . push_limit ( len) ?;
649+ message. merge_from ( self ) ?;
650+ self . pop_limit ( old_limit) ;
647651 Ok ( ( ) )
648652 }
649653
@@ -982,4 +986,47 @@ mod test {
982986 ) ;
983987 assert_eq ! ( "field 3" , input. read_string( ) . unwrap( ) ) ;
984988 }
989+
990+ #[ test]
991+ fn test_shallow_nested_unknown_groups ( ) {
992+ // Test skip_group() succeeds on a start group tag 50 times
993+ // followed by end group tag 50 times. We should be able to
994+ // successfully skip the outermost group.
995+ let mut vec = Vec :: new ( ) ;
996+ let mut os = CodedOutputStream :: new ( & mut vec) ;
997+ for _ in 0 ..50 {
998+ os. write_tag ( 1 , WireType :: StartGroup ) . unwrap ( ) ;
999+ }
1000+ for _ in 0 ..50 {
1001+ os. write_tag ( 1 , WireType :: EndGroup ) . unwrap ( ) ;
1002+ }
1003+ drop ( os) ;
1004+
1005+ let mut input = CodedInputStream :: from_bytes ( & vec) ;
1006+ assert ! ( input. skip_group( ) . is_ok( ) ) ;
1007+ }
1008+
1009+ #[ test]
1010+ fn test_deeply_nested_unknown_groups ( ) {
1011+ // Create an output stream that has groups nested recursively 1000
1012+ // deep, and try to skip the group.
1013+ // This should fail the default depth limit of 100 which ensures we
1014+ // don't blow the stack on adversial input.
1015+ let mut vec = Vec :: new ( ) ;
1016+ let mut os = CodedOutputStream :: new ( & mut vec) ;
1017+ for _ in 0 ..1000 {
1018+ os. write_tag ( 1 , WireType :: StartGroup ) . unwrap ( ) ;
1019+ }
1020+ for _ in 0 ..1000 {
1021+ os. write_tag ( 1 , WireType :: EndGroup ) . unwrap ( ) ;
1022+ }
1023+ drop ( os) ;
1024+
1025+ let mut input = CodedInputStream :: from_bytes ( & vec) ;
1026+ assert ! ( input
1027+ . skip_group( )
1028+ . unwrap_err( )
1029+ . to_string( )
1030+ . contains( "Over recursion limit" ) ) ;
1031+ }
9851032}
0 commit comments