1
0
mirror of https://github.com/openSUSE/osc.git synced 2025-01-12 08:56:13 +01:00

Store cached model defaults in self._defaults, avoid sharing references to mutable defaults

This commit is contained in:
Daniel Mach 2024-01-03 21:21:40 +01:00
parent 587c094f61
commit 16cdc067a5
7 changed files with 37 additions and 43 deletions

View File

@ -6,6 +6,7 @@ It works on python 3.6+.
This module IS NOT a supported API, it is meant for osc internal use only. This module IS NOT a supported API, it is meant for osc internal use only.
""" """
import copy
import inspect import inspect
import sys import sys
import types import types
@ -82,9 +83,6 @@ class Field(property):
# a flag indicating, whether the default is a callable with lazy evalution # a flag indicating, whether the default is a callable with lazy evalution
self.default_is_lazy = callable(self.default) self.default_is_lazy = callable(self.default)
# whether the field was set
self.is_set = False
# the name of model's attribute associated with this field instance - set from the model # the name of model's attribute associated with this field instance - set from the model
self.name = None self.name = None
@ -209,6 +207,11 @@ class Field(property):
except KeyError: except KeyError:
pass pass
try:
return obj._defaults[self.name]
except KeyError:
pass
if isinstance(self.default, FromParent): if isinstance(self.default, FromParent):
if obj._parent is None: if obj._parent is None:
raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set") raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set")
@ -217,18 +220,23 @@ class Field(property):
if self.default is NotSet: if self.default is NotSet:
raise RuntimeError(f"The field '{self.name}' has no default") raise RuntimeError(f"The field '{self.name}' has no default")
# make a deepcopy to avoid problems with mutable defaults
default = copy.deepcopy(self.default)
# lazy evaluation of a factory function on first use # lazy evaluation of a factory function on first use
if callable(self.default): if callable(default):
self.default = self.default() default = default()
# if this is a model field, convert dict to a model instance # if this is a model field, convert dict to a model instance
if self.is_model and isinstance(self.default, dict): if self.is_model and isinstance(default, dict):
new_value = self.origin_type() # pylint: disable=not-callable cls = self.origin_type
for k, v in self.default.items(): new_value = cls() # pylint: disable=not-callable
for k, v in default.items():
setattr(new_value, k, v) setattr(new_value, k, v)
self.default = new_value default = new_value
return self.default obj._defaults[self.name] = default
return default
def set(self, obj, value): def set(self, obj, value):
# if this is a model field, convert dict to a model instance # if this is a model field, convert dict to a model instance
@ -240,7 +248,6 @@ class Field(property):
self.validate_type(value) self.validate_type(value)
obj._values[self.name] = value obj._values[self.name] = value
self.is_set = True
class ModelMeta(type): class ModelMeta(type):
@ -288,7 +295,8 @@ class BaseModel(metaclass=ModelMeta):
raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed") raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed")
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._values = {} self._defaults = {} # field defaults cached in field.get()
self._values = {} # field values explicitly set after initializing the model
self._parent = kwargs.pop("_parent", None) self._parent = kwargs.pop("_parent", None)
uninitialized_fields = [] uninitialized_fields = []
@ -321,7 +329,7 @@ class BaseModel(metaclass=ModelMeta):
for name, field in self.__fields__.items(): for name, field in self.__fields__.items():
if field.exclude: if field.exclude:
continue continue
if exclude_unset and not field.is_set and field.is_optional: if exclude_unset and field.name not in self._values and field.is_optional:
# include only mandatory fields and optional fields that were set to an actual value # include only mandatory fields and optional fields that were set to an actual value
continue continue
if field.is_model: if field.is_model:

View File

@ -1,4 +1,3 @@
import importlib
import unittest import unittest
import osc.conf import osc.conf
@ -8,12 +7,7 @@ from osc.oscerr import UserAbort
class TestTrustedProjects(unittest.TestCase): class TestTrustedProjects(unittest.TestCase):
def setUp(self): def setUp(self):
# reset the global `config` in preparation for running the tests osc.conf.config = osc.conf.Options()
importlib.reload(osc.conf)
def tearDown(self):
# reset the global `config` to avoid impacting tests from other classes
importlib.reload(osc.conf)
def test_name(self): def test_name(self):
apiurl = "https://example.com" apiurl = "https://example.com"

View File

