diff --git a/py/builtinimport.c b/py/builtinimport.c index de4ea17f38..8827be6123 100644 --- a/py/builtinimport.c +++ b/py/builtinimport.c @@ -346,6 +346,17 @@ STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, s *module_name_len = new_module_name_len; } +typedef struct _nlr_jump_callback_node_unregister_module_t { + nlr_jump_callback_node_t callback; + qstr name; +} nlr_jump_callback_node_unregister_module_t; + +STATIC void unregister_module_from_nlr_jump_callback(void *ctx_in) { + nlr_jump_callback_node_unregister_module_t *ctx = ctx_in; + mp_map_t *mp_loaded_modules_map = &MP_STATE_VM(mp_loaded_modules_dict).map; + mp_map_lookup(mp_loaded_modules_map, MP_OBJ_NEW_QSTR(ctx->name), MP_MAP_LOOKUP_REMOVE_IF_FOUND); +} + // Load a module at the specified absolute path, possibly as a submodule of the given outer module. // full_mod_name: The full absolute path up to this level (e.g. "foo.bar.baz"). // level_mod_name: The final component of the path (e.g. "baz"). @@ -467,8 +478,13 @@ STATIC mp_obj_t process_import_at_level(qstr full_mod_name, qstr level_mod_name, // Module was found on the filesystem/frozen, try and load it. DEBUG_printf("Found path to load: %.*s\n", (int)vstr_len(&path), vstr_str(&path)); - // Prepare for loading from the filesystem. Create a new shell module. + // Prepare for loading from the filesystem. Create a new shell module + // and register it in sys.modules. Also make sure we remove it if + // there is any problem below. module_obj = mp_obj_new_module(full_mod_name); + nlr_jump_callback_node_unregister_module_t ctx; + ctx.name = full_mod_name; + nlr_push_jump_callback(&ctx.callback, unregister_module_from_nlr_jump_callback); #if MICROPY_MODULE_OVERRIDE_MAIN_IMPORT // If this module is being loaded via -m on unix, then @@ -526,6 +542,8 @@ STATIC mp_obj_t process_import_at_level(qstr full_mod_name, qstr level_mod_name, mp_store_attr(outer_module_obj, level_mod_name, module_obj); } + nlr_pop_jump_callback(false); + return module_obj; } diff --git a/tests/cpydiff/core_import_prereg.py b/tests/cpydiff/core_import_prereg.py deleted file mode 100644 index 3ce2340c68..0000000000 --- a/tests/cpydiff/core_import_prereg.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -categories: Core,import -description: Failed to load modules are still registered as loaded -cause: To make module handling more efficient, it's not wrapped with exception handling. -workaround: Test modules before production use; during development, use ``del sys.modules["name"]``, or just soft or hard reset the board. -""" -import sys - -try: - from modules import foo -except NameError as e: - print(e) -try: - from modules import foo - - print("Should not get here") -except NameError as e: - print(e) diff --git a/tests/import/broken/pkg2_and_zerodiv.py b/tests/import/broken/pkg2_and_zerodiv.py new file mode 100644 index 0000000000..3580628ff5 --- /dev/null +++ b/tests/import/broken/pkg2_and_zerodiv.py @@ -0,0 +1,2 @@ +import pkg2 +import broken.zerodiv diff --git a/tests/import/broken/zerodiv.py b/tests/import/broken/zerodiv.py new file mode 100644 index 0000000000..72dca4d5e4 --- /dev/null +++ b/tests/import/broken/zerodiv.py @@ -0,0 +1 @@ +1 / 0 diff --git a/tests/import/circular/main.py b/tests/import/circular/main.py new file mode 100644 index 0000000000..5d63d507c3 --- /dev/null +++ b/tests/import/circular/main.py @@ -0,0 +1,4 @@ +x = 1 +import circular.sub + +print(circular.sub.y) diff --git a/tests/import/circular/sub.py b/tests/import/circular/sub.py new file mode 100644 index 0000000000..50d7afe07b --- /dev/null +++ b/tests/import/circular/sub.py @@ -0,0 +1,3 @@ +from circular.main import x + +y = x + 20 diff --git a/tests/import/import_broken.py b/tests/import/import_broken.py new file mode 100644 index 0000000000..3c7cf4a498 --- /dev/null +++ b/tests/import/import_broken.py @@ -0,0 +1,32 @@ +import sys, pkg + +# Modules we import are usually added to sys.modules. +print("pkg" in sys.modules) + +try: + from broken.zerodiv import x +except Exception as e: + print(e.__class__.__name__) + +# The broken module we tried to import should not be in sys.modules. +print("broken.zerodiv" in sys.modules) + +# If we try to import the module again, the code should +# run again and we should get the same error. +try: + from broken.zerodiv import x +except Exception as e: + print(e.__class__.__name__) + +# Import a module that successfully imports some other modules +# before importing the problematic module. +try: + import broken.pkg2_and_zerodiv +except ZeroDivisionError: + pass + +print("pkg2" in sys.modules) +print("pkg2.mod1" in sys.modules) +print("pkg2.mod2" in sys.modules) +print("broken.zerodiv" in sys.modules) +print("broken.pkg2_and_zerodiv" in sys.modules) diff --git a/tests/import/import_circular.py b/tests/import/import_circular.py new file mode 100644 index 0000000000..388efdd130 --- /dev/null +++ b/tests/import/import_circular.py @@ -0,0 +1 @@ +import circular.main