Files
protobuf/internal-pure-python-fixes.patch
John Paul Adrian Glaubitz 5984203f71 Backport fixes for CVE-2025-4565
- Add internal-pure-python-fixes.patch to backport changes required for CVE fix
- Add CVE-2025-4565.patch to fix parsing of untrusted Protocol Buffers
  data containing an arbitrary number of recursive groups or messages
  can lead to crash due to RecursionError (bsc#1244663, CVE-2025-4565)
2025-11-14 15:35:13 +01:00

422 lines
16 KiB
Diff

From dac2e91e36408087d769be89a72fbafe1ea5039c Mon Sep 17 00:00:00 2001
From: Protobuf Team Bot <protobuf-github-bot@google.com>
Date: Tue, 4 Mar 2025 13:16:32 -0800
Subject: [PATCH 1/2] Internal pure python fixes
PiperOrigin-RevId: 733441339
---
python/google/protobuf/internal/decoder.py | 98 ++++++++++++++-----
.../google/protobuf/internal/message_test.py | 1 +
.../protobuf/internal/python_message.py | 7 +-
.../protobuf/internal/self_recursive.proto | 9 +-
4 files changed, 86 insertions(+), 29 deletions(-)
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index dcde1d942..89d829142 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -184,7 +184,10 @@ def _SimpleDecoder(wire_type, decode_value):
clear_if_default=False):
if is_packed:
local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
+ def DecodePackedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -199,11 +202,15 @@ def _SimpleDecoder(wire_type, decode_value):
del value[-1] # Discard corrupt value.
raise _DecodeError('Packed element was truncated.')
return pos
+
return DecodePackedField
elif is_repeated:
tag_bytes = encoder.TagBytes(field_number, wire_type)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -218,9 +225,12 @@ def _SimpleDecoder(wire_type, decode_value):
if new_pos > end:
raise _DecodeError('Truncated message.')
return new_pos
+
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
(new_value, pos) = decode_value(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
@@ -229,6 +239,7 @@ def _SimpleDecoder(wire_type, decode_value):
else:
field_dict[key] = new_value
return pos
+
return DecodeField
return SpecificDecoder
@@ -364,7 +375,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
enum_type = key.enum_type
if is_packed:
local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
+ def DecodePackedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
"""Decode serialized packed enum to its value and a new position.
Args:
@@ -377,6 +390,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
Returns:
int, new position in serialized data.
"""
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -407,11 +421,14 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
# pylint: enable=protected-access
raise _DecodeError('Packed element was truncated.')
return pos
+
return DecodePackedField
elif is_repeated:
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
"""Decode serialized repeated enum to its value and a new position.
Args:
@@ -424,6 +441,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
Returns:
int, new position in serialized data.
"""
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -446,9 +464,11 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
if new_pos > end:
raise _DecodeError('Truncated message.')
return new_pos
+
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
"""Decode serialized repeated enum to its value and a new position.
Args:
@@ -461,6 +481,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
Returns:
int, new position in serialized data.
"""
+ del current_depth # unused
value_start_pos = pos
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
@@ -480,6 +501,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
(tag_bytes, buffer[value_start_pos:pos].tobytes()))
# pylint: enable=protected-access
return pos
+
return DecodeField
@@ -538,7 +560,10 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -553,9 +578,12 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
+
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
@@ -565,6 +593,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
else:
field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
return new_pos
+
return DecodeField
@@ -579,7 +608,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
+ del current_depth # unused
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -594,9 +626,12 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
+
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # unused
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
@@ -606,6 +641,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
else:
field_dict[key] = buffer[pos:new_pos].tobytes()
return new_pos
+
return DecodeField
@@ -621,7 +657,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_START_GROUP)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -630,7 +668,7 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
- pos = value.add()._InternalParse(buffer, pos, end)
+ pos = value.add()._InternalParse(buffer, pos, end, current_depth)
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -640,19 +678,22 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
+
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
- pos = value._InternalParse(buffer, pos, end)
+ pos = value._InternalParse(buffer, pos, end, current_depth)
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
raise _DecodeError('Missing group end tag.')
return new_pos
+
return DecodeField
@@ -666,7 +707,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ def DecodeRepeatedField(
+ buffer, pos, end, message, field_dict, current_depth=0
+ ):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -677,7 +720,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
- if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+ if (
+ value.add()._InternalParse(buffer, pos, new_pos, current_depth)
+ != new_pos
+ ):
# The only reason _InternalParse would return early is if it
# encountered an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
@@ -686,9 +732,11 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
+
return DecodeRepeatedField
else:
- def DecodeField(buffer, pos, end, message, field_dict):
+
+ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -698,11 +746,12 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
- if value._InternalParse(buffer, pos, new_pos) != new_pos:
+ if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
return new_pos
+
return DecodeField
@@ -851,7 +900,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
# Can't read _concrete_class yet; might not be initialized.
message_type = field_descriptor.message_type
- def DecodeMap(buffer, pos, end, message, field_dict):
+ def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
+ del current_depth # Unused.
submsg = message_type._concrete_class()
value = field_dict.get(key)
if value is None:
@@ -934,7 +984,7 @@ def _SkipGroup(buffer, pos, end):
pos = new_pos
-def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
+def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
unknown_field_set = containers.UnknownFieldSet()
@@ -944,14 +994,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
field_number, wire_type = wire_format.UnpackTag(tag)
if wire_type == wire_format.WIRETYPE_END_GROUP:
break
- (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+ (data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
# pylint: disable=protected-access
unknown_field_set._add(field_number, wire_type, data)
return (unknown_field_set, pos)
-def _DecodeUnknownField(buffer, pos, wire_type):
+def _DecodeUnknownField(
+ buffer, pos, wire_type, current_depth=0
+):
"""Decode a unknown field. Returns the UnknownField and new position."""
if wire_type == wire_format.WIRETYPE_VARINT:
@@ -965,7 +1017,7 @@ def _DecodeUnknownField(buffer, pos, wire_type):
data = buffer[pos:pos+size].tobytes()
pos += size
elif wire_type == wire_format.WIRETYPE_START_GROUP:
- (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
+ (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
elif wire_type == wire_format.WIRETYPE_END_GROUP:
return (0, -1)
else:
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 2a723eabb..48e6df806 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -30,6 +30,7 @@ import warnings
cmp = lambda x, y: (x > y) - (x < y)
from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
+from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import enum_type_wrapper
from google.protobuf.internal import more_extensions_pb2
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index fabc6aa07..62c059cd2 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -1194,7 +1194,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
fields_by_tag = cls._fields_by_tag
message_set_decoders_by_tag = cls._message_set_decoders_by_tag
- def InternalParse(self, buffer, pos, end):
+ def InternalParse(self, buffer, pos, end, current_depth=0):
"""Create a message from serialized bytes.
Args:
@@ -1244,10 +1244,13 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
else:
_MaybeAddDecoder(cls, field_des)
field_decoder = field_des._decoders[is_packed]
- pos = field_decoder(buffer, new_pos, end, self, field_dict)
+ pos = field_decoder(
+ buffer, new_pos, end, self, field_dict, current_depth
+ )
if field_des.containing_oneof:
self._UpdateOneofState(field_des)
return pos
+
cls._InternalParse = InternalParse
diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto
index dbfcaf971..20bc2b4d3 100644
--- a/python/google/protobuf/internal/self_recursive.proto
+++ b/python/google/protobuf/internal/self_recursive.proto
@@ -5,18 +5,19 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
-syntax = "proto2";
+edition = "2023";
package google.protobuf.python.internal;
message SelfRecursive {
- optional SelfRecursive sub = 1;
+ SelfRecursive sub = 1;
+ int32 i = 2;
}
message IndirectRecursive {
- optional IntermediateRecursive intermediate = 1;
+ IntermediateRecursive intermediate = 1;
}
message IntermediateRecursive {
- optional IndirectRecursive indirect = 1;
+ IndirectRecursive indirect = 1;
}
--
2.51.1