diff --git a/osc/util/models.py b/osc/util/models.py index c2ab36c5..d0493da8 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -71,8 +71,9 @@ NotSet = NotSetClass() class FromParent(NotSetClass): - def __init__(self, field_name): + def __init__(self, field_name, *, fallback=NotSet): self.field_name = field_name + self.fallback = fallback def __repr__(self): return f"FromParent(field_name={self.field_name})" @@ -256,7 +257,7 @@ class Field(property): for num, i in enumerate(result): if isinstance(i, dict): klass = self.inner_type - result[num] = klass(**i) + result[num] = klass(**i, _parent=obj) if self.get_callback is not None: result = self.get_callback(obj, result) @@ -279,7 +280,10 @@ class Field(property): if isinstance(self.default, FromParent): if obj._parent is None: - raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set") + if self.default.fallback is not NotSet: + return self.default.fallback + else: + raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set") return getattr(obj._parent, self.default.field_name or self.name) if self.default is NotSet: @@ -308,20 +312,23 @@ class Field(property): if self.is_model and isinstance(value, dict): # initialize a model instance from a dictionary klass = self.origin_type - value = klass(**value) # pylint: disable=not-callable + value = klass(**value, _parent=obj) # pylint: disable=not-callable elif self.is_model_list and isinstance(value, list): new_value = [] for i in value: if isinstance(i, dict): klass = self.inner_type - new_value.append(klass(**i)) + new_value.append(klass(**i, _parent=obj)) else: + i._parent = obj new_value.append(i) value = new_value elif self.is_model and isinstance(value, str) and hasattr(self.origin_type, "XML_TAG_FIELD"): klass = self.origin_type key = getattr(self.origin_type, "XML_TAG_FIELD") - value = klass(**{key: value}) + value = klass(**{key: value}, _parent=obj) + elif self.is_model and value is not None: + value._parent = obj self.validate_type(value) obj._values[self.name] = value diff --git a/tests/test_models.py b/tests/test_models.py index 84ae542b..40cf2ff7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -291,6 +291,49 @@ class Test(unittest.TestCase): self.assertEqual(c.field, "new-text") self.assertEqual(c.field2, "text") + def test_parent_fallback(self): + class SubModel(BaseModel): + field: str = Field(default=FromParent("field", fallback="submodel-fallback")) + + class Model(BaseModel): + field: str = Field(default=FromParent("field", fallback="model-fallback")) + sub: Optional[SubModel] = Field() + sub_list: Optional[List[SubModel]] = Field() + + m = Model() + s = SubModel(_parent=m) + m.sub = s + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub.field, "model-fallback") + + m = Model(sub={}) + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub.field, "model-fallback") + + m = Model(sub=SubModel()) + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub.field, "model-fallback") + + m = Model() + s = SubModel(_parent=m) + m.sub_list = [s] + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub_list[0].field, "model-fallback") + + m = Model(sub_list=[{}]) + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub_list[0].field, "model-fallback") + + m = Model(sub_list=[SubModel()]) + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub_list[0].field, "model-fallback") + + m = Model() + m.sub_list = [] + m.sub_list.append({}) + self.assertEqual(m.field, "model-fallback") + self.assertEqual(m.sub_list[0].field, "model-fallback") + def test_get_callback(self): class Model(BaseModel): quiet: bool = Field(