mirror of
https://github.com/openSUSE/osc.git
synced 2024-12-25 17:36:13 +01:00
Store cached model defaults in self._defaults, avoid sharing references to mutable defaults
This commit is contained in:
parent
587c094f61
commit
16cdc067a5
@ -6,6 +6,7 @@ It works on python 3.6+.
|
|||||||
This module IS NOT a supported API, it is meant for osc internal use only.
|
This module IS NOT a supported API, it is meant for osc internal use only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
@ -82,9 +83,6 @@ class Field(property):
|
|||||||
# a flag indicating, whether the default is a callable with lazy evalution
|
# a flag indicating, whether the default is a callable with lazy evalution
|
||||||
self.default_is_lazy = callable(self.default)
|
self.default_is_lazy = callable(self.default)
|
||||||
|
|
||||||
# whether the field was set
|
|
||||||
self.is_set = False
|
|
||||||
|
|
||||||
# the name of model's attribute associated with this field instance - set from the model
|
# the name of model's attribute associated with this field instance - set from the model
|
||||||
self.name = None
|
self.name = None
|
||||||
|
|
||||||
@ -209,6 +207,11 @@ class Field(property):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return obj._defaults[self.name]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
if isinstance(self.default, FromParent):
|
if isinstance(self.default, FromParent):
|
||||||
if obj._parent is None:
|
if obj._parent is None:
|
||||||
raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set")
|
raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set")
|
||||||
@ -217,18 +220,23 @@ class Field(property):
|
|||||||
if self.default is NotSet:
|
if self.default is NotSet:
|
||||||
raise RuntimeError(f"The field '{self.name}' has no default")
|
raise RuntimeError(f"The field '{self.name}' has no default")
|
||||||
|
|
||||||
|
# make a deepcopy to avoid problems with mutable defaults
|
||||||
|
default = copy.deepcopy(self.default)
|
||||||
|
|
||||||
# lazy evaluation of a factory function on first use
|
# lazy evaluation of a factory function on first use
|
||||||
if callable(self.default):
|
if callable(default):
|
||||||
self.default = self.default()
|
default = default()
|
||||||
|
|
||||||
# if this is a model field, convert dict to a model instance
|
# if this is a model field, convert dict to a model instance
|
||||||
if self.is_model and isinstance(self.default, dict):
|
if self.is_model and isinstance(default, dict):
|
||||||
new_value = self.origin_type() # pylint: disable=not-callable
|
cls = self.origin_type
|
||||||
for k, v in self.default.items():
|
new_value = cls() # pylint: disable=not-callable
|
||||||
|
for k, v in default.items():
|
||||||
setattr(new_value, k, v)
|
setattr(new_value, k, v)
|
||||||
self.default = new_value
|
default = new_value
|
||||||
|
|
||||||
return self.default
|
obj._defaults[self.name] = default
|
||||||
|
return default
|
||||||
|
|
||||||
def set(self, obj, value):
|
def set(self, obj, value):
|
||||||
# if this is a model field, convert dict to a model instance
|
# if this is a model field, convert dict to a model instance
|
||||||
@ -240,7 +248,6 @@ class Field(property):
|
|||||||
|
|
||||||
self.validate_type(value)
|
self.validate_type(value)
|
||||||
obj._values[self.name] = value
|
obj._values[self.name] = value
|
||||||
self.is_set = True
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMeta(type):
|
class ModelMeta(type):
|
||||||
@ -288,7 +295,8 @@ class BaseModel(metaclass=ModelMeta):
|
|||||||
raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed")
|
raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed")
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self._values = {}
|
self._defaults = {} # field defaults cached in field.get()
|
||||||
|
self._values = {} # field values explicitly set after initializing the model
|
||||||
self._parent = kwargs.pop("_parent", None)
|
self._parent = kwargs.pop("_parent", None)
|
||||||
|
|
||||||
uninitialized_fields = []
|
uninitialized_fields = []
|
||||||
@ -321,7 +329,7 @@ class BaseModel(metaclass=ModelMeta):
|
|||||||
for name, field in self.__fields__.items():
|
for name, field in self.__fields__.items():
|
||||||
if field.exclude:
|
if field.exclude:
|
||||||
continue
|
continue
|
||||||
if exclude_unset and not field.is_set and field.is_optional:
|
if exclude_unset and field.name not in self._values and field.is_optional:
|
||||||
# include only mandatory fields and optional fields that were set to an actual value
|
# include only mandatory fields and optional fields that were set to an actual value
|
||||||
continue
|
continue
|
||||||
if field.is_model:
|
if field.is_model:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import importlib
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import osc.conf
|
import osc.conf
|
||||||
@ -8,12 +7,7 @@ from osc.oscerr import UserAbort
|
|||||||
|
|
||||||
class TestTrustedProjects(unittest.TestCase):
|
class TestTrustedProjects(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# reset the global `config` in preparation for running the tests
|
osc.conf.config = osc.conf.Options()
|
||||||
importlib.reload(osc.conf)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# reset the global `config` to avoid impacting tests from other classes
|
|
||||||
importlib.reload(osc.conf)
|
|
||||||
|
|
||||||
def test_name(self):
|
def test_name(self):
|
||||||
apiurl = "https://example.com"
|
apiurl = "https://example.com"
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import importlib
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -105,7 +104,6 @@ plugin-option = plugin-host-option
|
|||||||
|
|
||||||
class TestExampleConfig(unittest.TestCase):
|
class TestExampleConfig(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
importlib.reload(osc.conf)
|
|
||||||
self.tmpdir = tempfile.mkdtemp(prefix="osc_test_")
|
self.tmpdir = tempfile.mkdtemp(prefix="osc_test_")
|
||||||
self.oscrc = os.path.join(self.tmpdir, "oscrc")
|
self.oscrc = os.path.join(self.tmpdir, "oscrc")
|
||||||
with open(self.oscrc, "w", encoding="utf-8") as f:
|
with open(self.oscrc, "w", encoding="utf-8") as f:
|
||||||
@ -481,6 +479,15 @@ class TestConf(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
osc.conf.write_initial_config(conffile, entries)
|
osc.conf.write_initial_config(conffile, entries)
|
||||||
|
|
||||||
|
def test_api_host_options(self):
|
||||||
|
# test that instances do not share any references leaked from the defaults
|
||||||
|
conf1 = osc.conf.Options()
|
||||||
|
conf2 = osc.conf.Options()
|
||||||
|
|
||||||
|
self.assertNotEqual(conf1, conf2)
|
||||||
|
self.assertNotEqual(id(conf1), id(conf2))
|
||||||
|
self.assertNotEqual(id(conf1.api_host_options), id(conf2.api_host_options))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import importlib
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@ -13,14 +12,10 @@ FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "conf_fixtures")
|
|||||||
class TestMirrorGroup(unittest.TestCase):
|
class TestMirrorGroup(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdir = tempfile.mkdtemp(prefix='osc_test')
|
self.tmpdir = tempfile.mkdtemp(prefix='osc_test')
|
||||||
# reset the global `config` in preparation for running the tests
|
|
||||||
importlib.reload(osc.conf)
|
|
||||||
oscrc = os.path.join(self._get_fixtures_dir(), "oscrc")
|
oscrc = os.path.join(self._get_fixtures_dir(), "oscrc")
|
||||||
osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True)
|
osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# reset the global `config` to avoid impacting tests from other classes
|
|
||||||
importlib.reload(osc.conf)
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(self.tmpdir)
|
shutil.rmtree(self.tmpdir)
|
||||||
except:
|
except:
|
||||||
|
@ -79,11 +79,12 @@ class Test(unittest.TestCase):
|
|||||||
|
|
||||||
m = TestModel()
|
m = TestModel()
|
||||||
|
|
||||||
field = m.__fields__["field"]
|
self.assertNotIn("field", m._values)
|
||||||
self.assertEqual(field.is_set, False)
|
|
||||||
self.assertEqual(m.field, None)
|
self.assertEqual(m.field, None)
|
||||||
|
|
||||||
m.field = "text"
|
m.field = "text"
|
||||||
self.assertEqual(field.is_set, True)
|
|
||||||
|
self.assertIn("field", m._values)
|
||||||
self.assertEqual(m.field, "text")
|
self.assertEqual(m.field, "text")
|
||||||
|
|
||||||
def test_str(self):
|
def test_str(self):
|
||||||
@ -95,7 +96,6 @@ class Test(unittest.TestCase):
|
|||||||
field = m.__fields__["field"]
|
field = m.__fields__["field"]
|
||||||
self.assertEqual(field.is_model, False)
|
self.assertEqual(field.is_model, False)
|
||||||
self.assertEqual(field.is_optional, False)
|
self.assertEqual(field.is_optional, False)
|
||||||
self.assertEqual(field.is_set, False)
|
|
||||||
self.assertEqual(field.origin_type, str)
|
self.assertEqual(field.origin_type, str)
|
||||||
|
|
||||||
self.assertEqual(m.field, "default")
|
self.assertEqual(m.field, "default")
|
||||||
@ -111,7 +111,6 @@ class Test(unittest.TestCase):
|
|||||||
field = m.__fields__["field"]
|
field = m.__fields__["field"]
|
||||||
self.assertEqual(field.is_model, False)
|
self.assertEqual(field.is_model, False)
|
||||||
self.assertEqual(field.is_optional, True)
|
self.assertEqual(field.is_optional, True)
|
||||||
self.assertEqual(field.is_set, False)
|
|
||||||
self.assertEqual(field.origin_type, str)
|
self.assertEqual(field.origin_type, str)
|
||||||
|
|
||||||
self.assertEqual(m.field, None)
|
self.assertEqual(m.field, None)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import importlib
|
|
||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -74,12 +73,7 @@ Key : Value
|
|||||||
|
|
||||||
class TestPrintMsg(unittest.TestCase):
|
class TestPrintMsg(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# reset the global `config` in preparation for running the tests
|
osc.conf.config = osc.conf.Options()
|
||||||
importlib.reload(osc.conf)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# reset the global `config` to avoid impacting tests from other classes
|
|
||||||
importlib.reload(osc.conf)
|
|
||||||
|
|
||||||
def test_debug(self):
|
def test_debug(self):
|
||||||
osc.conf.config["debug"] = False
|
osc.conf.config["debug"] = False
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import importlib
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ from .common import patch
|
|||||||
|
|
||||||
class TestVC(unittest.TestCase):
|
class TestVC(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
importlib.reload(osc.conf)
|
osc.conf.config = osc.conf.Options()
|
||||||
|
|
||||||
config = osc.conf.config
|
config = osc.conf.config
|
||||||
host_options = osc.conf.HostOptions(
|
host_options = osc.conf.HostOptions(
|
||||||
@ -21,8 +20,6 @@ class TestVC(unittest.TestCase):
|
|||||||
config["apiurl"] = host_options["apiurl"]
|
config["apiurl"] = host_options["apiurl"]
|
||||||
self.host_options = host_options
|
self.host_options = host_options
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
importlib.reload(osc.conf)
|
|
||||||
|
|
||||||
@patch.dict(os.environ, {}, clear=True)
|
@patch.dict(os.environ, {}, clear=True)
|
||||||
def test_vc_export_env_conf(self):
|
def test_vc_export_env_conf(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user