17
0
Files
python-azure-core/CVE-2026-21226.patch
John Paul Adrian Glaubitz b426eedd52 Backport upstream patch to fix CVE-2026-21226
- Add CVE-2026-21226.patch to fix deserialization of untrusted data in
  Azure Core shared client library for Python allowing an authorized
  attacker to execute code over a network (bsc#1257703, CVE-2026-21226)
2026-02-06 10:23:00 +01:00

851 lines
34 KiB
Diff

From cd9ebabe25cd795b5fad91cc780e01f03a2c0fea Mon Sep 17 00:00:00 2001
From: Xiang Yan <xiangsjtu@gmail.com>
Date: Thu, 8 Jan 2026 12:48:36 -0800
Subject: [PATCH] Introduce new version of continuation token (#44574)
* Introduce new version of continuation token
* updates
* fix mgmt tests
* address feedback
* Update sdk/core/azure-core/azure/core/polling/_utils.py
Co-authored-by: Kashif Khan <361477+kashifkhan@users.noreply.github.com>
* update tests
* update
* update
---------
Co-authored-by: Kashif Khan <361477+kashifkhan@users.noreply.github.com>
---
sdk/core/azure-core/TROUBLESHOOTING.md | 40 +++++
.../azure-core/azure/core/polling/_poller.py | 12 +-
.../azure-core/azure/core/polling/_utils.py | 140 +++++++++++++++
.../azure/core/polling/base_polling.py | 169 +++++++++++++++++-
.../async_tests/test_base_polling_async.py | 9 +-
.../azure-core/tests/test_base_polling.py | 169 +++++++++++++++++-
sdk/core/azure-core/tests/test_polling.py | 48 +++++
.../coretestserver/test_routes/polling.py | 54 ++++++
8 files changed, 624 insertions(+), 17 deletions(-)
create mode 100644 sdk/core/azure-core/TROUBLESHOOTING.md
create mode 100644 sdk/core/azure-core/azure/core/polling/_utils.py
diff --git a/sdk/core/azure-core/TROUBLESHOOTING.md b/sdk/core/azure-core/TROUBLESHOOTING.md
new file mode 100644
index 0000000000..e2638fe241
--- /dev/null
+++ b/sdk/core/azure-core/TROUBLESHOOTING.md
@@ -0,0 +1,40 @@
+# Troubleshooting Azure Core
+
+This document provides solutions to common issues you may encounter when using the Azure Core library.
+
+## Continuation Token Compatibility Issues
+
+### Error: "Continuation token from a previous version is not compatible"
+
+**Symptoms:**
+
+You may encounter an error message like:
+
+```
+ValueError: This continuation token is not compatible with this version of azure-core. It may have been generated by a previous version.
+```
+
+**Cause:**
+
+Starting from azure-core version 1.38.0, the continuation token format was changed. This change was made to improve security and portability. Continuation tokens are opaque strings and their internal format is not guaranteed to be stable across versions.
+
+Continuation tokens generated by previous versions of azure-core are not compatible with version 1.38.0 and later.
+
+**Solution:**
+
+Unfortunately, old continuation tokens cannot be migrated to the new version. You will need to:
+
+1. **Start a new long-running operation**: Instead of using the old continuation token, initiate a new request for your long-running operation.
+
+2. **Check operation status via Azure Portal or CLI**: If you need to check the status of an operation that was started with an old token, you can use the Azure Portal or Azure CLI to check the operation status directly.
+
+3. **Update or pin your dependencies**: Ensure that any new continuation tokens are generated and consumed using the same version of azure-core (1.38.0 or later).
+
+**Prevention:**
+
+To avoid this issue in the future:
+
+- When upgrading azure-core, ensure that any stored continuation tokens are either consumed before the upgrade or discarded.
+- Design your application to handle the case where a continuation token may become invalid.
+
+For more information, see the [CHANGELOG](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/CHANGELOG.md) for version 1.38.0.
diff --git a/sdk/core/azure-core/azure/core/polling/_poller.py b/sdk/core/azure-core/azure/core/polling/_poller.py
index 8b8e651eeb..8294c7ce14 100644
--- a/sdk/core/azure-core/azure/core/polling/_poller.py
+++ b/sdk/core/azure-core/azure/core/polling/_poller.py
@@ -23,7 +23,6 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
-import base64
import logging
import threading
import uuid
@@ -31,6 +30,7 @@ from typing import TypeVar, Generic, Any, Callable, Optional, Tuple, List
from azure.core.exceptions import AzureError
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.common import with_current_context
+from ._utils import _encode_continuation_token, _decode_continuation_token
PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True)
@@ -108,9 +108,12 @@ class _SansIONoPolling(Generic[PollingReturnType_co]):
return self._deserialization_callback(self._initial_response)
def get_continuation_token(self) -> str:
- import pickle
+ """Return a continuation token that allows to restart the poller later.
- return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii")
+ :rtype: str
+ :return: An opaque continuation token
+ """
+ return _encode_continuation_token(self._initial_response)
@classmethod
def from_continuation_token(
@@ -120,9 +123,8 @@ class _SansIONoPolling(Generic[PollingReturnType_co]):
deserialization_callback = kwargs["deserialization_callback"]
except KeyError:
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None
- import pickle
- initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec
+ initial_response = _decode_continuation_token(continuation_token)
return None, initial_response, deserialization_callback
diff --git a/sdk/core/azure-core/azure/core/polling/_utils.py b/sdk/core/azure-core/azure/core/polling/_utils.py
new file mode 100644
index 0000000000..86d6907c05
--- /dev/null
+++ b/sdk/core/azure-core/azure/core/polling/_utils.py
@@ -0,0 +1,140 @@
+# --------------------------------------------------------------------------
+#
+# Copyright (c) Microsoft Corporation. All rights reserved.
+#
+# The MIT License (MIT)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the ""Software""), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+#
+# --------------------------------------------------------------------------
+"""Shared utilities for polling continuation token serialization."""
+
+import base64
+import binascii
+import json
+from typing import Any, Dict, Mapping
+
+
+# Current continuation token version
+_CONTINUATION_TOKEN_VERSION = 1
+
+# Error message for incompatible continuation tokens from older versions
+_INCOMPATIBLE_TOKEN_ERROR_MESSAGE = (
+ "This continuation token is not compatible with this version of azure-core. "
+ "It may have been generated by a previous version. "
+ "See https://aka.ms/azsdk/python/core/troubleshoot for more information."
+)
+
+# Headers that are needed for LRO rehydration.
+# We use an allowlist approach for security - only include headers we actually need.
+_LRO_HEADERS = frozenset(
+ [
+ "operation-location",
+ # azure-asyncoperation is included only for back compat with mgmt-core<=1.6.0
+ "azure-asyncoperation",
+ "location",
+ "content-type",
+ "retry-after",
+ ]
+)
+
+
+def _filter_sensitive_headers(headers: Mapping[str, str]) -> Dict[str, str]:
+ """Filter headers to only include those needed for LRO rehydration.
+
+ Uses an allowlist approach - only headers required for polling are included.
+
+ :param headers: The headers to filter.
+ :type headers: Mapping[str, str]
+ :return: A new dictionary with only allowed headers.
+ :rtype: dict[str, str]
+ """
+ return {k: v for k, v in headers.items() if k.lower() in _LRO_HEADERS}
+
+
+def _is_pickle_format(data: bytes) -> bool:
+ """Check if the data appears to be in pickle format.
+
+ Pickle protocol markers start with \\x80 followed by a protocol version byte (1-5).
+
+ :param data: The bytes to check.
+ :type data: bytes
+ :return: True if the data appears to be pickled, False otherwise.
+ :rtype: bool
+ """
+ if not data or len(data) < 2:
+ return False
+ # Check for pickle protocol marker (0x80) followed by protocol version 1-5
+ return data[0:1] == b"\x80" and 1 <= data[1] <= 5
+
+
+def _decode_continuation_token(continuation_token: str) -> Dict[str, Any]:
+ """Decode a base64-encoded JSON continuation token.
+
+ :param continuation_token: The base64-encoded continuation token.
+ :type continuation_token: str
+ :return: The decoded JSON data as a dictionary (the "data" field from the token).
+ :rtype: dict
+ :raises ValueError: If the token is invalid or in an unsupported format.
+ """
+ try:
+ decoded_bytes = base64.b64decode(continuation_token)
+ token = json.loads(decoded_bytes.decode("utf-8"))
+ except binascii.Error:
+ # Invalid base64 input
+ raise ValueError("This doesn't look like a continuation token the sdk created.") from None
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ # Check if the data appears to be from an older version
+ if _is_pickle_format(decoded_bytes):
+ raise ValueError(_INCOMPATIBLE_TOKEN_ERROR_MESSAGE) from None
+ raise ValueError("Invalid continuation token format.") from None
+
+ # Validate token schema - must be a dict with a version field
+ if not isinstance(token, dict) or "version" not in token:
+ raise ValueError("Invalid continuation token format.") from None
+
+ # For now, we only support version 1
+ # Future versions can add handling for older versions here if needed
+ if token["version"] != _CONTINUATION_TOKEN_VERSION:
+ raise ValueError(_INCOMPATIBLE_TOKEN_ERROR_MESSAGE) from None
+
+ return token["data"]
+
+
+def _encode_continuation_token(data: Any) -> str:
+ """Encode data as a base64-encoded JSON continuation token.
+
+ The token includes a version field for future compatibility checking.
+
+ :param data: The data to encode. Must be JSON-serializable.
+ :type data: any
+ :return: The base64-encoded JSON string.
+ :rtype: str
+ :raises TypeError: If the data is not JSON-serializable.
+ """
+ token = {
+ "version": _CONTINUATION_TOKEN_VERSION,
+ "data": data,
+ }
+ try:
+ return base64.b64encode(json.dumps(token, separators=(",", ":")).encode("utf-8")).decode("ascii")
+ except (TypeError, ValueError) as err:
+ raise TypeError(
+ "Unable to generate a continuation token for this operation. Payload is not JSON-serializable."
+ ) from err
diff --git a/sdk/core/azure-core/azure/core/polling/base_polling.py b/sdk/core/azure-core/azure/core/polling/base_polling.py
index 4824b7dda6..780279cc93 100644
--- a/sdk/core/azure-core/azure/core/polling/base_polling.py
+++ b/sdk/core/azure-core/azure/core/polling/base_polling.py
@@ -33,6 +33,7 @@ from typing import (
Tuple,
Callable,
Dict,
+ Mapping,
Sequence,
Generic,
TypeVar,
@@ -46,7 +47,9 @@ from ..pipeline.policies._utils import get_retry_after
from ..pipeline._tools import is_rest
from .._enum_meta import CaseInsensitiveEnumMeta
from .. import PipelineClient
-from ..pipeline import PipelineResponse
+from ..pipeline import PipelineResponse, PipelineContext
+from ..rest._helpers import decode_to_text, get_charset_encoding
+from ..utils._utils import case_insensitive_dict
from ..pipeline.transport import (
HttpTransport,
HttpRequest as LegacyHttpRequest,
@@ -54,6 +57,11 @@ from ..pipeline.transport import (
AsyncHttpResponse as LegacyAsyncHttpResponse,
)
from ..rest import HttpRequest, HttpResponse, AsyncHttpResponse
+from ._utils import (
+ _encode_continuation_token,
+ _decode_continuation_token,
+ _filter_sensitive_headers,
+)
HttpRequestType = Union[LegacyHttpRequest, HttpRequest]
@@ -80,6 +88,56 @@ _FAILED = frozenset(["canceled", "failed"])
_SUCCEEDED = frozenset(["succeeded"])
+class _ContinuationTokenHttpResponse:
+ """A minimal HTTP response class for reconstructing responses from continuation tokens.
+
+ This class provides just enough interface to be used with LRO polling operations
+ when restoring from a continuation token.
+
+ :param request: The HTTP request (optional, may be None if not available in the continuation token)
+ :type request: ~azure.core.rest.HttpRequest or None
+ :param status_code: The HTTP status code
+ :type status_code: int
+ :param headers: The response headers
+ :type headers: dict
+ :param content: The response content
+ :type content: bytes
+ """
+
+ def __init__(
+ self,
+ request: Optional[HttpRequest],
+ status_code: int,
+ headers: Dict[str, str],
+ content: bytes,
+ ):
+ self.request = request
+ self.status_code = status_code
+ self.headers = case_insensitive_dict(headers)
+ self._content = content
+
+ @property
+ def content(self) -> bytes:
+ """Return the response content.
+
+ :return: The response content
+ :rtype: bytes
+ """
+ return self._content
+
+ def text(self) -> str:
+ """Return the response content as text.
+
+ Uses the charset from Content-Type header if available, otherwise falls back
+ to UTF-8 with replacement for invalid characters.
+
+ :return: The response content as text
+ :rtype: str
+ """
+ encoding = get_charset_encoding(self)
+ return decode_to_text(encoding, self._content)
+
+
def _get_content(response: AllHttpResponseType) -> bytes:
"""Get the content of this response. This is designed specifically to avoid
a warning of mypy for body() access, as this method is deprecated.
@@ -645,10 +703,82 @@ class _SansIOLROBasePolling(
except OperationFailed as err:
raise HttpResponseError(response=initial_response.http_response, error=err) from err
+ def _filter_headers_for_continuation_token(self, headers: Mapping[str, str]) -> Dict[str, str]:
+ """Filter headers to include in the continuation token.
+
+ Subclasses can override this method to include additional headers needed
+ for their specific LRO implementation.
+
+ :param headers: The response headers to filter.
+ :type headers: Mapping[str, str]
+ :return: A filtered dictionary of headers to include in the continuation token.
+ :rtype: dict[str, str]
+ """
+ return _filter_sensitive_headers(headers)
+
def get_continuation_token(self) -> str:
- import pickle
+ """Get a continuation token that can be used to recreate this poller.
- return base64.b64encode(pickle.dumps(self._initial_response)).decode("ascii")
+ :rtype: str
+ :return: An opaque continuation token.
+ :raises ValueError: If the initial response is not set.
+ """
+ response = self._initial_response.http_response
+ request = response.request
+ # Serialize the essential parts of the PipelineResponse to JSON.
+ if request:
+ request_headers = {}
+ # Preserve x-ms-client-request-id for request correlation
+ if "x-ms-client-request-id" in request.headers:
+ request_headers["x-ms-client-request-id"] = request.headers["x-ms-client-request-id"]
+ request_state = {
+ "method": request.method,
+ "url": request.url,
+ "headers": request_headers,
+ }
+ else:
+ request_state = None
+ # Get response content, handling the case where it might not be read yet
+ try:
+ content = _get_content(response) or b""
+ except Exception: # pylint: disable=broad-except
+ content = b""
+ # Get deserialized data from context if available (optimization).
+ # If context doesn't have it, fall back to parsing the response body directly.
+ # Note: deserialized_data is only included if it's JSON-serializable.
+ # Non-JSON-serializable types (e.g., XML ElementTree) are skipped and set to None.
+ # In such cases, the data can still be re-parsed from the raw content bytes.
+ deserialized_data = None
+ raw_deserialized = None
+ if self._initial_response.context is not None:
+ raw_deserialized = self._initial_response.context.get("deserialized_data")
+ # Fallback: try to get deserialized data from the response body if context didn't have it
+ if raw_deserialized is None and content:
+ try:
+ raw_deserialized = json.loads(content)
+ except (json.JSONDecodeError, ValueError, TypeError):
+ # Response body is not valid JSON, leave as None
+ pass
+ if raw_deserialized is not None:
+ try:
+ # Test if the data is JSON-serializable
+ json.dumps(raw_deserialized)
+ deserialized_data = raw_deserialized
+ except (TypeError, ValueError):
+ # Skip non-JSON-serializable data (e.g., XML ElementTree objects)
+ deserialized_data = None
+ state = {
+ "request": request_state,
+ "response": {
+ "status_code": response.status_code,
+ "headers": self._filter_headers_for_continuation_token(response.headers),
+ "content": base64.b64encode(content).decode("ascii"),
+ },
+ "context": {
+ "deserialized_data": deserialized_data,
+ },
+ }
+ return _encode_continuation_token(state)
@classmethod
def from_continuation_token(
@@ -664,11 +794,34 @@ class _SansIOLROBasePolling(
except KeyError:
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None
- import pickle
-
- initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec
- # Restore the transport in the context
- initial_response.context.transport = client._pipeline._transport # pylint: disable=protected-access
+ state = _decode_continuation_token(continuation_token)
+ # Reconstruct HttpRequest if present
+ request_state = state.get("request")
+ http_request = None
+ if request_state is not None:
+ http_request = HttpRequest(
+ method=request_state["method"],
+ url=request_state["url"],
+ headers=request_state.get("headers", {}),
+ )
+ # Reconstruct HttpResponse using the minimal response class
+ response_state = state["response"]
+ http_response = _ContinuationTokenHttpResponse(
+ request=http_request,
+ status_code=response_state["status_code"],
+ headers=response_state["headers"],
+ content=base64.b64decode(response_state["content"]),
+ )
+ # Reconstruct PipelineResponse
+ context = PipelineContext(client._pipeline._transport) # pylint: disable=protected-access
+ context_state = state.get("context", {})
+ if context_state.get("deserialized_data") is not None:
+ context["deserialized_data"] = context_state["deserialized_data"]
+ initial_response = PipelineResponse(
+ http_request=http_request,
+ http_response=http_response,
+ context=context,
+ )
return client, initial_response, deserialization_callback
def status(self) -> str:
diff --git a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py
index 16cbfd48f1..56e5e44a4e 100644
--- a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py
+++ b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py
@@ -25,7 +25,6 @@
# --------------------------------------------------------------------------
import base64
import json
-import pickle
import re
from utils import HTTP_REQUESTS
from azure.core.pipeline._tools import is_rest
@@ -690,7 +689,13 @@ async def test_long_running_negative(http_request, http_response):
poll = async_poller(CLIENT, response, TestBasePolling.mock_outputs, AsyncLROBasePolling(0))
with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization
await poll
- assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii")
+ # Verify continuation token is set and is a valid JSON-encoded token
+ assert error.value.continuation_token is not None
+ assert isinstance(error.value.continuation_token, str)
+ # Verify the token can be decoded
+ decoded = json.loads(base64.b64decode(error.value.continuation_token).decode("utf-8"))
+ assert "request" in decoded["data"]
+ assert "response" in decoded["data"]
LOCATION_BODY = json.dumps({"name": TEST_NAME})
POLLING_STATUS = 200
diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py
index 3a4b6d75f1..a23f4824ac 100644
--- a/sdk/core/azure-core/tests/test_base_polling.py
+++ b/sdk/core/azure-core/tests/test_base_polling.py
@@ -28,7 +28,6 @@ import datetime
import json
import re
import types
-import pickle
import platform
try:
@@ -177,6 +176,67 @@ def test_base_polling_continuation_token(client, polling_response, http_response
new_polling.initialize(*polling_args)
+def test_base_polling_continuation_token_pickle_incompatibility(client):
+ """Test that from_continuation_token raises ValueError with helpful message for old pickle tokens."""
+ import pickle
+
+ # Simulate an old pickle-based continuation token (would have been a pickled PipelineResponse)
+ old_pickle_data = pickle.dumps({"some": "data"})
+ old_continuation_token = base64.b64encode(old_pickle_data).decode("ascii")
+
+ with pytest.raises(ValueError) as excinfo:
+ LROBasePolling.from_continuation_token(
+ old_continuation_token,
+ deserialization_callback=lambda x: x,
+ client=client,
+ )
+
+ error_message = str(excinfo.value)
+ assert "aka.ms" in error_message
+
+
+@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
+def test_base_polling_continuation_token_with_stream_response(port, http_request, deserialization_cb):
+ """Test that get_continuation_token works correctly with real server responses."""
+ client = MockRestClient(port)
+ request = http_request(
+ "POST",
+ "http://localhost:{}/polling/continuation-token-stream".format(port),
+ )
+ initial_response = client._client._pipeline.run(request)
+
+ # Create polling operation
+ polling = LROBasePolling(timeout=0)
+ polling.initialize(
+ client._client,
+ initial_response,
+ deserialization_cb,
+ )
+
+ # get_continuation_token should work with real server response
+ continuation_token = polling.get_continuation_token()
+ assert isinstance(continuation_token, str)
+
+ # Verify the token can be decoded and contains expected structure
+ decoded = json.loads(base64.b64decode(continuation_token).decode("utf-8"))
+ assert decoded["version"] == 1
+ assert "request" in decoded["data"]
+ assert "response" in decoded["data"]
+ assert decoded["data"]["response"]["status_code"] == 202
+ # Content should be preserved
+ content_bytes = base64.b64decode(decoded["data"]["response"]["content"])
+ assert b"InProgress" in content_bytes
+
+ # Verify we can restore from the continuation token
+ polling_args = LROBasePolling.from_continuation_token(
+ continuation_token,
+ deserialization_callback=deserialization_cb,
+ client=client._client,
+ )
+ new_polling = LROBasePolling()
+ new_polling.initialize(*polling_args)
+
+
@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES)
def test_delay_extraction_int(polling_response, http_response):
polling = polling_response(http_response, {"Retry-After": "10"})
@@ -714,7 +774,13 @@ class TestBasePolling(object):
poll = LROPoller(CLIENT, response, TestBasePolling.mock_outputs, LROBasePolling(0))
with pytest.raises(HttpResponseError) as error: # TODO: Node.js raises on deserialization
poll.result()
- assert error.value.continuation_token == base64.b64encode(pickle.dumps(response)).decode("ascii")
+ # Verify continuation token is set and is a valid JSON-encoded token
+ assert error.value.continuation_token is not None
+ assert isinstance(error.value.continuation_token, str)
+ # Verify the token can be decoded
+ decoded = json.loads(base64.b64decode(error.value.continuation_token).decode("utf-8"))
+ assert "request" in decoded["data"]
+ assert "response" in decoded["data"]
LOCATION_BODY = json.dumps({"name": TEST_NAME})
POLLING_STATUS = 200
@@ -833,3 +899,102 @@ def test_post_check_patch(http_request):
with pytest.raises(AttributeError) as ex:
algorithm.get_final_get_url(None)
assert "'NoneType' object has no attribute 'http_response'" in str(ex.value)
+
+
+def test_continuation_token_with_non_json_serializable_data(port, deserialization_cb):
+ """Test that continuation token gracefully handles non-JSON-serializable data like XML."""
+ import base64
+ import json
+ import xml.etree.ElementTree as ET
+
+ from azure.core.polling.base_polling import LROBasePolling
+ from azure.core.rest import HttpRequest
+
+ client = MockRestClient(port)
+ request = HttpRequest(
+ "POST",
+ "http://localhost:{}/polling/continuation-token-xml".format(port),
+ )
+ initial_response = client._client._pipeline.run(request)
+
+ # Simulate XML deserialized data (non-JSON-serializable)
+ xml_element = ET.fromstring(b"<root><status>InProgress</status></root>")
+ initial_response.context["deserialized_data"] = xml_element
+
+ # Create polling operation
+ polling = LROBasePolling(timeout=0)
+ polling.initialize(
+ client._client,
+ initial_response,
+ deserialization_cb,
+ )
+
+ # Get continuation token - this should NOT raise an error
+ token = polling.get_continuation_token()
+
+ # Decode and verify the token structure
+ decoded = json.loads(base64.b64decode(token).decode("utf-8"))
+
+ # deserialized_data should be None because XML is not JSON-serializable
+ assert decoded["data"]["context"]["deserialized_data"] is None
+
+ # The raw content should still be preserved
+ content_bytes = base64.b64decode(decoded["data"]["response"]["content"])
+ assert b"<root>" in content_bytes or b"InProgress" in content_bytes
+
+
+@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
+def test_continuation_token_excludes_request_headers(port, http_request, deserialization_cb):
+ """Test that continuation token does not include sensitive request headers for security."""
+ import base64
+ import json
+
+ from azure.core.polling.base_polling import LROBasePolling
+
+ client = MockRestClient(port)
+ request = http_request(
+ "POST",
+ "http://localhost:{}/polling/continuation-token".format(port),
+ )
+ # Add headers that should NOT be included in the continuation token
+ request.headers["Authorization"] = "Bearer super-secret-token"
+ request.headers["x-ms-authorization-auxiliary"] = "auxiliary-secret"
+ request.headers["x-custom-header"] = "custom-value"
+ # Add header that SHOULD be included for request correlation
+ request.headers["x-ms-client-request-id"] = "test-request-id-12345"
+
+ initial_response = client._client._pipeline.run(request)
+
+ # Create polling operation
+ polling = LROBasePolling(timeout=0)
+ polling.initialize(
+ client._client,
+ initial_response,
+ deserialization_cb,
+ )
+
+ token = polling.get_continuation_token()
+
+ # Decode and verify sensitive request headers are not included
+ decoded = json.loads(base64.b64decode(token).decode("utf-8"))
+
+ # Request should contain method, url, and only safe headers (x-ms-client-request-id)
+ request_state = decoded["data"]["request"]
+ assert request_state["method"] == "POST"
+ assert "continuation-token" in request_state["url"]
+ # Only x-ms-client-request-id should be in headers
+ assert "headers" in request_state
+ assert request_state["headers"].get("x-ms-client-request-id") == "test-request-id-12345"
+ # Sensitive headers should NOT be included
+ assert "Authorization" not in request_state["headers"]
+ assert "x-ms-authorization-auxiliary" not in request_state["headers"]
+ assert "x-custom-header" not in request_state["headers"]
+
+ # Verify we can restore from the continuation token
+ polling_args = LROBasePolling.from_continuation_token(
+ token,
+ deserialization_callback=deserialization_cb,
+ client=client._client,
+ )
+ new_polling = LROBasePolling()
+ new_polling.initialize(*polling_args)
diff --git a/sdk/core/azure-core/tests/test_polling.py b/sdk/core/azure-core/tests/test_polling.py
index e3fb92e124..cdabd4575b 100644
--- a/sdk/core/azure-core/tests/test_polling.py
+++ b/sdk/core/azure-core/tests/test_polling.py
@@ -23,6 +23,7 @@
# THE SOFTWARE.
#
# --------------------------------------------------------------------------
+import base64
import time
try:
@@ -99,6 +100,53 @@ def test_no_polling(client):
assert no_polling_revived.resource() == "Treated: " + initial_response
+def test_no_polling_continuation_token_missing_callback():
+ """Test that from_continuation_token raises ValueError when deserialization_callback is missing."""
+ no_polling = NoPolling()
+ no_polling.initialize(None, "test", lambda x: x)
+
+ continuation_token = no_polling.get_continuation_token()
+
+ with pytest.raises(ValueError) as excinfo:
+ NoPolling.from_continuation_token(continuation_token)
+ assert "deserialization_callback" in str(excinfo.value)
+
+
+def test_no_polling_continuation_token_pickle_incompatibility():
+ """Test that from_continuation_token raises ValueError with helpful message for old pickle tokens."""
+ import pickle
+
+ # Simulate an old pickle-based continuation token
+ old_pickle_data = pickle.dumps({"some": "data"})
+ old_continuation_token = base64.b64encode(old_pickle_data).decode("ascii")
+
+ with pytest.raises(ValueError) as excinfo:
+ NoPolling.from_continuation_token(old_continuation_token, deserialization_callback=lambda x: x)
+
+ error_message = str(excinfo.value)
+ assert "aka.ms" in error_message
+
+
+def test_no_polling_continuation_token_non_serializable():
+ """Test that get_continuation_token raises TypeError for non-JSON-serializable initial responses."""
+ no_polling = NoPolling()
+
+ # Create a non-JSON-serializable object
+ class CustomObject:
+ def __init__(self, value):
+ self.value = value
+
+ initial_response = CustomObject("test")
+
+ no_polling.initialize(None, initial_response, lambda x: x)
+
+ with pytest.raises(TypeError) as excinfo:
+ no_polling.get_continuation_token()
+
+ error_message = str(excinfo.value)
+ assert "not JSON-serializable" in error_message
+
+
def test_polling_with_path_format_arguments(client):
method = LROBasePolling(timeout=0, path_format_arguments={"host": "host:3000", "accountName": "local"})
client._base_url = "http://{accountName}{host}"
diff --git a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py
index 005867a38e..527fe77e37 100644
--- a/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py
+++ b/sdk/core/azure-core/tests/testserver_tests/coretestserver/coretestserver/test_routes/polling.py
@@ -169,3 +169,57 @@ def polling_with_options_first():
@polling_api.route("/final-get-with-location", methods=["GET"])
def polling_with_options_final_get_with_location():
return Response('{"returnedFrom": "locationHeaderUrl"}', status=200)
+
+
+@polling_api.route("/continuation-token", methods=["POST"])
+def continuation_token_initial():
+ """Initial LRO response for continuation token tests."""
+ base_url = get_base_url(request)
+ return Response(
+ '{"properties":{"provisioningState": "InProgress"}}',
+ headers={
+ "operation-location": "{}/polling/continuation-token-status".format(base_url),
+ "x-ms-request-id": "test-request-id-12345",
+ },
+ status=202,
+ )
+
+
+@polling_api.route("/continuation-token-status", methods=["GET"])
+def continuation_token_status():
+ """Status endpoint for continuation token tests."""
+ return Response('{"status": "Succeeded"}', status=200)
+
+
+@polling_api.route("/continuation-token-xml", methods=["POST"])
+def continuation_token_xml_initial():
+ """Initial LRO response with XML body for continuation token tests."""
+ base_url = get_base_url(request)
+ return Response(
+ "<root><status>InProgress</status></root>",
+ headers={
+ "operation-location": "{}/polling/continuation-token-xml-status".format(base_url),
+ "content-type": "application/xml",
+ },
+ status=202,
+ )
+
+
+@polling_api.route("/continuation-token-xml-status", methods=["GET"])
+def continuation_token_xml_status():
+ """Status endpoint for XML continuation token tests."""
+ return Response('{"status": "Succeeded"}', status=200)
+
+
+@polling_api.route("/continuation-token-stream", methods=["POST"])
+def continuation_token_stream_initial():
+ """Initial LRO response for stream continuation token tests."""
+ base_url = get_base_url(request)
+ return Response(
+ '{"status": "InProgress"}',
+ headers={
+ "operation-location": "{}/polling/continuation-token-status".format(base_url),
+ "content-type": "application/json",
+ },
+ status=202,
+ )
--
2.52.0