@ -1,4 +1,3 @@
import importlib
import os import os
import shutil import shutil
import tempfile import tempfile
@ -105,7 +104,6 @@ plugin-option = plugin-host-option
class TestExampleConfig(unittest.TestCase): class TestExampleConfig(unittest.TestCase):
def setUp(self): def setUp(self):
importlib.reload(osc.conf)
self.tmpdir = tempfile.mkdtemp(prefix="osc_test_") self.tmpdir = tempfile.mkdtemp(prefix="osc_test_")
self.oscrc = os.path.join(self.tmpdir, "oscrc") self.oscrc = os.path.join(self.tmpdir, "oscrc")
with open(self.oscrc, "w", encoding="utf-8") as f: with open(self.oscrc, "w", encoding="utf-8") as f:
@ -481,6 +479,15 @@ class TestConf(unittest.TestCase):
} }
osc.conf.write_initial_config(conffile, entries) osc.conf.write_initial_config(conffile, entries)
def test_api_host_options(self):
# test that instances do not share any references leaked from the defaults
conf1 = osc.conf.Options()
conf2 = osc.conf.Options()
self.assertNotEqual(conf1, conf2)
self.assertNotEqual(id(conf1), id(conf2))
self.assertNotEqual(id(conf1.api_host_options), id(conf2.api_host_options))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,4 +1,3 @@
import importlib
import os import os
import tempfile import tempfile
import unittest import unittest
@ -13,14 +12,10 @@ FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "conf_fixtures")
class TestMirrorGroup(unittest.TestCase): class TestMirrorGroup(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdir = tempfile.mkdtemp(prefix='osc_test') self.tmpdir = tempfile.mkdtemp(prefix='osc_test')
# reset the global `config` in preparation for running the tests
importlib.reload(osc.conf)
oscrc = os.path.join(self._get_fixtures_dir(), "oscrc") oscrc = os.path.join(self._get_fixtures_dir(), "oscrc")
osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True) osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True)
def tearDown(self): def tearDown(self):
# reset the global `config` to avoid impacting tests from other classes
importlib.reload(osc.conf)
try: try:
shutil.rmtree(self.tmpdir) shutil.rmtree(self.tmpdir)
except: except:

View File

@ -79,11 +79,12 @@ class Test(unittest.TestCase):
m = TestModel() m = TestModel()
field = m.__fields__["field"] self.assertNotIn("field", m._values)
self.assertEqual(field.is_set, False)
self.assertEqual(m.field, None) self.assertEqual(m.field, None)
m.field = "text" m.field = "text"
self.assertEqual(field.is_set, True)
self.assertIn("field", m._values)
self.assertEqual(m.field, "text") self.assertEqual(m.field, "text")
def test_str(self): def test_str(self):
@ -95,7 +96,6 @@ class Test(unittest.TestCase):
field = m.__fields__["field"] field = m.__fields__["field"]
self.assertEqual(field.is_model, False) self.assertEqual(field.is_model, False)
self.assertEqual(field.is_optional, False) self.assertEqual(field.is_optional, False)
self.assertEqual(field.is_set, False)
self.assertEqual(field.origin_type, str) self.assertEqual(field.origin_type, str)
self.assertEqual(m.field, "default") self.assertEqual(m.field, "default")
@ -111,7 +111,6 @@ class Test(unittest.TestCase):
field = m.__fields__["field"] field = m.__fields__["field"]
self.assertEqual(field.is_model, False) self.assertEqual(field.is_model, False)
self.assertEqual(field.is_optional, True) self.assertEqual(field.is_optional, True)
self.assertEqual(field.is_set, False)
self.assertEqual(field.origin_type, str) self.assertEqual(field.origin_type, str)
self.assertEqual(m.field, None) self.assertEqual(m.field, None)

View File

@ -1,5 +1,4 @@
import contextlib import contextlib
import importlib
import io import io
import unittest import unittest
@ -74,12 +73,7 @@ Key : Value
class TestPrintMsg(unittest.TestCase): class TestPrintMsg(unittest.TestCase):
def setUp(self): def setUp(self):
# reset the global `config` in preparation for running the tests osc.conf.config = osc.conf.Options()
importlib.reload(osc.conf)
def tearDown(self):
# reset the global `config` to avoid impacting tests from other classes
importlib.reload(osc.conf)
def test_debug(self): def test_debug(self):
osc.conf.config["debug"] = False osc.conf.config["debug"] = False

View File

@ -1,4 +1,3 @@
import importlib
import os import os
import unittest import unittest
@ -11,7 +10,7 @@ from .common import patch
class TestVC(unittest.TestCase): class TestVC(unittest.TestCase):
def setUp(self): def setUp(self):
importlib.reload(osc.conf) osc.conf.config = osc.conf.Options()
config = osc.conf.config config = osc.conf.config
host_options = osc.conf.HostOptions( host_options = osc.conf.HostOptions(
@ -21,8 +20,6 @@ class TestVC(unittest.TestCase):
config["apiurl"] = host_options["apiurl"] config["apiurl"] = host_options["apiurl"]
self.host_options = host_options self.host_options = host_options
def tearDown(self):
importlib.reload(osc.conf)
@patch.dict(os.environ, {}, clear=True) @patch.dict(os.environ, {}, clear=True)
def test_vc_export_env_conf(self): def test_vc_export_env_conf(self):