From dac2e91e36408087d769be89a72fbafe1ea5039c Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 4 Mar 2025 13:16:32 -0800 Subject: [PATCH 1/2] Internal pure python fixes PiperOrigin-RevId: 733441339 --- python/google/protobuf/internal/decoder.py | 98 ++++++++++++++----- .../google/protobuf/internal/message_test.py | 1 + .../protobuf/internal/python_message.py | 7 +- .../protobuf/internal/self_recursive.proto | 9 +- 4 files changed, 86 insertions(+), 29 deletions(-) diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index dcde1d942..89d829142 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -184,7 +184,10 @@ def _SimpleDecoder(wire_type, decode_value): clear_if_default=False): if is_packed: local_DecodeVarint = _DecodeVarint - def DecodePackedField(buffer, pos, end, message, field_dict): + def DecodePackedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -199,11 +202,15 @@ def _SimpleDecoder(wire_type, decode_value): del value[-1] # Discard corrupt value. raise _DecodeError('Packed element was truncated.') return pos + return DecodePackedField elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_type) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -218,9 +225,12 @@ def _SimpleDecoder(wire_type, decode_value): if new_pos > end: raise _DecodeError('Truncated message.') return new_pos + return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused (new_value, pos) = decode_value(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') @@ -229,6 +239,7 @@ def _SimpleDecoder(wire_type, decode_value): else: field_dict[key] = new_value return pos + return DecodeField return SpecificDecoder @@ -364,7 +375,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, enum_type = key.enum_type if is_packed: local_DecodeVarint = _DecodeVarint - def DecodePackedField(buffer, pos, end, message, field_dict): + def DecodePackedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): """Decode serialized packed enum to its value and a new position. Args: @@ -377,6 +390,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, Returns: int, new position in serialized data. """ + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -407,11 +421,14 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, # pylint: enable=protected-access raise _DecodeError('Packed element was truncated.') return pos + return DecodePackedField elif is_repeated: tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): """Decode serialized repeated enum to its value and a new position. Args: @@ -424,6 +441,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, Returns: int, new position in serialized data. """ + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -446,9 +464,11 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, if new_pos > end: raise _DecodeError('Truncated message.') return new_pos + return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): """Decode serialized repeated enum to its value and a new position. Args: @@ -461,6 +481,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, Returns: int, new position in serialized data. """ + del current_depth # unused value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: @@ -480,6 +501,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default, (tag_bytes, buffer[value_start_pos:pos].tobytes())) # pylint: enable=protected-access return pos + return DecodeField @@ -538,7 +560,10 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default, tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -553,9 +578,12 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default, if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos + return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: @@ -565,6 +593,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default, else: field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos + return DecodeField @@ -579,7 +608,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): + del current_depth # unused value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -594,9 +626,12 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos + return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # unused (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size if new_pos > end: @@ -606,6 +641,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default, else: field_dict[key] = buffer[pos:new_pos].tobytes() return new_pos + return DecodeField @@ -621,7 +657,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -630,7 +668,7 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): if value is None: value = field_dict.setdefault(key, new_default(message)) # Read sub-message. - pos = value.add()._InternalParse(buffer, pos, end) + pos = value.add()._InternalParse(buffer, pos, end, current_depth) # Read end tag. new_pos = pos+end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: @@ -640,19 +678,22 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos + return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) # Read sub-message. - pos = value._InternalParse(buffer, pos, end) + pos = value._InternalParse(buffer, pos, end, current_depth) # Read end tag. new_pos = pos+end_tag_len if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: raise _DecodeError('Missing group end tag.') return new_pos + return DecodeField @@ -666,7 +707,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_len = len(tag_bytes) - def DecodeRepeatedField(buffer, pos, end, message, field_dict): + def DecodeRepeatedField( + buffer, pos, end, message, field_dict, current_depth=0 + ): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -677,7 +720,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if new_pos > end: raise _DecodeError('Truncated message.') # Read sub-message. - if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: + if ( + value.add()._InternalParse(buffer, pos, new_pos, current_depth) + != new_pos + ): # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise _DecodeError('Unexpected end-group tag.') @@ -686,9 +732,11 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if buffer[new_pos:pos] != tag_bytes or new_pos == end: # Prediction failed. Return. return new_pos + return DecodeRepeatedField else: - def DecodeField(buffer, pos, end, message, field_dict): + + def DecodeField(buffer, pos, end, message, field_dict, current_depth=0): value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -698,11 +746,12 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if new_pos > end: raise _DecodeError('Truncated message.') # Read sub-message. - if value._InternalParse(buffer, pos, new_pos) != new_pos: + if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos: # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') return new_pos + return DecodeField @@ -851,7 +900,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map): # Can't read _concrete_class yet; might not be initialized. message_type = field_descriptor.message_type - def DecodeMap(buffer, pos, end, message, field_dict): + def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0): + del current_depth # Unused. submsg = message_type._concrete_class() value = field_dict.get(key) if value is None: @@ -934,7 +984,7 @@ def _SkipGroup(buffer, pos, end): pos = new_pos -def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): +def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0): """Decode UnknownFieldSet. Returns the UnknownFieldSet and new position.""" unknown_field_set = containers.UnknownFieldSet() @@ -944,14 +994,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): field_number, wire_type = wire_format.UnpackTag(tag) if wire_type == wire_format.WIRETYPE_END_GROUP: break - (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) + (data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth) # pylint: disable=protected-access unknown_field_set._add(field_number, wire_type, data) return (unknown_field_set, pos) -def _DecodeUnknownField(buffer, pos, wire_type): +def _DecodeUnknownField( + buffer, pos, wire_type, current_depth=0 +): """Decode a unknown field. Returns the UnknownField and new position.""" if wire_type == wire_format.WIRETYPE_VARINT: @@ -965,7 +1017,7 @@ def _DecodeUnknownField(buffer, pos, wire_type): data = buffer[pos:pos+size].tobytes() pos += size elif wire_type == wire_format.WIRETYPE_START_GROUP: - (data, pos) = _DecodeUnknownFieldSet(buffer, pos) + (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth) elif wire_type == wire_format.WIRETYPE_END_GROUP: return (0, -1) else: diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 2a723eabb..48e6df806 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -30,6 +30,7 @@ import warnings cmp = lambda x, y: (x > y) - (x < y) from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top +from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import more_extensions_pb2 diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index fabc6aa07..62c059cd2 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -1194,7 +1194,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): fields_by_tag = cls._fields_by_tag message_set_decoders_by_tag = cls._message_set_decoders_by_tag - def InternalParse(self, buffer, pos, end): + def InternalParse(self, buffer, pos, end, current_depth=0): """Create a message from serialized bytes. Args: @@ -1244,10 +1244,13 @@ def _AddMergeFromStringMethod(message_descriptor, cls): else: _MaybeAddDecoder(cls, field_des) field_decoder = field_des._decoders[is_packed] - pos = field_decoder(buffer, new_pos, end, self, field_dict) + pos = field_decoder( + buffer, new_pos, end, self, field_dict, current_depth + ) if field_des.containing_oneof: self._UpdateOneofState(field_des) return pos + cls._InternalParse = InternalParse diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto index dbfcaf971..20bc2b4d3 100644 --- a/python/google/protobuf/internal/self_recursive.proto +++ b/python/google/protobuf/internal/self_recursive.proto @@ -5,18 +5,19 @@ // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd -syntax = "proto2"; +edition = "2023"; package google.protobuf.python.internal; message SelfRecursive { - optional SelfRecursive sub = 1; + SelfRecursive sub = 1; + int32 i = 2; } message IndirectRecursive { - optional IntermediateRecursive intermediate = 1; + IntermediateRecursive intermediate = 1; } message IntermediateRecursive { - optional IndirectRecursive indirect = 1; + IndirectRecursive indirect = 1; } -- 2.51.1