429 lines
18 KiB
Diff
429 lines
18 KiB
Diff
From 3b0c1b873793da29ff46a42d448cda3b5f937c8a Mon Sep 17 00:00:00 2001
|
|
From: Jie Luo <jieluo@google.com>
|
|
Date: Mon, 28 Aug 2023 09:58:39 -0700
|
|
Subject: [PATCH 1/2] Fix a bug that strips options from descriptor.proto in
|
|
Pure Python.
|
|
|
|
GetOptions on fields (which parse the _serialized_options) will be called for the first time of parse or serialize instead of Build time.
|
|
|
|
Note: GetOptions on messages are still called in Build time because of message_set_wire_format. If message options are needed in descriptor.proto, a parse error will be raised in GetOptions(). We can check the file to not invoke GetOptions() for descriptor.proto as long as message_set_wire_format not needed in descriptor.proto.
|
|
|
|
Other options except message options do not invoke GetOptions() in Build time
|
|
|
|
PiperOrigin-RevId: 560741182
|
|
---
|
|
python/google/protobuf/descriptor.py | 18 ++--
|
|
python/google/protobuf/descriptor_pool.py | 1 +
|
|
.../protobuf/internal/python_message.py | 102 +++++++++++++-----
|
|
.../protobuf/internal/reflection_test.py | 43 ++++++++
|
|
.../protobuf/internal/unknown_fields_test.py | 33 ------
|
|
5 files changed, 128 insertions(+), 69 deletions(-)
|
|
|
|
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
|
|
index fcb87cab5..4c03a4669 100755
|
|
--- a/python/google/protobuf/descriptor.py
|
|
+++ b/python/google/protobuf/descriptor.py
|
|
@@ -176,14 +176,15 @@ class DescriptorBase(metaclass=DescriptorMetaclass):
|
|
raise RuntimeError('Unknown options class name %s!' %
|
|
(self._options_class_name))
|
|
|
|
- with _lock:
|
|
- if self._serialized_options is None:
|
|
+ if self._serialized_options is None:
|
|
+ with _lock:
|
|
self._options = options_class()
|
|
- else:
|
|
- self._options = _ParseOptions(options_class(),
|
|
- self._serialized_options)
|
|
+ else:
|
|
+ options = _ParseOptions(options_class(), self._serialized_options)
|
|
+ with _lock:
|
|
+ self._options = options
|
|
|
|
- return self._options
|
|
+ return self._options
|
|
|
|
|
|
class _NestedDescriptorBase(DescriptorBase):
|
|
@@ -285,6 +286,7 @@ class Descriptor(_NestedDescriptorBase):
|
|
oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in
|
|
:attr:`oneofs`, but indexed by "name" attribute.
|
|
file (FileDescriptor): Reference to file descriptor.
|
|
+ is_map_entry: If the message type is a map entry.
|
|
|
|
"""
|
|
|
|
@@ -310,6 +312,7 @@ class Descriptor(_NestedDescriptorBase):
|
|
serialized_start=None,
|
|
serialized_end=None,
|
|
syntax=None,
|
|
+ is_map_entry=False,
|
|
create_key=None):
|
|
_message.Message._CheckCalledFromGeneratedFile()
|
|
return _message.default_pool.FindMessageTypeByName(full_name)
|
|
@@ -322,7 +325,7 @@ class Descriptor(_NestedDescriptorBase):
|
|
serialized_options=None,
|
|
is_extendable=True, extension_ranges=None, oneofs=None,
|
|
file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
|
|
- syntax=None, create_key=None):
|
|
+ syntax=None, is_map_entry=False, create_key=None):
|
|
"""Arguments to __init__() are as described in the description
|
|
of Descriptor fields above.
|
|
|
|
@@ -372,6 +375,7 @@ class Descriptor(_NestedDescriptorBase):
|
|
for oneof in self.oneofs:
|
|
oneof.containing_type = self
|
|
self.syntax = syntax or "proto2"
|
|
+ self._is_map_entry = is_map_entry
|
|
|
|
@property
|
|
def fields_by_camelcase_name(self):
|
|
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
|
|
index 1ebf11834..04bd7a3e9 100644
|
|
--- a/python/google/protobuf/descriptor_pool.py
|
|
+++ b/python/google/protobuf/descriptor_pool.py
|
|
@@ -897,6 +897,7 @@ class DescriptorPool(object):
|
|
serialized_start=None,
|
|
serialized_end=None,
|
|
syntax=syntax,
|
|
+ is_map_entry=desc_proto.options.map_entry,
|
|
# pylint: disable=protected-access
|
|
create_key=descriptor._internal_create_key)
|
|
for nested in desc.nested_types:
|
|
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
|
|
index bf9acefd2..a72276d1a 100755
|
|
--- a/python/google/protobuf/internal/python_message.py
|
|
+++ b/python/google/protobuf/internal/python_message.py
|
|
@@ -182,11 +182,14 @@ class GeneratedProtocolMessageType(type):
|
|
% (descriptor.full_name))
|
|
return
|
|
|
|
- cls._decoders_by_tag = {}
|
|
+ cls._message_set_decoders_by_tag = {}
|
|
+ cls._fields_by_tag = {}
|
|
if (descriptor.has_options and
|
|
descriptor.GetOptions().message_set_wire_format):
|
|
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
|
|
- decoder.MessageSetItemDecoder(descriptor), None)
|
|
+ cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
|
|
+ decoder.MessageSetItemDecoder(descriptor),
|
|
+ None,
|
|
+ )
|
|
|
|
# Attach stuff to each FieldDescriptor for quick lookup later on.
|
|
for field in descriptor.fields:
|
|
@@ -272,16 +275,36 @@ def _IsMessageSetExtension(field):
|
|
|
|
def _IsMapField(field):
|
|
return (field.type == _FieldDescriptor.TYPE_MESSAGE and
|
|
- field.message_type.has_options and
|
|
- field.message_type.GetOptions().map_entry)
|
|
+ field.message_type._is_map_entry)
|
|
|
|
|
|
def _IsMessageMapField(field):
|
|
value_type = field.message_type.fields_by_name['value']
|
|
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
|
|
|
|
-
|
|
def _AttachFieldHelpers(cls, field_descriptor):
|
|
+ is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
|
|
+ field_descriptor._default_constructor = _DefaultValueConstructorForField(
|
|
+ field_descriptor
|
|
+ )
|
|
+
|
|
+ def AddFieldByTag(wiretype, is_packed):
|
|
+ tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
|
|
+ cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
|
|
+
|
|
+ AddFieldByTag(
|
|
+ type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
|
|
+ )
|
|
+
|
|
+ if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
|
|
+ # To support wire compatibility of adding packed = true, add a decoder for
|
|
+ # packed values regardless of the field's options.
|
|
+ AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
|
|
+
|
|
+
|
|
+def _MaybeAddEncoder(cls, field_descriptor):
|
|
+ if hasattr(field_descriptor, '_encoder'):
|
|
+ return
|
|
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
|
|
is_map_entry = _IsMapField(field_descriptor)
|
|
is_packed = field_descriptor.is_packed
|
|
@@ -301,11 +324,17 @@ def _AttachFieldHelpers(cls, field_descriptor):
|
|
|
|
field_descriptor._encoder = field_encoder
|
|
field_descriptor._sizer = sizer
|
|
- field_descriptor._default_constructor = _DefaultValueConstructorForField(
|
|
- field_descriptor)
|
|
|
|
- def AddDecoder(wiretype, is_packed):
|
|
- tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
|
|
+
|
|
+def _MaybeAddDecoder(cls, field_descriptor):
|
|
+ if hasattr(field_descriptor, '_decoders'):
|
|
+ return
|
|
+
|
|
+ is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
|
|
+ is_map_entry = _IsMapField(field_descriptor)
|
|
+ field_descriptor._decoders = {}
|
|
+
|
|
+ def AddDecoder(is_packed):
|
|
decode_type = field_descriptor.type
|
|
if (decode_type == _FieldDescriptor.TYPE_ENUM and
|
|
not field_descriptor.enum_type.is_closed):
|
|
@@ -337,15 +366,14 @@ def _AttachFieldHelpers(cls, field_descriptor):
|
|
field_descriptor, field_descriptor._default_constructor,
|
|
not field_descriptor.has_presence)
|
|
|
|
- cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
|
|
+ field_descriptor._decoders[is_packed] = field_decoder
|
|
|
|
- AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
|
|
- False)
|
|
+ AddDecoder(False)
|
|
|
|
if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
|
|
# To support wire compatibility of adding packed = true, add a decoder for
|
|
# packed values regardless of the field's options.
|
|
- AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
|
|
+ AddDecoder(True)
|
|
|
|
|
|
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
|
|
@@ -1031,12 +1059,17 @@ def _AddByteSizeMethod(message_descriptor, cls):
|
|
|
|
size = 0
|
|
descriptor = self.DESCRIPTOR
|
|
- if descriptor.GetOptions().map_entry:
|
|
+ if descriptor._is_map_entry:
|
|
# Fields of map entry should always be serialized.
|
|
- size = descriptor.fields_by_name['key']._sizer(self.key)
|
|
- size += descriptor.fields_by_name['value']._sizer(self.value)
|
|
+ key_field = descriptor.fields_by_name['key']
|
|
+ _MaybeAddEncoder(cls, key_field)
|
|
+ size = key_field._sizer(self.key)
|
|
+ value_field = descriptor.fields_by_name['value']
|
|
+ _MaybeAddEncoder(cls, value_field)
|
|
+ size += value_field._sizer(self.value)
|
|
else:
|
|
for field_descriptor, field_value in self.ListFields():
|
|
+ _MaybeAddEncoder(cls, field_descriptor)
|
|
size += field_descriptor._sizer(field_value)
|
|
for tag_bytes, value_bytes in self._unknown_fields:
|
|
size += len(tag_bytes) + len(value_bytes)
|
|
@@ -1079,14 +1112,17 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
|
|
deterministic = bool(deterministic)
|
|
|
|
descriptor = self.DESCRIPTOR
|
|
- if descriptor.GetOptions().map_entry:
|
|
+ if descriptor._is_map_entry:
|
|
# Fields of map entry should always be serialized.
|
|
- descriptor.fields_by_name['key']._encoder(
|
|
- write_bytes, self.key, deterministic)
|
|
- descriptor.fields_by_name['value']._encoder(
|
|
- write_bytes, self.value, deterministic)
|
|
+ key_field = descriptor.fields_by_name['key']
|
|
+ _MaybeAddEncoder(cls, key_field)
|
|
+ key_field._encoder(write_bytes, self.key, deterministic)
|
|
+ value_field = descriptor.fields_by_name['value']
|
|
+ _MaybeAddEncoder(cls, value_field)
|
|
+ value_field._encoder(write_bytes, self.value, deterministic)
|
|
else:
|
|
for field_descriptor, field_value in self.ListFields():
|
|
+ _MaybeAddEncoder(cls, field_descriptor)
|
|
field_descriptor._encoder(write_bytes, field_value, deterministic)
|
|
for tag_bytes, value_bytes in self._unknown_fields:
|
|
write_bytes(tag_bytes)
|
|
@@ -1114,7 +1150,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
|
|
|
|
local_ReadTag = decoder.ReadTag
|
|
local_SkipField = decoder.SkipField
|
|
- decoders_by_tag = cls._decoders_by_tag
|
|
+ fields_by_tag = cls._fields_by_tag
|
|
+ message_set_decoders_by_tag = cls._message_set_decoders_by_tag
|
|
|
|
def InternalParse(self, buffer, pos, end):
|
|
"""Create a message from serialized bytes.
|
|
@@ -1137,8 +1174,14 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
|
|
unknown_field_set = self._unknown_field_set
|
|
while pos != end:
|
|
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
|
|
- field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
|
|
- if field_decoder is None:
|
|
+ field_decoder, field_des = message_set_decoders_by_tag.get(
|
|
+ tag_bytes, (None, None)
|
|
+ )
|
|
+ if field_decoder:
|
|
+ pos = field_decoder(buffer, new_pos, end, self, field_dict)
|
|
+ continue
|
|
+ field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
|
|
+ if field_des is None:
|
|
if not self._unknown_fields: # pylint: disable=protected-access
|
|
self._unknown_fields = [] # pylint: disable=protected-access
|
|
if unknown_field_set is None:
|
|
@@ -1167,9 +1210,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
|
|
(tag_bytes, buffer[old_pos:new_pos].tobytes()))
|
|
pos = new_pos
|
|
else:
|
|
+ _MaybeAddDecoder(cls, field_des)
|
|
+ field_decoder = field_des._decoders[is_packed]
|
|
pos = field_decoder(buffer, new_pos, end, self, field_dict)
|
|
- if field_desc:
|
|
- self._UpdateOneofState(field_desc)
|
|
+ if field_des.containing_oneof:
|
|
+ self._UpdateOneofState(field_des)
|
|
return pos
|
|
cls._InternalParse = InternalParse
|
|
|
|
@@ -1205,8 +1250,7 @@ def _AddIsInitializedMethod(message_descriptor, cls):
|
|
for field, value in list(self._fields.items()): # dict can change size!
|
|
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
|
|
if field.label == _FieldDescriptor.LABEL_REPEATED:
|
|
- if (field.message_type.has_options and
|
|
- field.message_type.GetOptions().map_entry):
|
|
+ if (field.message_type._is_map_entry):
|
|
continue
|
|
for element in value:
|
|
if not element.IsInitialized():
|
|
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
|
|
index 0708d51e6..c5a600fa9 100755
|
|
--- a/python/google/protobuf/internal/reflection_test.py
|
|
+++ b/python/google/protobuf/internal/reflection_test.py
|
|
@@ -2053,6 +2053,49 @@ class Proto2ReflectionTest(unittest.TestCase):
|
|
# dependency on the C++ logging code.
|
|
self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
|
|
|
|
+ def testDescriptorProtoHasFileOptions(self):
|
|
+ self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
|
|
+ self.assertEqual(
|
|
+ descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
|
|
+ 'com.google.protobuf',
|
|
+ )
|
|
+
|
|
+ def testDescriptorProtoHasFieldOptions(self):
|
|
+ self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
|
|
+ self.assertEqual(
|
|
+ descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
|
|
+ 'com.google.protobuf',
|
|
+ )
|
|
+ packed_desc = (
|
|
+ descriptor_pb2.SourceCodeInfo.DESCRIPTOR.nested_types_by_name.get(
|
|
+ 'Location'
|
|
+ ).fields_by_name.get('path')
|
|
+ )
|
|
+ self.assertTrue(packed_desc.has_options)
|
|
+ self.assertTrue(packed_desc.GetOptions().packed)
|
|
+
|
|
+ def testDescriptorProtoHasFeatureOptions(self):
|
|
+ self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
|
|
+ self.assertEqual(
|
|
+ descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
|
|
+ 'com.google.protobuf',
|
|
+ )
|
|
+ presence_desc = descriptor_pb2.FeatureSet.DESCRIPTOR.fields_by_name.get(
|
|
+ 'field_presence'
|
|
+ )
|
|
+ self.assertTrue(presence_desc.has_options)
|
|
+ self.assertEqual(
|
|
+ presence_desc.GetOptions().retention,
|
|
+ descriptor_pb2.FieldOptions.OptionRetention.RETENTION_RUNTIME,
|
|
+ )
|
|
+ self.assertListsEqual(
|
|
+ presence_desc.GetOptions().targets,
|
|
+ [
|
|
+ descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FIELD,
|
|
+ descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FILE,
|
|
+ ],
|
|
+ )
|
|
+
|
|
def testStringUTF8Serialization(self):
|
|
proto = message_set_extensions_pb2.TestMessageSet()
|
|
extension_message = message_set_extensions_pb2.TestMessageSetExtension2
|
|
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
|
|
index ec1aa1b45..9a8d7d751 100755
|
|
--- a/python/google/protobuf/internal/unknown_fields_test.py
|
|
+++ b/python/google/protobuf/internal/unknown_fields_test.py
|
|
@@ -177,25 +177,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
|
|
self.empty_message = unittest_pb2.TestEmptyMessage()
|
|
self.empty_message.ParseFromString(self.all_fields_data)
|
|
|
|
- # InternalCheckUnknownField() is an additional Pure Python check which checks
|
|
- # a detail of unknown fields. It cannot be used by the C++
|
|
- # implementation because some protect members are called.
|
|
- # The test is added for historical reasons. It is not necessary as
|
|
- # serialized string is checked.
|
|
- # TODO(jieluo): Remove message._unknown_fields.
|
|
- def InternalCheckUnknownField(self, name, expected_value):
|
|
- if api_implementation.Type() != 'python':
|
|
- return
|
|
- field_descriptor = self.descriptor.fields_by_name[name]
|
|
- wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
|
|
- field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
|
|
- result_dict = {}
|
|
- for tag_bytes, value in self.empty_message._unknown_fields:
|
|
- if tag_bytes == field_tag:
|
|
- decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
|
|
- decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
|
|
- self.assertEqual(expected_value, result_dict[field_descriptor])
|
|
-
|
|
def CheckUnknownField(self, name, unknown_field_set, expected_value):
|
|
field_descriptor = self.descriptor.fields_by_name[name]
|
|
expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
|
|
@@ -223,50 +204,36 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
|
|
self.CheckUnknownField('optional_nested_enum',
|
|
unknown_field_set,
|
|
self.all_fields.optional_nested_enum)
|
|
- self.InternalCheckUnknownField('optional_nested_enum',
|
|
- self.all_fields.optional_nested_enum)
|
|
|
|
# Test repeated enum.
|
|
self.CheckUnknownField('repeated_nested_enum',
|
|
unknown_field_set,
|
|
self.all_fields.repeated_nested_enum)
|
|
- self.InternalCheckUnknownField('repeated_nested_enum',
|
|
- self.all_fields.repeated_nested_enum)
|
|
|
|
# Test varint.
|
|
self.CheckUnknownField('optional_int32',
|
|
unknown_field_set,
|
|
self.all_fields.optional_int32)
|
|
- self.InternalCheckUnknownField('optional_int32',
|
|
- self.all_fields.optional_int32)
|
|
|
|
# Test fixed32.
|
|
self.CheckUnknownField('optional_fixed32',
|
|
unknown_field_set,
|
|
self.all_fields.optional_fixed32)
|
|
- self.InternalCheckUnknownField('optional_fixed32',
|
|
- self.all_fields.optional_fixed32)
|
|
|
|
# Test fixed64.
|
|
self.CheckUnknownField('optional_fixed64',
|
|
unknown_field_set,
|
|
self.all_fields.optional_fixed64)
|
|
- self.InternalCheckUnknownField('optional_fixed64',
|
|
- self.all_fields.optional_fixed64)
|
|
|
|
# Test length delimited.
|
|
self.CheckUnknownField('optional_string',
|
|
unknown_field_set,
|
|
self.all_fields.optional_string.encode('utf-8'))
|
|
- self.InternalCheckUnknownField('optional_string',
|
|
- self.all_fields.optional_string)
|
|
|
|
# Test group.
|
|
self.CheckUnknownField('optionalgroup',
|
|
unknown_field_set,
|
|
(17, 0, 117))
|
|
- self.InternalCheckUnknownField('optionalgroup',
|
|
- self.all_fields.optionalgroup)
|
|
|
|
self.assertEqual(98, len(unknown_field_set))
|
|
|
|
--
|
|
2.49.0
|
|
|