diff --git a/osc/conf.py b/osc/conf.py index 07dfc0b9..779799e6 100644 --- a/osc/conf.py +++ b/osc/conf.py @@ -126,7 +126,9 @@ HttpHeader = NewType("HttpHeader", Tuple[str, str]) class OscOptions(BaseModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.extra_fields = {} + self._allow_new_attributes = True + self._extra_fields = {} + self._allow_new_attributes = False # compat function with the config dict def _get_field_name(self, name): @@ -145,7 +147,7 @@ class OscOptions(BaseModel): field_name = self._get_field_name(name) if field_name is None and not hasattr(self, name): - return self.extra_fields[name] + return self._extra_fields[name] field_name = field_name or name try: @@ -158,7 +160,7 @@ class OscOptions(BaseModel): field_name = self._get_field_name(name) if field_name is None and not hasattr(self, name): - self.extra_fields[name] = value + self._extra_fields[name] = value return field_name = field_name or name diff --git a/osc/util/models.py b/osc/util/models.py index f0926000..125557ba 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -6,6 +6,7 @@ It works on python 3.6+. This module IS NOT a supported API, it is meant for osc internal use only. """ +import copy import inspect import sys import types @@ -82,9 +83,6 @@ class Field(property): # a flag indicating, whether the default is a callable with lazy evalution self.default_is_lazy = callable(self.default) - # whether the field was set - self.is_set = False - # the name of model's attribute associated with this field instance - set from the model self.name = None @@ -205,8 +203,13 @@ class Field(property): def get(self, obj): try: - return getattr(obj, f"_{self.name}") - except AttributeError: + return obj._values[self.name] + except KeyError: + pass + + try: + return obj._defaults[self.name] + except KeyError: pass if isinstance(self.default, FromParent): @@ -217,18 +220,23 @@ class Field(property): if self.default is NotSet: raise RuntimeError(f"The field '{self.name}' has no default") + # make a deepcopy to avoid problems with mutable defaults + default = copy.deepcopy(self.default) + # lazy evaluation of a factory function on first use - if callable(self.default): - self.default = self.default() + if callable(default): + default = default() # if this is a model field, convert dict to a model instance - if self.is_model and isinstance(self.default, dict): - new_value = self.origin_type() # pylint: disable=not-callable - for k, v in self.default.items(): + if self.is_model and isinstance(default, dict): + cls = self.origin_type + new_value = cls() # pylint: disable=not-callable + for k, v in default.items(): setattr(new_value, k, v) - self.default = new_value + default = new_value - return self.default + obj._defaults[self.name] = default + return default def set(self, obj, value): # if this is a model field, convert dict to a model instance @@ -239,8 +247,7 @@ class Field(property): value = new_value self.validate_type(value) - setattr(obj, f"_{self.name}", value) - self.is_set = True + obj._values[self.name] = value class ModelMeta(type): @@ -280,7 +287,16 @@ class ModelMeta(type): class BaseModel(metaclass=ModelMeta): __fields__: Dict[str, Field] + def __setattr__(self, name, value): + if getattr(self, "_allow_new_attributes", True) or hasattr(self.__class__, name) or hasattr(self, name): + # allow setting properties - test if they exist in the class + # also allow setting existing attributes that were previously initialized via __dict__ + return super().__setattr__(name, value) + raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed") + def __init__(self, **kwargs): + self._defaults = {} # field defaults cached in field.get() + self._values = {} # field values explicitly set after initializing the model self._parent = kwargs.pop("_parent", None) uninitialized_fields = [] @@ -306,16 +322,15 @@ class BaseModel(metaclass=ModelMeta): for name, field in self.__fields__.items(): field.validate_type(getattr(self, name)) - def dict(self, exclude_unset=False): + self._allow_new_attributes = False + + def dict(self): result = {} for name, field in self.__fields__.items(): if field.exclude: continue - if exclude_unset and not field.is_set and field.is_optional: - # include only mandatory fields and optional fields that were set to an actual value - continue if field.is_model: - result[name] = getattr(self, name).dict(exclude_unset=exclude_unset) + result[name] = getattr(self, name).dict() else: result[name] = getattr(self, name) return result diff --git a/tests/test_build.py b/tests/test_build.py index 18e8c227..99dd5044 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,4 +1,3 @@ -import importlib import unittest import osc.conf @@ -8,12 +7,7 @@ from osc.oscerr import UserAbort class TestTrustedProjects(unittest.TestCase): def setUp(self): - # reset the global `config` in preparation for running the tests - importlib.reload(osc.conf) - - def tearDown(self): - # reset the global `config` to avoid impacting tests from other classes - importlib.reload(osc.conf) + osc.conf.config = osc.conf.Options() def test_name(self): apiurl = "https://example.com" diff --git a/tests/test_conf.py b/tests/test_conf.py index 5d6a41d2..bb7cba38 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1,4 +1,3 @@ -import importlib import os import shutil import tempfile @@ -105,7 +104,6 @@ plugin-option = plugin-host-option class TestExampleConfig(unittest.TestCase): def setUp(self): - importlib.reload(osc.conf) self.tmpdir = tempfile.mkdtemp(prefix="osc_test_") self.oscrc = os.path.join(self.tmpdir, "oscrc") with open(self.oscrc, "w", encoding="utf-8") as f: @@ -116,6 +114,9 @@ class TestExampleConfig(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmpdir) + def test_invalid_attribute(self): + self.assertRaises(AttributeError, setattr, self.config, "new_attribute", "123") + def test_apiurl(self): self.assertEqual(self.config["apiurl"], "https://api.opensuse.org") @@ -407,26 +408,19 @@ class TestExampleConfig(unittest.TestCase): def test_extra_fields(self): self.assertEqual(self.config["plugin-option"], "plugin-general-option") - self.assertEqual(self.config.extra_fields, {"plugin-option": "plugin-general-option"}) - - # write to an existing attribute instead of extra_fields - self.config.attrib = 123 - self.assertEqual(self.config["attrib"], 123) - self.config["attrib"] = 456 - self.assertEqual(self.config["attrib"], 456) - self.assertEqual(self.config.extra_fields, {"plugin-option": "plugin-general-option"}) + self.assertEqual(self.config._extra_fields, {"plugin-option": "plugin-general-option"}) self.config["new-option"] = "value" self.assertEqual(self.config["new-option"], "value") - self.assertEqual(self.config.extra_fields, {"plugin-option": "plugin-general-option", "new-option": "value"}) + self.assertEqual(self.config._extra_fields, {"plugin-option": "plugin-general-option", "new-option": "value"}) host_options = self.config["api_host_options"][self.config["apiurl"]] self.assertEqual(host_options["plugin-option"], "plugin-host-option") - self.assertEqual(host_options.extra_fields, {"plugin-option": "plugin-host-option"}) + self.assertEqual(host_options._extra_fields, {"plugin-option": "plugin-host-option"}) host_options["new-option"] = "value" self.assertEqual(host_options["new-option"], "value") - self.assertEqual(host_options.extra_fields, {"plugin-option": "plugin-host-option", "new-option": "value"}) + self.assertEqual(host_options._extra_fields, {"plugin-option": "plugin-host-option", "new-option": "value"}) def test_apiurl_aliases(self): expected = {"https://api.opensuse.org": "https://api.opensuse.org", "osc": "https://api.opensuse.org"} @@ -485,6 +479,15 @@ class TestConf(unittest.TestCase): } osc.conf.write_initial_config(conffile, entries) + def test_api_host_options(self): + # test that instances do not share any references leaked from the defaults + conf1 = osc.conf.Options() + conf2 = osc.conf.Options() + + self.assertNotEqual(conf1, conf2) + self.assertNotEqual(id(conf1), id(conf2)) + self.assertNotEqual(id(conf1.api_host_options), id(conf2.api_host_options)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_grabber.py b/tests/test_grabber.py index f88cd872..25493b83 100644 --- a/tests/test_grabber.py +++ b/tests/test_grabber.py @@ -1,4 +1,3 @@ -import importlib import os import tempfile import unittest @@ -13,14 +12,10 @@ FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "conf_fixtures") class TestMirrorGroup(unittest.TestCase): def setUp(self): self.tmpdir = tempfile.mkdtemp(prefix='osc_test') - # reset the global `config` in preparation for running the tests - importlib.reload(osc.conf) oscrc = os.path.join(self._get_fixtures_dir(), "oscrc") osc.conf.get_config(override_conffile=oscrc, override_no_keyring=True) def tearDown(self): - # reset the global `config` to avoid impacting tests from other classes - importlib.reload(osc.conf) try: shutil.rmtree(self.tmpdir) except: diff --git a/tests/test_models.py b/tests/test_models.py index e3d6895e..ccde8a4b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,16 +24,21 @@ class TestNotSet(unittest.TestCase): class Test(unittest.TestCase): - def test_modified(self): + def test_dict(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + class TestModel(BaseModel): a: str = Field(default="default") b: Optional[str] = Field(default=None) + sub: Optional[List[TestSubmodel]] = Field(default=None) m = TestModel() - self.assertEqual(m.dict(exclude_unset=True), {"a": "default"}) + self.assertEqual(m.dict(), {"a": "default", "b": None, "sub": None}) - m = TestModel(b=None) - self.assertEqual(m.dict(exclude_unset=True), {"a": "default", "b": None}) + m.b = "B" + m.sub = [{"text": "one"}, {"text": "two"}] + self.assertEqual(m.dict(), {"a": "default", "b": "B", "sub": [{"text": "one"}, {"text": "two"}]}) def test_unknown_fields(self): class TestModel(BaseModel): @@ -79,11 +84,12 @@ class Test(unittest.TestCase): m = TestModel() - field = m.__fields__["field"] - self.assertEqual(field.is_set, False) + self.assertNotIn("field", m._values) self.assertEqual(m.field, None) + m.field = "text" - self.assertEqual(field.is_set, True) + + self.assertIn("field", m._values) self.assertEqual(m.field, "text") def test_str(self): @@ -95,7 +101,6 @@ class Test(unittest.TestCase): field = m.__fields__["field"] self.assertEqual(field.is_model, False) self.assertEqual(field.is_optional, False) - self.assertEqual(field.is_set, False) self.assertEqual(field.origin_type, str) self.assertEqual(m.field, "default") @@ -111,7 +116,6 @@ class Test(unittest.TestCase): field = m.__fields__["field"] self.assertEqual(field.is_model, False) self.assertEqual(field.is_optional, True) - self.assertEqual(field.is_set, False) self.assertEqual(field.origin_type, str) self.assertEqual(m.field, None) diff --git a/tests/test_output.py b/tests/test_output.py index 8c02adbc..1b7337ec 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -1,5 +1,4 @@ import contextlib -import importlib import io import unittest @@ -74,12 +73,7 @@ Key : Value class TestPrintMsg(unittest.TestCase): def setUp(self): - # reset the global `config` in preparation for running the tests - importlib.reload(osc.conf) - - def tearDown(self): - # reset the global `config` to avoid impacting tests from other classes - importlib.reload(osc.conf) + osc.conf.config = osc.conf.Options() def test_debug(self): osc.conf.config["debug"] = False diff --git a/tests/test_vc.py b/tests/test_vc.py index 491a677a..9da14b55 100644 --- a/tests/test_vc.py +++ b/tests/test_vc.py @@ -1,4 +1,3 @@ -import importlib import os import unittest @@ -11,7 +10,7 @@ from .common import patch class TestVC(unittest.TestCase): def setUp(self): - importlib.reload(osc.conf) + osc.conf.config = osc.conf.Options() config = osc.conf.config host_options = osc.conf.HostOptions( @@ -21,8 +20,6 @@ class TestVC(unittest.TestCase): config["apiurl"] = host_options["apiurl"] self.host_options = host_options - def tearDown(self): - importlib.reload(osc.conf) @patch.dict(os.environ, {}, clear=True) def test_vc_export_env_conf(self):