From 570744d06c5ba9dba59b4c3f432ca4f0abd396b6 Mon Sep 17 00:00:00 2001 From: Damien George Date: Mon, 12 Jan 2026 11:04:09 +1100 Subject: [PATCH] py/runtime: Make import-all support non-modules via __dict__/__all__. Prior to this fix, `mp_import_all()` assumed that its argument was exactly a native module instance. That would lead to a crash if something else was passed in, eg a user class via a custom `__import__` implementation or by writing to `sys.modules`. MicroPython already supports injecting non-module objects into the import machinery, so it makes sense to round out that implementation by supporting `from x import *` where `x` is a non-module object. Fixes issue #18639. Signed-off-by: Damien George --- py/runtime.c | 21 ++++++--- tests/basics/import_star_nonmodule.py | 65 +++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 tests/basics/import_star_nonmodule.py diff --git a/py/runtime.c b/py/runtime.c index 8d7e11e5f0dab..3fc35c8c2daf6 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -1592,19 +1592,19 @@ mp_obj_t mp_import_from(mp_obj_t module, qstr name) { void mp_import_all(mp_obj_t module) { DEBUG_printf("import all %p\n", module); - mp_map_t *map = &mp_obj_module_get_globals(module)->map; + mp_obj_t dest[2]; #if MICROPY_MODULE___ALL__ - mp_map_elem_t *elem = mp_map_lookup(map, MP_OBJ_NEW_QSTR(MP_QSTR___all__), MP_MAP_LOOKUP); - if (elem != NULL) { + + mp_load_method_maybe(module, MP_QSTR___all__, dest); + if (dest[0] != MP_OBJ_NULL) { // When __all__ is defined, we must explicitly load all specified // symbols, possibly invoking the module __getattr__ function size_t len; mp_obj_t *items; - mp_obj_get_array(elem->value, &len, &items); + mp_obj_get_array(dest[0], &len, &items); for (size_t i = 0; i < len; i++) { qstr qname = mp_obj_str_get_qstr(items[i]); - mp_obj_t dest[2]; mp_load_method(module, qname, dest); mp_store_name(qname, dest[0]); } @@ -1612,8 +1612,19 @@ void mp_import_all(mp_obj_t module) { } #endif + #if MICROPY_CPYTHON_COMPAT + // Load the dict from the module. In MicroPython, if __dict__ is + // available then it always returns a native mp_obj_dict_t instance. + mp_load_method(module, MP_QSTR___dict__, dest); + #else + // Without MICROPY_CPYTHON_COMPAT __dict__ is not available, so just + // assume the given module is actually an mp_obj_module_t instance. + dest[0] = MP_OBJ_FROM_PTR(mp_obj_module_get_globals(module)); + #endif + // By default, the set of public names includes all names found in the module's // namespace which do not begin with an underscore character ('_') + mp_map_t *map = mp_obj_dict_get_map(dest[0]); for (size_t i = 0; i < map->alloc; i++) { if (mp_map_slot_is_filled(map, i)) { // Entry in module global scope may be generated programmatically diff --git a/tests/basics/import_star_nonmodule.py b/tests/basics/import_star_nonmodule.py new file mode 100644 index 0000000000000..8a98ef26ce544 --- /dev/null +++ b/tests/basics/import_star_nonmodule.py @@ -0,0 +1,65 @@ +# Test "from x import *" where x is something other than a module. + +import sys + +try: + next(iter([]), 42) +except TypeError: + # Two-argument version of next() not supported. We are probably not at + # MICROPY_CONFIG_ROM_LEVEL_BASIC_FEATURES which is needed for "import *". + print("SKIP") + raise SystemExit + +print("== test with a class as a module ==") + + +class M: + x = "a1" + + def __init__(self): + self.x = "a2" + + +sys.modules["mod"] = M +from mod import * + +print(x) + +sys.modules["mod"] = M() +from mod import * + +print(x) + +print("== test with a class as a module that overrides __all__ ==") + + +class M: + __all__ = ("y",) + x = "b1" + y = "b2" + + def __init__(self): + self.__all__ = ("x",) + self.x = "b3" + self.y = "b4" + + +sys.modules["mod"] = M +x = None +from mod import * + +print(x, y) + +sys.modules["mod"] = M() +from mod import * + +print(x, y) + +print("== test with objects that don't have a __dict__ ==") + +sys.modules["mod"] = 1 +try: + from mod import * + # MicroPython raises AttributeError, CPython raises ImportError. +except (AttributeError, ImportError): + print("ImportError")