From 8a38a9da8202fe4ca9e0dbb8f2e9bff6e64d4a8c Mon Sep 17 00:00:00 2001 From: Daniel Mach Date: Wed, 24 Jan 2024 09:49:23 +0100 Subject: [PATCH] Implement get_callback that allows modifying returned value to the Field class --- osc/util/models.py | 16 ++++++++++++++-- tests/test_models.py | 25 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/osc/util/models.py b/osc/util/models.py index 46f2f422..d0f8dfb0 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -10,6 +10,7 @@ import copy import inspect import sys import types +from typing import Callable from typing import get_type_hints # supported types @@ -76,6 +77,7 @@ class Field(property): default: Any = NotSet, description: Optional[str] = None, exclude: bool = False, + get_callback: Optional[Callable] = None, **extra, ): # the default value; it can be a factory function that is lazily evaluated on the first use @@ -106,6 +108,10 @@ class Field(property): # whether to exclude this field from export self.exclude = exclude + # optional callback to postprocess returned field value + # it takes (model_instance, value) and returns modified value + self.get_callback = get_callback + # extra fields self.extra = extra @@ -235,12 +241,18 @@ class Field(property): def get(self, obj): try: - return obj._values[self.name] + result = obj._values[self.name] + if self.get_callback is not None: + result = self.get_callback(obj, result) + return result except KeyError: pass try: - return obj._defaults[self.name] + result = obj._defaults[self.name] + if self.get_callback is not None: + result = self.get_callback(obj, result) + return result except KeyError: pass diff --git a/tests/test_models.py b/tests/test_models.py index add9bd55..df6e5125 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -291,6 +291,31 @@ class Test(unittest.TestCase): self.assertEqual(c.field, "new-text") self.assertEqual(c.field2, "text") + def test_get_callback(self): + class Model(BaseModel): + quiet: bool = Field( + default=False, + ) + verbose: bool = Field( + default=False, + # return False if ``quiet`` is True; return the actual value otherwise + get_callback=lambda obj, value: False if obj.quiet else value, + ) + + m = Model() + self.assertEqual(m.quiet, False) + self.assertEqual(m.verbose, False) + + m.quiet = True + m.verbose = True + self.assertEqual(m.quiet, True) + self.assertEqual(m.verbose, False) + + m.quiet = False + m.verbose = True + self.assertEqual(m.quiet, False) + self.assertEqual(m.verbose, True) + if __name__ == "__main__": unittest.main()