mirror of
https://github.com/openSUSE/osc.git
synced 2025-01-19 11:56:13 +01:00
Merge pull request #1462 from dmach/fix-model-references
Improve storing model values and cached defaults
This commit is contained in:
commit
130c1b4c14
@ -126,7 +126,9 @@ HttpHeader = NewType("HttpHeader", Tuple[str, str])
|
||||
class OscOptions(BaseModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.extra_fields = {}
|
||||
self._allow_new_attributes = True
|
||||
self._extra_fields = {}
|
||||
self._allow_new_attributes = False
|
||||
|
||||
# compat function with the config dict
|
||||
def _get_field_name(self, name):
|
||||
@ -145,7 +147,7 @@ class OscOptions(BaseModel):
|
||||
field_name = self._get_field_name(name)
|
||||
|
||||
if field_name is None and not hasattr(self, name):
|
||||
return self.extra_fields[name]
|
||||
return self._extra_fields[name]
|
||||
|
||||
field_name = field_name or name
|
||||
try:
|
||||
@ -158,7 +160,7 @@ class OscOptions(BaseModel):
|
||||
field_name = self._get_field_name(name)
|
||||
|
||||
if field_name is None and not hasattr(self, name):
|
||||
self.extra_fields[name] = value
|
||||
self._extra_fields[name] = value
|
||||
return
|
||||
|
||||
field_name = field_name or name
|
||||
|
@ -6,6 +6,7 @@ It works on python 3.6+.
|
||||
This module IS NOT a supported API, it is meant for osc internal use only.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
@ -82,9 +83,6 @@ class Field(property):
|
||||
# a flag indicating, whether the default is a callable with lazy evalution
|
||||
self.default_is_lazy = callable(self.default)
|
||||
|
||||
# whether the field was set
|
||||
self.is_set = False
|
||||
|
||||
# the name of model's attribute associated with this field instance - set from the model
|
||||
self.name = None
|
||||
|
||||
@ -205,8 +203,13 @@ class Field(property):
|
||||
|
||||
def get(self, obj):
|
||||
try:
|
||||
return getattr(obj, f"_{self.name}")
|
||||
except AttributeError:
|
||||
return obj._values[self.name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return obj._defaults[self.name]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if isinstance(self.default, FromParent):
|
||||
@ -217,18 +220,23 @@ class Field(property):
|
||||
if self.default is NotSet:
|
||||
raise RuntimeError(f"The field '{self.name}' has no default")
|
||||
|
||||
# make a deepcopy to avoid problems with mutable defaults
|
||||
default = copy.deepcopy(self.default)
|
||||
|
||||
# lazy evaluation of a factory function on first use
|
||||
if callable(self.default):
|
||||
self.default = self.default()
|
||||
if callable(default):
|
||||
default = default()
|
||||
|
||||
# if this is a model field, convert dict to a model instance
|
||||
if self.is_model and isinstance(self.default, dict):
|
||||
new_value = self.origin_type() # pylint: disable=not-callable
|
||||
for k, v in self.default.items():
|
||||
if self.is_model and isinstance(default, dict):
|
||||
cls = self.origin_type
|
||||
new_value = cls() # pylint: disable=not-callable
|
||||
for k, v in default.items():
|
||||
setattr(new_value, k, v)
|
||||
self.default = new_value
|
||||
default = new_value
|
||||
|
||||
return self.default
|
||||
obj._defaults[self.name] = default
|
||||
return default
|
||||
|
||||
def set(self, obj, value):
|
||||
# if this is a model field, convert dict to a model instance
|
||||
@ -239,8 +247,7 @@ class Field(property):
|
||||
value = new_value
|
||||
|
||||
self.validate_type(value)
|
||||
setattr(obj, f"_{self.name}", value)
|
||||
self.is_set = True
|
||||
obj._values[self.name] = value
|
||||
|
||||
|
||||
class ModelMeta(type):
|
||||
@ -280,7 +287,16 @@ class ModelMeta(type):
|
||||
class BaseModel(metaclass=ModelMeta):
|
||||
__fields__: Dict[str, Field]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if getattr(self, "_allow_new_attributes", True) or hasattr(self.__class__, name) or hasattr(self, name):
|
||||
# allow setting properties - test if they exist in the class
|
||||
# also allow setting existing attributes that were previously initialized via __dict__
|
||||
return super().__setattr__(name, value)
|
||||
raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._defaults = {} # field defaults cached in field.get()
|
||||
self._values = {} # field values explicitly set after initializing the model
|
||||
self._parent = kwargs.pop("_parent", None)
|
||||
|
||||
uninitialized_fields = []
|
||||
@ -306,16 +322,15 @@ class BaseModel(metaclass=ModelMeta):
|
||||
for name, field in self.__fields__.items():
|
||||
field.validate_type(getattr(self, name))
|
||||
|
||||
def dict(self, exclude_unset=False):
|
||||
self._allow_new_attributes = False
|
||||
|
||||
def dict(self):
|
||||
result = {}
|
||||
for name, field in self.__fields__.items():
|
||||
if field.exclude:
|
||||
continue
|
||||
if exclude_unset and not field.is_set and field.is_optional:
|
||||
# include only mandatory fields and optional fields that were set to an actual value
|
||||
continue
|
||||
if field.is_model:
|
||||
result[name] = getattr(self, name).dict(exclude_unset=exclude_unset)
|
||||
result[name] = getattr(self, name).dict()
|
||||
else:
|
||||
result[name] = getattr(self, name)
|
||||
return result
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import unittest
|
||||
|
||||
import osc.conf
|
||||
@ -8,12 +7,7 @@ from osc.oscerr import UserAbort
|
||||
|
||||
class TestTrustedProjects(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# reset the global `config` in preparation for running the tests
|
||||
importlib.reload(osc.conf)
|
||||
|
||||
def tearDown(self):
|
||||
# reset the global `config` to avoid impacting tests from other classes
|
||||
importlib.reload(osc.conf)
|
||||
osc.conf.config = osc.conf.Options()
|
||||
|
||||
def test_name(self):
|
||||
apiurl = "https://example.com"
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@ -105,7 +104,6 @@ plugin-option = plugin-host-option
|
||||
|
||||
class TestExampleConfig(unittest.TestCase):
|
||||
def setUp(self):
|
||||
importlib.reload(osc.conf)
|
||||
self.tmpdir = tempfile.mkdtemp(prefix="osc_test_")
|
||||
self.oscrc = os.path.join(self.tmpdir, "oscrc")
|
||||
with open(self.oscrc, "w", encoding="utf-8") as f:
|
||||
@ -116,6 +114,9 @@ class TestExampleConfig(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def test_invalid_attribute(self):
|
||||
self.assertRaises(AttributeError, setattr, self.config, "new_attribute", "123")
|
||||
|
||||
def test_apiurl(self):
|
||||
self.assertEqual(self.config["apiurl"], "https://api.opensuse.org")
|
||||
|
||||
@ -407,26 +408,19 @@ class TestExampleConfig(unittest.TestCase):
|
||||
|
||||
def test_extra_fields(self):
|
||||
self.assertEqual(self.config["plugin-option"], "plugin-general-option")
|
||||
self.assertEqual(self.config.extra_fields, {"plugin-option": "plugin-general-option"})
|
||||
|
||||
# write to an existing attribute instead of extra_fields
|
||||
self.config.attrib = 123
|
||||
self.assertEqual(self.config["attrib"], 123)
|
||||
self.config["attrib"] = 456
|
||||
self.assertEqual(self.config["attrib"], 456)
|
||||
self.assertEqual(self.config.extra_fields, {"plugin-option": "plugin-general-option"})
|
||||
self.assertEqual(self.config._extra_fields, {"plugin-option": "plugin-general-option"})
|
||||
|
||||
self.config["new-option"] = "value"
|
||||
self.assertEqual(self.config["new-option"], "value")
|
||||
self.assertEqual(self.config.extra_fields, {"plugin-option": "plugin-general-option", "new-option": "value"})
|
||||
self.assertEqual(self.config._extra_fields, {"plugin-option": "plugin-general-option", "new-option": "value"})
|
||||
|
||||
host_options = self.config["api_host_options"][self.config["apiurl"]]
|
||||
self.assertEqual(host_options["plugin-option"], "plugin-host-option")
|
||||
self.assertEqual(host_options.extra_fields, {"plugin-option": "plugin-host-option"})
|
||||
self.assertEqual(host_options._extra_fields, {"plugin-option": "plugin-host-option"})
|
||||
|
||||
host_options["new-option"] = "value"
|
||||
self.assertEqual(host_options["new-option"], "value")
|
||||
self.assertEqual(host_options.extra_fields, {"plugin-option": "plugin-host-option", "new-option": "value"})
|
||||
self.assertEqual(host_options._extra_fields, {"plugin-option": "plugin-host-option", "new-option": "value"})
|
||||
|
||||
def test_apiurl_aliases(self):
|
||||
expected = {"https://api.opensuse.org": "https://api.opensuse.org", "osc": "https://api.opensuse.org"}
|
||||
@ -485,6 +479,15 @@ class TestConf(unittest.TestCase):
|
||||
}
|
||||
osc.conf.write_initial_config(conffile, entries)
|
||||
|
||||
def test_api_host_options(self):
|
||||
# test that instances do not share any references leaked from the defaults
|
||||
conf1 = osc.conf.Options()
|
||||
conf2 = osc.conf.Options()
|
||||
|
||||
self.assertNotEqual(conf1, conf2)
|
||||
self.assertNotEqual(id(conf1), id(conf2))
|
||||
self.assertNotEqual(id(conf1.api_host_options), id(conf2.api_host_options))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -13,14 +12,10 @@ FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "conf_fixtures")
|
||||
class TestMirrorGroup(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp(prefix='osc_test')
|
||||
# reset the global `config` in preparation for running the tests
|
||||
importlib.reload(osc.conf)
|
||||
oscrc = os.path.join(self._get_fixtures_dir(), "oscrc")
|
||||
osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True)
|
||||
|
||||
def tearDown(self):
|
||||
# reset the global `config` to avoid impacting tests from other classes
|
||||
importlib.reload(osc.conf)
|
||||
try:
|
||||
shutil.rmtree(self.tmpdir)
|
||||
except:
|
||||
|
@ -24,16 +24,21 @@ class TestNotSet(unittest.TestCase):
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
def test_modified(self):
|
||||
def test_dict(self):
|
||||
class TestSubmodel(BaseModel):
|
||||
text: str = Field(default="default")
|
||||
|
||||
class TestModel(BaseModel):
|
||||
a: str = Field(default="default")
|
||||
b: Optional[str] = Field(default=None)
|
||||
sub: Optional[List[TestSubmodel]] = Field(default=None)
|
||||
|
||||
m = TestModel()
|
||||
self.assertEqual(m.dict(exclude_unset=True), {"a": "default"})
|
||||
self.assertEqual(m.dict(), {"a": "default", "b": None, "sub": None})
|
||||
|
||||
m = TestModel(b=None)
|
||||
self.assertEqual(m.dict(exclude_unset=True), {"a": "default", "b": None})
|
||||
m.b = "B"
|
||||
m.sub = [{"text": "one"}, {"text": "two"}]
|
||||
self.assertEqual(m.dict(), {"a": "default", "b": "B", "sub": [{"text": "one"}, {"text": "two"}]})
|
||||
|
||||
def test_unknown_fields(self):
|
||||
class TestModel(BaseModel):
|
||||
@ -79,11 +84,12 @@ class Test(unittest.TestCase):
|
||||
|
||||
m = TestModel()
|
||||
|
||||
field = m.__fields__["field"]
|
||||
self.assertEqual(field.is_set, False)
|
||||
self.assertNotIn("field", m._values)
|
||||
self.assertEqual(m.field, None)
|
||||
|
||||
m.field = "text"
|
||||
self.assertEqual(field.is_set, True)
|
||||
|
||||
self.assertIn("field", m._values)
|
||||
self.assertEqual(m.field, "text")
|
||||
|
||||
def test_str(self):
|
||||
@ -95,7 +101,6 @@ class Test(unittest.TestCase):
|
||||
field = m.__fields__["field"]
|
||||
self.assertEqual(field.is_model, False)
|
||||
self.assertEqual(field.is_optional, False)
|
||||
self.assertEqual(field.is_set, False)
|
||||
self.assertEqual(field.origin_type, str)
|
||||
|
||||
self.assertEqual(m.field, "default")
|
||||
@ -111,7 +116,6 @@ class Test(unittest.TestCase):
|
||||
field = m.__fields__["field"]
|
||||
self.assertEqual(field.is_model, False)
|
||||
self.assertEqual(field.is_optional, True)
|
||||
self.assertEqual(field.is_set, False)
|
||||
self.assertEqual(field.origin_type, str)
|
||||
|
||||
self.assertEqual(m.field, None)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import unittest
|
||||
|
||||
@ -74,12 +73,7 @@ Key : Value
|
||||
|
||||
class TestPrintMsg(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# reset the global `config` in preparation for running the tests
|
||||
importlib.reload(osc.conf)
|
||||
|
||||
def tearDown(self):
|
||||
# reset the global `config` to avoid impacting tests from other classes
|
||||
importlib.reload(osc.conf)
|
||||
osc.conf.config = osc.conf.Options()
|
||||
|
||||
def test_debug(self):
|
||||
osc.conf.config["debug"] = False
|
||||
|
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@ -11,7 +10,7 @@ from .common import patch
|
||||
|
||||
class TestVC(unittest.TestCase):
|
||||
def setUp(self):
|
||||
importlib.reload(osc.conf)
|
||||
osc.conf.config = osc.conf.Options()
|
||||
|
||||
config = osc.conf.config
|
||||
host_options = osc.conf.HostOptions(
|
||||
@ -21,8 +20,6 @@ class TestVC(unittest.TestCase):
|
||||
config["apiurl"] = host_options["apiurl"]
|
||||
self.host_options = host_options
|
||||
|
||||
def tearDown(self):
|
||||
importlib.reload(osc.conf)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_vc_export_env_conf(self):
|
||||
|
Loading…
Reference in New Issue
Block a user