From f1d701cd2c411ee40bb1fe383afe7f365f35abf0 Mon Sep 17 00:00:00 2001 From: Andreas Eriksen Date: Thu, 18 Dec 2025 16:48:26 +0100 Subject: [PATCH] Merge commit from fork * track depth of recursive encode/decode, clear shared refs on start * test that shared refs are cleared on start * add fix-shared-state-reset to version history * clear shared state _after_ encode/decode * use PY_SSIZE_T_MAX to clear shareables list * use context manager for python decoder depth tracking * use context manager for python encoder depth tracking --- cbor2/_decoder.py | 38 +++++++++++++++++----- cbor2/_encoder.py | 44 +++++++++++++++++++++----- docs/versionhistory.rst | 5 +++ source/decoder.c | 28 ++++++++++++++++- source/decoder.h | 1 + source/encoder.c | 23 ++++++++++++-- source/encoder.h | 1 + tests/test_decoder.py | 62 ++++++++++++++++++++++++++++++++++++ tests/test_encoder.py | 70 +++++++++++++++++++++++++++++++++++++++++ 9 files changed, 255 insertions(+), 17 deletions(-) Index: cbor2-5.6.5/cbor2/_decoder.py =================================================================== --- cbor2-5.6.5.orig/cbor2/_decoder.py +++ cbor2-5.6.5/cbor2/_decoder.py @@ -5,6 +5,7 @@ import struct import sys from codecs import getincrementaldecoder from collections.abc import Callable, Mapping, Sequence +from contextlib import contextmanager from datetime import date, datetime, timedelta, timezone from io import BytesIO from typing import IO, TYPE_CHECKING, Any, TypeVar, cast, overload @@ -59,6 +60,7 @@ class CBORDecoder: "_immutable", "_str_errors", "_stringref_namespace", + "_decode_depth", ) _fp: IO[bytes] @@ -100,6 +102,7 @@ class CBORDecoder: self._shareables: list[object] = [] self._stringref_namespace: list[str | bytes] | None = None self._immutable = False + self._decode_depth = 0 @property def immutable(self) -> bool: @@ -225,13 +228,33 @@ class CBORDecoder: if unshared: self._share_index = old_index + @contextmanager + def _decoding_context(self): + """ + Context manager for tracking decode depth and clearing shared state. + + Shared state is cleared at the end of each top-level decode to prevent + shared references from leaking between independent decode operations. + Nested calls (from hooks) must preserve the state. + """ + self._decode_depth += 1 + try: + yield + finally: + self._decode_depth -= 1 + assert self._decode_depth >= 0 + if self._decode_depth == 0: + self._shareables.clear() + self._share_index = None + def decode(self) -> object: """ Decode the next value from the stream. :raises CBORDecodeError: if there is any problem decoding the stream """ - return self._decode() + with self._decoding_context(): + return self._decode() def decode_from_bytes(self, buf: bytes) -> object: """ @@ -242,12 +265,13 @@ class CBORDecoder: object needs to be decoded separately from the rest but while still taking advantage of the shared value registry. """ - with BytesIO(buf) as fp: - old_fp = self.fp - self.fp = fp - retval = self._decode() - self.fp = old_fp - return retval + with self._decoding_context(): + with BytesIO(buf) as fp: + old_fp = self.fp + self.fp = fp + retval = self._decode() + self.fp = old_fp + return retval @overload def _decode_length(self, subtype: int) -> int: ... Index: cbor2-5.6.5/cbor2/_encoder.py =================================================================== --- cbor2-5.6.5.orig/cbor2/_encoder.py +++ cbor2-5.6.5/cbor2/_encoder.py @@ -123,6 +123,7 @@ class CBOREncoder: "string_referencing", "string_namespacing", "_string_references", + "_encode_depth", ) _fp: IO[bytes] @@ -183,6 +184,7 @@ class CBOREncoder: int, tuple[object, int | None] ] = {} # indexes used for value sharing self._string_references: dict[str | bytes, int] = {} # indexes used for string references + self._encode_depth = 0 self._encoders = default_encoders.copy() if canonical: self._encoders.update(canonical_encoders) @@ -298,6 +300,24 @@ class CBOREncoder: """ self._fp_write(data) + @contextmanager + def _encoding_context(self): + """ + Context manager for tracking encode depth and clearing shared state. + + Shared state is cleared at the end of each top-level encode to prevent + shared references from leaking between independent encode operations. + Nested calls (from hooks) must preserve the state. + """ + self._encode_depth += 1 + try: + yield + finally: + self._encode_depth -= 1 + if self._encode_depth == 0: + self._shared_containers.clear() + self._string_references.clear() + def encode(self, obj: Any) -> None: """ Encode the given object using CBOR. @@ -305,6 +325,16 @@ class CBOREncoder: :param obj: the object to encode """ + with self._encoding_context(): + self._encode_value(obj) + + def _encode_value(self, obj: Any) -> None: + """ + Internal fast path for encoding - used by built-in encoders. + + External code should use encode() instead, which properly manages + shared state between independent encode operations. + """ obj_type = obj.__class__ encoder = self._encoders.get(obj_type) or self._find_encoder(obj_type) or self._default if not encoder: @@ -448,14 +478,14 @@ class CBOREncoder: def encode_array(self, value: Sequence[Any]) -> None: self.encode_length(4, len(value)) for item in value: - self.encode(item) + self._encode_value(item) @container_encoder def encode_map(self, value: Mapping[Any, Any]) -> None: self.encode_length(5, len(value)) for key, val in value.items(): - self.encode(key) - self.encode(val) + self._encode_value(key) + self._encode_value(val) def encode_sortable_key(self, value: Any) -> tuple[int, bytes]: """ @@ -477,10 +507,10 @@ class CBOREncoder: # String referencing requires that the order encoded is # the same as the order emitted so string references are # generated after an order is determined - self.encode(realkey) + self._encode_value(realkey) else: self._fp_write(sortkey[1]) - self.encode(value) + self._encode_value(value) def encode_semantic(self, value: CBORTag) -> None: # Nested string reference domains are distinct @@ -491,7 +521,7 @@ class CBOREncoder: self._string_references = {} self.encode_length(6, value.tag) - self.encode(value.value) + self._encode_value(value.value) self.string_referencing = old_string_referencing self._string_references = old_string_references @@ -554,7 +584,7 @@ class CBOREncoder: def encode_stringref(self, value: str | bytes) -> None: # Semantic tag 25 if not self._stringref(value): - self.encode(value) + self._encode_value(value) def encode_rational(self, value: Fraction) -> None: # Semantic tag 30 Index: cbor2-5.6.5/source/decoder.c =================================================================== --- cbor2-5.6.5.orig/source/decoder.c +++ cbor2-5.6.5/source/decoder.c @@ -142,6 +142,7 @@ CBORDecoder_new(PyTypeObject *type, PyOb self->str_errors = PyBytes_FromString("strict"); self->immutable = false; self->shared_index = -1; + self->decode_depth = 0; } return (PyObject *) self; error: @@ -2058,11 +2059,30 @@ decode(CBORDecoderObject *self, DecodeOp } +// Reset shared state at the end of each top-level decode to prevent +// shared references from leaking between independent decode operations. +// Nested calls (from hooks) must preserve the state. +static inline void +clear_shareable_state(CBORDecoderObject *self) +{ + PyList_SetSlice(self->shareables, 0, PY_SSIZE_T_MAX, NULL); + self->shared_index = -1; +} + + // CBORDecoder.decode(self) -> obj PyObject * CBORDecoder_decode(CBORDecoderObject *self) { - return decode(self, DECODE_NORMAL); + PyObject *ret; + self->decode_depth++; + ret = decode(self, DECODE_NORMAL); + self->decode_depth--; + assert(self->decode_depth >= 0); + if (self->decode_depth == 0) { + clear_shareable_state(self); + } + return ret; } @@ -2075,6 +2095,7 @@ CBORDecoder_decode_from_bytes(CBORDecode if (!_CBOR2_BytesIO && _CBOR2_init_BytesIO() == -1) return NULL; + self->decode_depth++; save_read = self->read; buf = PyObject_CallFunctionObjArgs(_CBOR2_BytesIO, data, NULL); if (buf) { @@ -2086,6 +2107,11 @@ CBORDecoder_decode_from_bytes(CBORDecode Py_DECREF(buf); } self->read = save_read; + self->decode_depth--; + assert(self->decode_depth >= 0); + if (self->decode_depth == 0) { + clear_shareable_state(self); + } return ret; } Index: cbor2-5.6.5/source/decoder.h =================================================================== --- cbor2-5.6.5.orig/source/decoder.h +++ cbor2-5.6.5/source/decoder.h @@ -13,6 +13,7 @@ typedef struct { PyObject *str_errors; bool immutable; Py_ssize_t shared_index; + Py_ssize_t decode_depth; } CBORDecoderObject; extern PyTypeObject CBORDecoderType; Index: cbor2-5.6.5/source/encoder.c =================================================================== --- cbor2-5.6.5.orig/source/encoder.c +++ cbor2-5.6.5/source/encoder.c @@ -113,6 +113,7 @@ CBOREncoder_new(PyTypeObject *type, PyOb self->shared_handler = NULL; self->string_referencing = false; self->string_namespacing = false; + self->encode_depth = 0; } return (PyObject *) self; } @@ -2027,17 +2028,35 @@ encode(CBOREncoderObject *self, PyObject } +// Reset shared state at the end of each top-level encode to prevent +// shared references from leaking between independent encode operations. +// Nested calls (from hooks or recursive encoding) must preserve the state. +static inline void +clear_shared_state(CBOREncoderObject *self) +{ + PyDict_Clear(self->shared); + PyDict_Clear(self->string_references); +} + + // CBOREncoder.encode(self, value) PyObject * CBOREncoder_encode(CBOREncoderObject *self, PyObject *value) { PyObject *ret; - // TODO reset shared dict? - if (Py_EnterRecursiveCall(" in CBOREncoder.encode")) + self->encode_depth++; + if (Py_EnterRecursiveCall(" in CBOREncoder.encode")) { + self->encode_depth--; return NULL; + } ret = encode(self, value); Py_LeaveRecursiveCall(); + self->encode_depth--; + assert(self->encode_depth >= 0); + if (self->encode_depth == 0) { + clear_shared_state(self); + } return ret; } Index: cbor2-5.6.5/source/encoder.h =================================================================== --- cbor2-5.6.5.orig/source/encoder.h +++ cbor2-5.6.5/source/encoder.h @@ -24,6 +24,7 @@ typedef struct { bool value_sharing; bool string_referencing; bool string_namespacing; + Py_ssize_t encode_depth; } CBOREncoderObject; extern PyTypeObject CBOREncoderType; Index: cbor2-5.6.5/tests/test_decoder.py =================================================================== --- cbor2-5.6.5.orig/tests/test_decoder.py +++ cbor2-5.6.5/tests/test_decoder.py @@ -983,3 +983,65 @@ def test_oversized_read(impl, payload: b dummy_path.write_bytes(payload) with dummy_path.open("rb") as f: impl.load(f) + + +class TestDecoderReuse: + """ + Tests for correct behavior when reusing CBORDecoder instances. + """ + + def test_decoder_reuse_resets_shared_refs(self, impl): + """ + Shared references should be scoped to a single decode operation, + not persist across multiple decodes on the same decoder instance. + """ + # Message with shareable tag (28) + msg1 = impl.dumps(impl.CBORTag(28, "first_value")) + + # Message with sharedref tag (29) referencing index 0 + msg2 = impl.dumps(impl.CBORTag(29, 0)) + + # Reuse decoder across messages + decoder = impl.CBORDecoder(BytesIO(msg1)) + result1 = decoder.decode() + assert result1 == "first_value" + + # Second decode should fail - sharedref(0) doesn't exist in this context + decoder.fp = BytesIO(msg2) + with pytest.raises(impl.CBORDecodeValueError, match="shared reference"): + decoder.decode() + + def test_decode_from_bytes_resets_shared_refs(self, impl): + """ + decode_from_bytes should also reset shared references between calls. + """ + msg1 = impl.dumps(impl.CBORTag(28, "value")) + msg2 = impl.dumps(impl.CBORTag(29, 0)) + + decoder = impl.CBORDecoder(BytesIO(b"")) + decoder.decode_from_bytes(msg1) + + with pytest.raises(impl.CBORDecodeValueError, match="shared reference"): + decoder.decode_from_bytes(msg2) + + def test_shared_refs_within_single_decode(self, impl): + """ + Shared references must work correctly within a single decode operation. + + Note: This tests non-cyclic sibling references [shareable(x), sharedref(0)], + which is a different pattern from test_cyclic_array/test_cyclic_map that + test self-referencing structures like shareable([sharedref(0)]). + """ + # [shareable("hello"), sharedref(0)] -> ["hello", "hello"] + data = unhexlify( + "82" # array(2) + "d81c" # tag(28) shareable + "65" # text(5) + "68656c6c6f" # "hello" + "d81d" # tag(29) sharedref + "00" # unsigned(0) + ) + + result = impl.loads(data) + assert result == ["hello", "hello"] + assert result[0] is result[1] # Same object reference Index: cbor2-5.6.5/tests/test_encoder.py =================================================================== --- cbor2-5.6.5.orig/tests/test_encoder.py +++ cbor2-5.6.5/tests/test_encoder.py @@ -654,3 +654,72 @@ def test_invariant_encode_decode(impl, v undergoing an encode and decode) """ assert impl.loads(impl.dumps(val)) == val + +class TestEncoderReuse: + """ + Tests for correct behavior when reusing CBOREncoder instances. + """ + + def test_encoder_reuse_resets_shared_containers(self, impl): + """ + Shared container tracking should be scoped to a single encode operation, + not persist across multiple encodes on the same encoder instance. + """ + fp = BytesIO() + encoder = impl.CBOREncoder(fp, value_sharing=True) + shared_obj = ["hello"] + + # First encode: object is tracked in shared containers + encoder.encode([shared_obj, shared_obj]) + + # Second encode on new fp: should produce valid standalone CBOR + # (not a sharedref pointing to stale first-encode data) + encoder.fp = BytesIO() + encoder.encode(shared_obj) + second_output = encoder.fp.getvalue() + + # The second output must be decodable on its own + result = impl.loads(second_output) + assert result == ["hello"] + + def test_encode_to_bytes_resets_shared_containers(self, impl): + """ + encode_to_bytes should also reset shared container tracking between calls. + """ + fp = BytesIO() + encoder = impl.CBOREncoder(fp, value_sharing=True) + shared_obj = ["hello"] + + # First encode + encoder.encode_to_bytes([shared_obj, shared_obj]) + + # Second encode should produce valid standalone CBOR + result_bytes = encoder.encode_to_bytes(shared_obj) + result = impl.loads(result_bytes) + assert result == ["hello"] + + def test_encoder_hook_does_not_reset_state(self, impl): + """ + When a custom encoder hook calls encode(), the shared container + tracking should be preserved (not reset mid-operation). + """ + + class Custom: + def __init__(self, value): + self.value = value + + def custom_encoder(encoder, obj): + # Hook encodes the wrapped value + encoder.encode(obj.value) + + # Encode a Custom wrapping a list + data = impl.dumps(Custom(["a", "b"]), default=custom_encoder) + + # Verify the output decodes correctly + result = impl.loads(data) + assert result == ["a", "b"] + + # Test nested Custom objects - hook should work recursively + data2 = impl.dumps(Custom(Custom(["x"])), default=custom_encoder) + result2 = impl.loads(data2) + assert result2 == ["x"]