From 3b0c1b873793da29ff46a42d448cda3b5f937c8a Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Mon, 28 Aug 2023 09:58:39 -0700 Subject: [PATCH 1/2] Fix a bug that strips options from descriptor.proto in Pure Python. GetOptions on fields (which parse the _serialized_options) will be called for the first time of parse or serialize instead of Build time. Note: GetOptions on messages are still called in Build time because of message_set_wire_format. If message options are needed in descriptor.proto, a parse error will be raised in GetOptions(). We can check the file to not invoke GetOptions() for descriptor.proto as long as message_set_wire_format not needed in descriptor.proto. Other options except message options do not invoke GetOptions() in Build time PiperOrigin-RevId: 560741182 --- python/google/protobuf/descriptor.py | 18 ++-- python/google/protobuf/descriptor_pool.py | 1 + .../protobuf/internal/python_message.py | 102 +++++++++++++----- .../protobuf/internal/reflection_test.py | 43 ++++++++ .../protobuf/internal/unknown_fields_test.py | 33 ------ 5 files changed, 128 insertions(+), 69 deletions(-) diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index fcb87cab5..4c03a4669 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -176,14 +176,15 @@ class DescriptorBase(metaclass=DescriptorMetaclass): raise RuntimeError('Unknown options class name %s!' % (self._options_class_name)) - with _lock: - if self._serialized_options is None: + if self._serialized_options is None: + with _lock: self._options = options_class() - else: - self._options = _ParseOptions(options_class(), - self._serialized_options) + else: + options = _ParseOptions(options_class(), self._serialized_options) + with _lock: + self._options = options - return self._options + return self._options class _NestedDescriptorBase(DescriptorBase): @@ -285,6 +286,7 @@ class Descriptor(_NestedDescriptorBase): oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in :attr:`oneofs`, but indexed by "name" attribute. file (FileDescriptor): Reference to file descriptor. + is_map_entry: If the message type is a map entry. """ @@ -310,6 +312,7 @@ class Descriptor(_NestedDescriptorBase): serialized_start=None, serialized_end=None, syntax=None, + is_map_entry=False, create_key=None): _message.Message._CheckCalledFromGeneratedFile() return _message.default_pool.FindMessageTypeByName(full_name) @@ -322,7 +325,7 @@ class Descriptor(_NestedDescriptorBase): serialized_options=None, is_extendable=True, extension_ranges=None, oneofs=None, file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin - syntax=None, create_key=None): + syntax=None, is_map_entry=False, create_key=None): """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -372,6 +375,7 @@ class Descriptor(_NestedDescriptorBase): for oneof in self.oneofs: oneof.containing_type = self self.syntax = syntax or "proto2" + self._is_map_entry = is_map_entry @property def fields_by_camelcase_name(self): diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 1ebf11834..04bd7a3e9 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -897,6 +897,7 @@ class DescriptorPool(object): serialized_start=None, serialized_end=None, syntax=syntax, + is_map_entry=desc_proto.options.map_entry, # pylint: disable=protected-access create_key=descriptor._internal_create_key) for nested in desc.nested_types: diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index bf9acefd2..a72276d1a 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -182,11 +182,14 @@ class GeneratedProtocolMessageType(type): % (descriptor.full_name)) return - cls._decoders_by_tag = {} + cls._message_set_decoders_by_tag = {} + cls._fields_by_tag = {} if (descriptor.has_options and descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(descriptor), None) + cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(descriptor), + None, + ) # Attach stuff to each FieldDescriptor for quick lookup later on. for field in descriptor.fields: @@ -272,16 +275,36 @@ def _IsMessageSetExtension(field): def _IsMapField(field): return (field.type == _FieldDescriptor.TYPE_MESSAGE and - field.message_type.has_options and - field.message_type.GetOptions().map_entry) + field.message_type._is_map_entry) def _IsMessageMapField(field): value_type = field.message_type.fields_by_name['value'] return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE - def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor + ) + + def AddFieldByTag(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed) + + AddFieldByTag( + type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False + ) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + + +def _MaybeAddEncoder(cls, field_descriptor): + if hasattr(field_descriptor, '_encoder'): + return is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) is_map_entry = _IsMapField(field_descriptor) is_packed = field_descriptor.is_packed @@ -301,11 +324,17 @@ def _AttachFieldHelpers(cls, field_descriptor): field_descriptor._encoder = field_encoder field_descriptor._sizer = sizer - field_descriptor._default_constructor = _DefaultValueConstructorForField( - field_descriptor) - def AddDecoder(wiretype, is_packed): - tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + +def _MaybeAddDecoder(cls, field_descriptor): + if hasattr(field_descriptor, '_decoders'): + return + + is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED + is_map_entry = _IsMapField(field_descriptor) + field_descriptor._decoders = {} + + def AddDecoder(is_packed): decode_type = field_descriptor.type if (decode_type == _FieldDescriptor.TYPE_ENUM and not field_descriptor.enum_type.is_closed): @@ -337,15 +366,14 @@ def _AttachFieldHelpers(cls, field_descriptor): field_descriptor, field_descriptor._default_constructor, not field_descriptor.has_presence) - cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) + field_descriptor._decoders[is_packed] = field_decoder - AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], - False) + AddDecoder(False) if is_repeated and wire_format.IsTypePackable(field_descriptor.type): # To support wire compatibility of adding packed = true, add a decoder for # packed values regardless of the field's options. - AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + AddDecoder(True) def _AddClassAttributesForNestedExtensions(descriptor, dictionary): @@ -1031,12 +1059,17 @@ def _AddByteSizeMethod(message_descriptor, cls): size = 0 descriptor = self.DESCRIPTOR - if descriptor.GetOptions().map_entry: + if descriptor._is_map_entry: # Fields of map entry should always be serialized. - size = descriptor.fields_by_name['key']._sizer(self.key) - size += descriptor.fields_by_name['value']._sizer(self.value) + key_field = descriptor.fields_by_name['key'] + _MaybeAddEncoder(cls, key_field) + size = key_field._sizer(self.key) + value_field = descriptor.fields_by_name['value'] + _MaybeAddEncoder(cls, value_field) + size += value_field._sizer(self.value) else: for field_descriptor, field_value in self.ListFields(): + _MaybeAddEncoder(cls, field_descriptor) size += field_descriptor._sizer(field_value) for tag_bytes, value_bytes in self._unknown_fields: size += len(tag_bytes) + len(value_bytes) @@ -1079,14 +1112,17 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): deterministic = bool(deterministic) descriptor = self.DESCRIPTOR - if descriptor.GetOptions().map_entry: + if descriptor._is_map_entry: # Fields of map entry should always be serialized. - descriptor.fields_by_name['key']._encoder( - write_bytes, self.key, deterministic) - descriptor.fields_by_name['value']._encoder( - write_bytes, self.value, deterministic) + key_field = descriptor.fields_by_name['key'] + _MaybeAddEncoder(cls, key_field) + key_field._encoder(write_bytes, self.key, deterministic) + value_field = descriptor.fields_by_name['value'] + _MaybeAddEncoder(cls, value_field) + value_field._encoder(write_bytes, self.value, deterministic) else: for field_descriptor, field_value in self.ListFields(): + _MaybeAddEncoder(cls, field_descriptor) field_descriptor._encoder(write_bytes, field_value, deterministic) for tag_bytes, value_bytes in self._unknown_fields: write_bytes(tag_bytes) @@ -1114,7 +1150,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField - decoders_by_tag = cls._decoders_by_tag + fields_by_tag = cls._fields_by_tag + message_set_decoders_by_tag = cls._message_set_decoders_by_tag def InternalParse(self, buffer, pos, end): """Create a message from serialized bytes. @@ -1137,8 +1174,14 @@ def _AddMergeFromStringMethod(message_descriptor, cls): unknown_field_set = self._unknown_field_set while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) - field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) - if field_decoder is None: + field_decoder, field_des = message_set_decoders_by_tag.get( + tag_bytes, (None, None) + ) + if field_decoder: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + continue + field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None)) + if field_des is None: if not self._unknown_fields: # pylint: disable=protected-access self._unknown_fields = [] # pylint: disable=protected-access if unknown_field_set is None: @@ -1167,9 +1210,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls): (tag_bytes, buffer[old_pos:new_pos].tobytes())) pos = new_pos else: + _MaybeAddDecoder(cls, field_des) + field_decoder = field_des._decoders[is_packed] pos = field_decoder(buffer, new_pos, end, self, field_dict) - if field_desc: - self._UpdateOneofState(field_desc) + if field_des.containing_oneof: + self._UpdateOneofState(field_des) return pos cls._InternalParse = InternalParse @@ -1205,8 +1250,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: - if (field.message_type.has_options and - field.message_type.GetOptions().map_entry): + if (field.message_type._is_map_entry): continue for element in value: if not element.IsInitialized(): diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 0708d51e6..c5a600fa9 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -2053,6 +2053,49 @@ class Proto2ReflectionTest(unittest.TestCase): # dependency on the C++ logging code. self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testDescriptorProtoHasFileOptions(self): + self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) + self.assertEqual( + descriptor_pb2.DESCRIPTOR.GetOptions().java_package, + 'com.google.protobuf', + ) + + def testDescriptorProtoHasFieldOptions(self): + self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) + self.assertEqual( + descriptor_pb2.DESCRIPTOR.GetOptions().java_package, + 'com.google.protobuf', + ) + packed_desc = ( + descriptor_pb2.SourceCodeInfo.DESCRIPTOR.nested_types_by_name.get( + 'Location' + ).fields_by_name.get('path') + ) + self.assertTrue(packed_desc.has_options) + self.assertTrue(packed_desc.GetOptions().packed) + + def testDescriptorProtoHasFeatureOptions(self): + self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) + self.assertEqual( + descriptor_pb2.DESCRIPTOR.GetOptions().java_package, + 'com.google.protobuf', + ) + presence_desc = descriptor_pb2.FeatureSet.DESCRIPTOR.fields_by_name.get( + 'field_presence' + ) + self.assertTrue(presence_desc.has_options) + self.assertEqual( + presence_desc.GetOptions().retention, + descriptor_pb2.FieldOptions.OptionRetention.RETENTION_RUNTIME, + ) + self.assertListsEqual( + presence_desc.GetOptions().targets, + [ + descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FIELD, + descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FILE, + ], + ) + def testStringUTF8Serialization(self): proto = message_set_extensions_pb2.TestMessageSet() extension_message = message_set_extensions_pb2.TestMessageSetExtension2 diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index ec1aa1b45..9a8d7d751 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -177,25 +177,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - # InternalCheckUnknownField() is an additional Pure Python check which checks - # a detail of unknown fields. It cannot be used by the C++ - # implementation because some protect members are called. - # The test is added for historical reasons. It is not necessary as - # serialized string is checked. - # TODO(jieluo): Remove message._unknown_fields. - def InternalCheckUnknownField(self, name, expected_value): - if api_implementation.Type() != 'python': - return - field_descriptor = self.descriptor.fields_by_name[name] - wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] - field_tag = encoder.TagBytes(field_descriptor.number, wire_type) - result_dict = {} - for tag_bytes, value in self.empty_message._unknown_fields: - if tag_bytes == field_tag: - decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] - decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) - self.assertEqual(expected_value, result_dict[field_descriptor]) - def CheckUnknownField(self, name, unknown_field_set, expected_value): field_descriptor = self.descriptor.fields_by_name[name] expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ @@ -223,50 +204,36 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.CheckUnknownField('optional_nested_enum', unknown_field_set, self.all_fields.optional_nested_enum) - self.InternalCheckUnknownField('optional_nested_enum', - self.all_fields.optional_nested_enum) # Test repeated enum. self.CheckUnknownField('repeated_nested_enum', unknown_field_set, self.all_fields.repeated_nested_enum) - self.InternalCheckUnknownField('repeated_nested_enum', - self.all_fields.repeated_nested_enum) # Test varint. self.CheckUnknownField('optional_int32', unknown_field_set, self.all_fields.optional_int32) - self.InternalCheckUnknownField('optional_int32', - self.all_fields.optional_int32) # Test fixed32. self.CheckUnknownField('optional_fixed32', unknown_field_set, self.all_fields.optional_fixed32) - self.InternalCheckUnknownField('optional_fixed32', - self.all_fields.optional_fixed32) # Test fixed64. self.CheckUnknownField('optional_fixed64', unknown_field_set, self.all_fields.optional_fixed64) - self.InternalCheckUnknownField('optional_fixed64', - self.all_fields.optional_fixed64) # Test length delimited. self.CheckUnknownField('optional_string', unknown_field_set, self.all_fields.optional_string.encode('utf-8')) - self.InternalCheckUnknownField('optional_string', - self.all_fields.optional_string) # Test group. self.CheckUnknownField('optionalgroup', unknown_field_set, (17, 0, 117)) - self.InternalCheckUnknownField('optionalgroup', - self.all_fields.optionalgroup) self.assertEqual(98, len(unknown_field_set)) -- 2.49.0