Files
protobuf/fix-options-parsing-bug.patch

429 lines
18 KiB
Diff

From 3b0c1b873793da29ff46a42d448cda3b5f937c8a Mon Sep 17 00:00:00 2001
From: Jie Luo <jieluo@google.com>
Date: Mon, 28 Aug 2023 09:58:39 -0700
Subject: [PATCH 1/2] Fix a bug that strips options from descriptor.proto in
Pure Python.
GetOptions on fields (which parse the _serialized_options) will be called for the first time of parse or serialize instead of Build time.
Note: GetOptions on messages are still called in Build time because of message_set_wire_format. If message options are needed in descriptor.proto, a parse error will be raised in GetOptions(). We can check the file to not invoke GetOptions() for descriptor.proto as long as message_set_wire_format not needed in descriptor.proto.
Other options except message options do not invoke GetOptions() in Build time
PiperOrigin-RevId: 560741182
---
python/google/protobuf/descriptor.py | 18 ++--
python/google/protobuf/descriptor_pool.py | 1 +
.../protobuf/internal/python_message.py | 102 +++++++++++++-----
.../protobuf/internal/reflection_test.py | 43 ++++++++
.../protobuf/internal/unknown_fields_test.py | 33 ------
5 files changed, 128 insertions(+), 69 deletions(-)
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index fcb87cab5..4c03a4669 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -176,14 +176,15 @@ class DescriptorBase(metaclass=DescriptorMetaclass):
raise RuntimeError('Unknown options class name %s!' %
(self._options_class_name))
- with _lock:
- if self._serialized_options is None:
+ if self._serialized_options is None:
+ with _lock:
self._options = options_class()
- else:
- self._options = _ParseOptions(options_class(),
- self._serialized_options)
+ else:
+ options = _ParseOptions(options_class(), self._serialized_options)
+ with _lock:
+ self._options = options
- return self._options
+ return self._options
class _NestedDescriptorBase(DescriptorBase):
@@ -285,6 +286,7 @@ class Descriptor(_NestedDescriptorBase):
oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in
:attr:`oneofs`, but indexed by "name" attribute.
file (FileDescriptor): Reference to file descriptor.
+ is_map_entry: If the message type is a map entry.
"""
@@ -310,6 +312,7 @@ class Descriptor(_NestedDescriptorBase):
serialized_start=None,
serialized_end=None,
syntax=None,
+ is_map_entry=False,
create_key=None):
_message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindMessageTypeByName(full_name)
@@ -322,7 +325,7 @@ class Descriptor(_NestedDescriptorBase):
serialized_options=None,
is_extendable=True, extension_ranges=None, oneofs=None,
file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
- syntax=None, create_key=None):
+ syntax=None, is_map_entry=False, create_key=None):
"""Arguments to __init__() are as described in the description
of Descriptor fields above.
@@ -372,6 +375,7 @@ class Descriptor(_NestedDescriptorBase):
for oneof in self.oneofs:
oneof.containing_type = self
self.syntax = syntax or "proto2"
+ self._is_map_entry = is_map_entry
@property
def fields_by_camelcase_name(self):
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 1ebf11834..04bd7a3e9 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -897,6 +897,7 @@ class DescriptorPool(object):
serialized_start=None,
serialized_end=None,
syntax=syntax,
+ is_map_entry=desc_proto.options.map_entry,
# pylint: disable=protected-access
create_key=descriptor._internal_create_key)
for nested in desc.nested_types:
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index bf9acefd2..a72276d1a 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -182,11 +182,14 @@ class GeneratedProtocolMessageType(type):
% (descriptor.full_name))
return
- cls._decoders_by_tag = {}
+ cls._message_set_decoders_by_tag = {}
+ cls._fields_by_tag = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(descriptor), None)
+ cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
+ decoder.MessageSetItemDecoder(descriptor),
+ None,
+ )
# Attach stuff to each FieldDescriptor for quick lookup later on.
for field in descriptor.fields:
@@ -272,16 +275,36 @@ def _IsMessageSetExtension(field):
def _IsMapField(field):
return (field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type.has_options and
- field.message_type.GetOptions().map_entry)
+ field.message_type._is_map_entry)
def _IsMessageMapField(field):
value_type = field.message_type.fields_by_name['value']
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
-
def _AttachFieldHelpers(cls, field_descriptor):
+ is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
+ field_descriptor._default_constructor = _DefaultValueConstructorForField(
+ field_descriptor
+ )
+
+ def AddFieldByTag(wiretype, is_packed):
+ tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
+ cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
+
+ AddFieldByTag(
+ type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
+ )
+
+ if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
+ # To support wire compatibility of adding packed = true, add a decoder for
+ # packed values regardless of the field's options.
+ AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
+
+
+def _MaybeAddEncoder(cls, field_descriptor):
+ if hasattr(field_descriptor, '_encoder'):
+ return
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
is_map_entry = _IsMapField(field_descriptor)
is_packed = field_descriptor.is_packed
@@ -301,11 +324,17 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_descriptor._encoder = field_encoder
field_descriptor._sizer = sizer
- field_descriptor._default_constructor = _DefaultValueConstructorForField(
- field_descriptor)
- def AddDecoder(wiretype, is_packed):
- tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
+
+def _MaybeAddDecoder(cls, field_descriptor):
+ if hasattr(field_descriptor, '_decoders'):
+ return
+
+ is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
+ is_map_entry = _IsMapField(field_descriptor)
+ field_descriptor._decoders = {}
+
+ def AddDecoder(is_packed):
decode_type = field_descriptor.type
if (decode_type == _FieldDescriptor.TYPE_ENUM and
not field_descriptor.enum_type.is_closed):
@@ -337,15 +366,14 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_descriptor, field_descriptor._default_constructor,
not field_descriptor.has_presence)
- cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
+ field_descriptor._decoders[is_packed] = field_decoder
- AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
- False)
+ AddDecoder(False)
if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
# To support wire compatibility of adding packed = true, add a decoder for
# packed values regardless of the field's options.
- AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
+ AddDecoder(True)
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
@@ -1031,12 +1059,17 @@ def _AddByteSizeMethod(message_descriptor, cls):
size = 0
descriptor = self.DESCRIPTOR
- if descriptor.GetOptions().map_entry:
+ if descriptor._is_map_entry:
# Fields of map entry should always be serialized.
- size = descriptor.fields_by_name['key']._sizer(self.key)
- size += descriptor.fields_by_name['value']._sizer(self.value)
+ key_field = descriptor.fields_by_name['key']
+ _MaybeAddEncoder(cls, key_field)
+ size = key_field._sizer(self.key)
+ value_field = descriptor.fields_by_name['value']
+ _MaybeAddEncoder(cls, value_field)
+ size += value_field._sizer(self.value)
else:
for field_descriptor, field_value in self.ListFields():
+ _MaybeAddEncoder(cls, field_descriptor)
size += field_descriptor._sizer(field_value)
for tag_bytes, value_bytes in self._unknown_fields:
size += len(tag_bytes) + len(value_bytes)
@@ -1079,14 +1112,17 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
deterministic = bool(deterministic)
descriptor = self.DESCRIPTOR
- if descriptor.GetOptions().map_entry:
+ if descriptor._is_map_entry:
# Fields of map entry should always be serialized.
- descriptor.fields_by_name['key']._encoder(
- write_bytes, self.key, deterministic)
- descriptor.fields_by_name['value']._encoder(
- write_bytes, self.value, deterministic)
+ key_field = descriptor.fields_by_name['key']
+ _MaybeAddEncoder(cls, key_field)
+ key_field._encoder(write_bytes, self.key, deterministic)
+ value_field = descriptor.fields_by_name['value']
+ _MaybeAddEncoder(cls, value_field)
+ value_field._encoder(write_bytes, self.value, deterministic)
else:
for field_descriptor, field_value in self.ListFields():
+ _MaybeAddEncoder(cls, field_descriptor)
field_descriptor._encoder(write_bytes, field_value, deterministic)
for tag_bytes, value_bytes in self._unknown_fields:
write_bytes(tag_bytes)
@@ -1114,7 +1150,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
- decoders_by_tag = cls._decoders_by_tag
+ fields_by_tag = cls._fields_by_tag
+ message_set_decoders_by_tag = cls._message_set_decoders_by_tag
def InternalParse(self, buffer, pos, end):
"""Create a message from serialized bytes.
@@ -1137,8 +1174,14 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
unknown_field_set = self._unknown_field_set
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
- field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
- if field_decoder is None:
+ field_decoder, field_des = message_set_decoders_by_tag.get(
+ tag_bytes, (None, None)
+ )
+ if field_decoder:
+ pos = field_decoder(buffer, new_pos, end, self, field_dict)
+ continue
+ field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
+ if field_des is None:
if not self._unknown_fields: # pylint: disable=protected-access
self._unknown_fields = [] # pylint: disable=protected-access
if unknown_field_set is None:
@@ -1167,9 +1210,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
(tag_bytes, buffer[old_pos:new_pos].tobytes()))
pos = new_pos
else:
+ _MaybeAddDecoder(cls, field_des)
+ field_decoder = field_des._decoders[is_packed]
pos = field_decoder(buffer, new_pos, end, self, field_dict)
- if field_desc:
- self._UpdateOneofState(field_desc)
+ if field_des.containing_oneof:
+ self._UpdateOneofState(field_des)
return pos
cls._InternalParse = InternalParse
@@ -1205,8 +1250,7 @@ def _AddIsInitializedMethod(message_descriptor, cls):
for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
- if (field.message_type.has_options and
- field.message_type.GetOptions().map_entry):
+ if (field.message_type._is_map_entry):
continue
for element in value:
if not element.IsInitialized():
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 0708d51e6..c5a600fa9 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -2053,6 +2053,49 @@ class Proto2ReflectionTest(unittest.TestCase):
# dependency on the C++ logging code.
self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
+ def testDescriptorProtoHasFileOptions(self):
+ self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
+ self.assertEqual(
+ descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
+ 'com.google.protobuf',
+ )
+
+ def testDescriptorProtoHasFieldOptions(self):
+ self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
+ self.assertEqual(
+ descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
+ 'com.google.protobuf',
+ )
+ packed_desc = (
+ descriptor_pb2.SourceCodeInfo.DESCRIPTOR.nested_types_by_name.get(
+ 'Location'
+ ).fields_by_name.get('path')
+ )
+ self.assertTrue(packed_desc.has_options)
+ self.assertTrue(packed_desc.GetOptions().packed)
+
+ def testDescriptorProtoHasFeatureOptions(self):
+ self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
+ self.assertEqual(
+ descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
+ 'com.google.protobuf',
+ )
+ presence_desc = descriptor_pb2.FeatureSet.DESCRIPTOR.fields_by_name.get(
+ 'field_presence'
+ )
+ self.assertTrue(presence_desc.has_options)
+ self.assertEqual(
+ presence_desc.GetOptions().retention,
+ descriptor_pb2.FieldOptions.OptionRetention.RETENTION_RUNTIME,
+ )
+ self.assertListsEqual(
+ presence_desc.GetOptions().targets,
+ [
+ descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FIELD,
+ descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FILE,
+ ],
+ )
+
def testStringUTF8Serialization(self):
proto = message_set_extensions_pb2.TestMessageSet()
extension_message = message_set_extensions_pb2.TestMessageSetExtension2
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index ec1aa1b45..9a8d7d751 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -177,25 +177,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- # InternalCheckUnknownField() is an additional Pure Python check which checks
- # a detail of unknown fields. It cannot be used by the C++
- # implementation because some protect members are called.
- # The test is added for historical reasons. It is not necessary as
- # serialized string is checked.
- # TODO(jieluo): Remove message._unknown_fields.
- def InternalCheckUnknownField(self, name, expected_value):
- if api_implementation.Type() != 'python':
- return
- field_descriptor = self.descriptor.fields_by_name[name]
- wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
- field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
- result_dict = {}
- for tag_bytes, value in self.empty_message._unknown_fields:
- if tag_bytes == field_tag:
- decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
- decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
- self.assertEqual(expected_value, result_dict[field_descriptor])
-
def CheckUnknownField(self, name, unknown_field_set, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
@@ -223,50 +204,36 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.CheckUnknownField('optional_nested_enum',
unknown_field_set,
self.all_fields.optional_nested_enum)
- self.InternalCheckUnknownField('optional_nested_enum',
- self.all_fields.optional_nested_enum)
# Test repeated enum.
self.CheckUnknownField('repeated_nested_enum',
unknown_field_set,
self.all_fields.repeated_nested_enum)
- self.InternalCheckUnknownField('repeated_nested_enum',
- self.all_fields.repeated_nested_enum)
# Test varint.
self.CheckUnknownField('optional_int32',
unknown_field_set,
self.all_fields.optional_int32)
- self.InternalCheckUnknownField('optional_int32',
- self.all_fields.optional_int32)
# Test fixed32.
self.CheckUnknownField('optional_fixed32',
unknown_field_set,
self.all_fields.optional_fixed32)
- self.InternalCheckUnknownField('optional_fixed32',
- self.all_fields.optional_fixed32)
# Test fixed64.
self.CheckUnknownField('optional_fixed64',
unknown_field_set,
self.all_fields.optional_fixed64)
- self.InternalCheckUnknownField('optional_fixed64',
- self.all_fields.optional_fixed64)
# Test length delimited.
self.CheckUnknownField('optional_string',
unknown_field_set,
self.all_fields.optional_string.encode('utf-8'))
- self.InternalCheckUnknownField('optional_string',
- self.all_fields.optional_string)
# Test group.
self.CheckUnknownField('optionalgroup',
unknown_field_set,
(17, 0, 117))
- self.InternalCheckUnknownField('optionalgroup',
- self.all_fields.optionalgroup)
self.assertEqual(98, len(unknown_field_set))
--
2.49.0