From 3e2d0f3d0615bf67eb2fe213c9daeb929e3bd2fa Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Thu, 21 Nov 2024 16:33:08 +0100 Subject: [PATCH] Check height limit in modular trees. (#3943) Also rewrite the implementation to use iterative checking instead of recursive checking of tree property values, to ensure stack usage is low. Before, it was possible for appropriately-crafted files to use a significant amount of stack (in the order of hundreds of MB). (cherry picked from commit bf4781a2eed2eef664790170977d1d3d8347efb9) --- lib/jxl/modular/encoding/dec_ma.cc | 65 ++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/lib/jxl/modular/encoding/dec_ma.cc b/lib/jxl/modular/encoding/dec_ma.cc index 66562f7d..2d88e9b1 100644 --- a/lib/jxl/modular/encoding/dec_ma.cc +++ b/lib/jxl/modular/encoding/dec_ma.cc @@ -14,23 +14,49 @@ namespace jxl { namespace { -Status ValidateTree( - const Tree &tree, - const std::vector> &prop_bounds, - size_t root) { - if (tree[root].property == -1) return true; - size_t p = tree[root].property; - int val = tree[root].splitval; - if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree"); - // Splitting at max value makes no sense: left range will be exactly same - // as parent, right range will be invalid (min > max). - if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree"); - auto new_bounds = prop_bounds; - new_bounds[p].first = val + 1; - JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild)); - new_bounds[p] = prop_bounds[p]; - new_bounds[p].second = val; - return ValidateTree(tree, new_bounds, tree[root].rchild); +Status ValidateTree(const Tree &tree) { + int num_properties = 0; + for (auto node : tree) { + if (node.property >= num_properties) { + num_properties = node.property + 1; + } + } + std::vector height(tree.size()); + std::vector> property_ranges( + num_properties * tree.size()); + for (int i = 0; i < num_properties; i++) { + property_ranges[i].first = std::numeric_limits::min(); + property_ranges[i].second = std::numeric_limits::max(); + } + const int kHeightLimit = 2048; + for (size_t i = 0; i < tree.size(); i++) { + if (height[i] > kHeightLimit) { + return JXL_FAILURE("Tree too tall: %d", height[i]); + } + if (tree[i].property == -1) continue; + height[tree[i].lchild] = height[i] + 1; + height[tree[i].rchild] = height[i] + 1; + for (size_t p = 0; p < static_cast(num_properties); p++) { + if (p == static_cast(tree[i].property)) { + pixel_type l = property_ranges[i * num_properties + p].first; + pixel_type u = property_ranges[i * num_properties + p].second; + pixel_type val = tree[i].splitval; + if (l > val || u <= val) { + return JXL_FAILURE("Invalid tree"); + } + property_ranges[tree[i].lchild * num_properties + p] = + std::make_pair(val + 1, u); + property_ranges[tree[i].rchild * num_properties + p] = + std::make_pair(l, val); + } else { + property_ranges[tree[i].lchild * num_properties + p] = + property_ranges[i * num_properties + p]; + property_ranges[tree[i].rchild * num_properties + p] = + property_ranges[i * num_properties + p]; + } + } + } + return true; } Status DecodeTree(BitReader *br, ANSSymbolReader *reader, @@ -79,10 +105,7 @@ Status DecodeTree(BitReader *br, ANSSymbolReader *reader, tree->size() + to_decode + 2, Predictor::Zero, 0, 1); to_decode += 2; } - std::vector> prop_bounds; - prop_bounds.resize(256, {std::numeric_limits::min(), - std::numeric_limits::max()}); - return ValidateTree(*tree, prop_bounds, 0); + return ValidateTree(*tree); } } // namespace -- 2.47.1