From a79a6ab3641e3536e509278706e7ac4de01a8126 Mon Sep 17 00:00:00 2001 From: David Grayson Date: Sun, 4 Jun 2023 19:55:41 -0700 Subject: [PATCH] py/builtinimport: Remove partially-loaded modules from sys.modules. Prior to this commit, importing a module that exists but has a syntax error or some other problem that happens at import time would result in a potentially-incomplete module object getting added to sys.modules. Subsequent imports would use that object, resulting in confusing error messages that hide the root cause of the problem. This commit fixes that issue by removing the failed module from sys.modules using the new NLR callback mechanism. Note that it is still important to add the module to sys.modules while the import is happening so that we can support circular imports just like CPython does. Fixes issue #967. Signed-off-by: David Grayson --- py/builtinimport.c | 20 +++++++++++++++- tests/cpydiff/core_import_prereg.py | 18 -------------- tests/import/broken/pkg2_and_zerodiv.py | 2 ++ tests/import/broken/zerodiv.py | 1 + tests/import/circular/main.py | 4 ++++ tests/import/circular/sub.py | 3 +++ tests/import/import_broken.py | 32 +++++++++++++++++++++++++ tests/import/import_circular.py | 1 + 8 files changed, 62 insertions(+), 19 deletions(-) delete mode 100644 tests/cpydiff/core_import_prereg.py create mode 100644 tests/import/broken/pkg2_and_zerodiv.py create mode 100644 tests/import/broken/zerodiv.py create mode 100644 tests/import/circular/main.py create mode 100644 tests/import/circular/sub.py create mode 100644 tests/import/import_broken.py create mode 100644 tests/import/import_circular.py diff --git a/py/builtinimport.c b/py/builtinimport.c index de4ea17f38..8827be6123 100644 --- a/py/builtinimport.c +++ b/py/builtinimport.c @@ -346,6 +346,17 @@ STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, s *module_name_len = new_module_name_len; } +typedef struct _nlr_jump_callback_node_unregister_module_t { + nlr_jump_callback_node_t callback; + qstr name; +} nlr_jump_callback_node_unregister_module_t; + +STATIC void unregister_module_from_nlr_jump_callback(void *ctx_in) { + nlr_jump_callback_node_unregister_module_t *ctx = ctx_in; + mp_map_t *mp_loaded_modules_map = &MP_STATE_VM(mp_loaded_modules_dict).map; + mp_map_lookup(mp_loaded_modules_map, MP_OBJ_NEW_QSTR(ctx->name), MP_MAP_LOOKUP_REMOVE_IF_FOUND); +} + // Load a module at the specified absolute path, possibly as a submodule of the given outer module. // full_mod_name: The full absolute path up to this level (e.g. "foo.bar.baz"). // level_mod_name: The final component of the path (e.g. "baz"). @@ -467,8 +478,13 @@ STATIC mp_obj_t process_import_at_level(qstr full_mod_name, qstr level_mod_name, // Module was found on the filesystem/frozen, try and load it. DEBUG_printf("Found path to load: %.*s\n", (int)vstr_len(&path), vstr_str(&path)); - // Prepare for loading from the filesystem. Create a new shell module. + // Prepare for loading from the filesystem. Create a new shell module + // and register it in sys.modules. Also make sure we remove it if + // there is any problem below. module_obj = mp_obj_new_module(full_mod_name); + nlr_jump_callback_node_unregister_module_t ctx; + ctx.name = full_mod_name; + nlr_push_jump_callback(&ctx.callback, unregister_module_from_nlr_jump_callback); #if MICROPY_MODULE_OVERRIDE_MAIN_IMPORT // If this module is being loaded via -m on unix, then @@ -526,6 +542,8 @@ STATIC mp_obj_t process_import_at_level(qstr full_mod_name, qstr level_mod_name, mp_store_attr(outer_module_obj, level_mod_name, module_obj); } + nlr_pop_jump_callback(false); + return module_obj; } diff --git a/tests/cpydiff/core_import_prereg.py b/tests/cpydiff/core_import_prereg.py deleted file mode 100644 index 3ce2340c68..0000000000 --- a/tests/cpydiff/core_import_prereg.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -categories: Core,import -description: Failed to load modules are still registered as loaded -cause: To make module handling more efficient, it's not wrapped with exception handling. -workaround: Test modules before production use; during development, use ``del sys.modules["name"]``, or just soft or hard reset the board. -""" -import sys - -try: - from modules import foo -except NameError as e: - print(e) -try: - from modules import foo - - print("Should not get here") -except NameError as e: - print(e) diff --git a/tests/import/broken/pkg2_and_zerodiv.py b/tests/import/broken/pkg2_and_zerodiv.py new file mode 100644 index 0000000000..3580628ff5 --- /dev/null +++ b/tests/import/broken/pkg2_and_zerodiv.py @@ -0,0 +1,2 @@ +import pkg2 +import broken.zerodiv diff --git a/tests/import/broken/zerodiv.py b/tests/import/broken/zerodiv.py new file mode 100644 index 0000000000..72dca4d5e4 --- /dev/null +++ b/tests/import/broken/zerodiv.py @@ -0,0 +1 @@ +1 / 0 diff --git a/tests/import/circular/main.py b/tests/import/circular/main.py new file mode 100644 index 0000000000..5d63d507c3 --- /dev/null +++ b/tests/import/circular/main.py @@ -0,0 +1,4 @@ +x = 1 +import circular.sub + +print(circular.sub.y) diff --git a/tests/import/circular/sub.py b/tests/import/circular/sub.py new file mode 100644 index 0000000000..50d7afe07b --- /dev/null +++ b/tests/import/circular/sub.py @@ -0,0 +1,3 @@ +from circular.main import x + +y = x + 20 diff --git a/tests/import/import_broken.py b/tests/import/import_broken.py new file mode 100644 index 0000000000..3c7cf4a498 --- /dev/null +++ b/tests/import/import_broken.py @@ -0,0 +1,32 @@ +import sys, pkg + +# Modules we import are usually added to sys.modules. +print("pkg" in sys.modules) + +try: + from broken.zerodiv import x +except Exception as e: + print(e.__class__.__name__) + +# The broken module we tried to import should not be in sys.modules. +print("broken.zerodiv" in sys.modules) + +# If we try to import the module again, the code should +# run again and we should get the same error. +try: + from broken.zerodiv import x +except Exception as e: + print(e.__class__.__name__) + +# Import a module that successfully imports some other modules +# before importing the problematic module. +try: + import broken.pkg2_and_zerodiv +except ZeroDivisionError: + pass + +print("pkg2" in sys.modules) +print("pkg2.mod1" in sys.modules) +print("pkg2.mod2" in sys.modules) +print("broken.zerodiv" in sys.modules) +print("broken.pkg2_and_zerodiv" in sys.modules) diff --git a/tests/import/import_circular.py b/tests/import/import_circular.py new file mode 100644 index 0000000000..388efdd130 --- /dev/null +++ b/tests/import/import_circular.py @@ -0,0 +1 @@ +import circular.main