diff --git a/py/builtinevex.c b/py/builtinevex.c index 173978ef52..97eab7fad9 100644 --- a/py/builtinevex.c +++ b/py/builtinevex.c @@ -45,12 +45,18 @@ STATIC MP_DEFINE_CONST_OBJ_TYPE( ); STATIC mp_obj_t code_execute(mp_obj_code_t *self, mp_obj_dict_t *globals, mp_obj_dict_t *locals) { - // save context and set new context - mp_obj_dict_t *old_globals = mp_globals_get(); - mp_obj_dict_t *old_locals = mp_locals_get(); + // save context + nlr_jump_callback_node_globals_locals_t ctx; + ctx.globals = mp_globals_get(); + ctx.locals = mp_locals_get(); + + // set new context mp_globals_set(globals); mp_locals_set(locals); + // set exception handler to restore context if an exception is raised + nlr_push_jump_callback(&ctx.callback, mp_globals_locals_set_from_nlr_jump_callback); + // a bit of a hack: fun_bc will re-set globals, so need to make sure it's // the correct one if (mp_obj_is_type(self->module_fun, &mp_type_fun_bc)) { @@ -59,19 +65,13 @@ STATIC mp_obj_t code_execute(mp_obj_code_t *self, mp_obj_dict_t *globals, mp_obj } // execute code - nlr_buf_t nlr; - if (nlr_push(&nlr) == 0) { - mp_obj_t ret = mp_call_function_0(self->module_fun); - nlr_pop(); - mp_globals_set(old_globals); - mp_locals_set(old_locals); - return ret; - } else { - // exception; restore context and re-raise same exception - mp_globals_set(old_globals); - mp_locals_set(old_locals); - nlr_jump(nlr.ret_val); - } + mp_obj_t ret = mp_call_function_0(self->module_fun); + + // deregister exception handler and restore context + nlr_pop_jump_callback(true); + + // return value + return ret; } STATIC mp_obj_t mp_builtin_compile(size_t n_args, const mp_obj_t *args) { diff --git a/py/builtinimport.c b/py/builtinimport.c index f6c2f7c348..de4ea17f38 100644 --- a/py/builtinimport.c +++ b/py/builtinimport.c @@ -184,28 +184,23 @@ STATIC void do_execute_raw_code(const mp_module_context_t *context, const mp_raw mp_obj_dict_t *mod_globals = context->module.globals; // save context - mp_obj_dict_t *volatile old_globals = mp_globals_get(); - mp_obj_dict_t *volatile old_locals = mp_locals_get(); + nlr_jump_callback_node_globals_locals_t ctx; + ctx.globals = mp_globals_get(); + ctx.locals = mp_locals_get(); // set new context mp_globals_set(mod_globals); mp_locals_set(mod_globals); - nlr_buf_t nlr; - if (nlr_push(&nlr) == 0) { - mp_obj_t module_fun = mp_make_function_from_raw_code(rc, context, NULL); - mp_call_function_0(module_fun); + // set exception handler to restore context if an exception is raised + nlr_push_jump_callback(&ctx.callback, mp_globals_locals_set_from_nlr_jump_callback); - // finish nlr block, restore context - nlr_pop(); - mp_globals_set(old_globals); - mp_locals_set(old_locals); - } else { - // exception; restore context and re-raise same exception - mp_globals_set(old_globals); - mp_locals_set(old_locals); - nlr_jump(nlr.ret_val); - } + // make and execute the function + mp_obj_t module_fun = mp_make_function_from_raw_code(rc, context, NULL); + mp_call_function_0(module_fun); + + // deregister exception handler and restore context + nlr_pop_jump_callback(true); } #endif diff --git a/py/runtime.c b/py/runtime.c index cbc7fb9311..2326dfb3ca 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -188,6 +188,12 @@ void mp_deinit(void) { #endif } +void mp_globals_locals_set_from_nlr_jump_callback(void *ctx_in) { + nlr_jump_callback_node_globals_locals_t *ctx = ctx_in; + mp_globals_set(ctx->globals); + mp_locals_set(ctx->locals); +} + mp_obj_t MICROPY_WRAP_MP_LOAD_NAME(mp_load_name)(qstr qst) { // logic: search locals, globals, builtins DEBUG_OP_printf("load name %s\n", qstr_str(qst)); @@ -1582,39 +1588,35 @@ void mp_import_all(mp_obj_t module) { mp_obj_t mp_parse_compile_execute(mp_lexer_t *lex, mp_parse_input_kind_t parse_input_kind, mp_obj_dict_t *globals, mp_obj_dict_t *locals) { // save context - mp_obj_dict_t *volatile old_globals = mp_globals_get(); - mp_obj_dict_t *volatile old_locals = mp_locals_get(); + nlr_jump_callback_node_globals_locals_t ctx; + ctx.globals = mp_globals_get(); + ctx.locals = mp_locals_get(); // set new context mp_globals_set(globals); mp_locals_set(locals); - nlr_buf_t nlr; - if (nlr_push(&nlr) == 0) { - qstr source_name = lex->source_name; - mp_parse_tree_t parse_tree = mp_parse(lex, parse_input_kind); - mp_obj_t module_fun = mp_compile(&parse_tree, source_name, parse_input_kind == MP_PARSE_SINGLE_INPUT); + // set exception handler to restore context if an exception is raised + nlr_push_jump_callback(&ctx.callback, mp_globals_locals_set_from_nlr_jump_callback); - mp_obj_t ret; - if (MICROPY_PY_BUILTINS_COMPILE && globals == NULL) { - // for compile only, return value is the module function - ret = module_fun; - } else { - // execute module function and get return value - ret = mp_call_function_0(module_fun); - } + qstr source_name = lex->source_name; + mp_parse_tree_t parse_tree = mp_parse(lex, parse_input_kind); + mp_obj_t module_fun = mp_compile(&parse_tree, source_name, parse_input_kind == MP_PARSE_SINGLE_INPUT); - // finish nlr block, restore context and return value - nlr_pop(); - mp_globals_set(old_globals); - mp_locals_set(old_locals); - return ret; + mp_obj_t ret; + if (MICROPY_PY_BUILTINS_COMPILE && globals == NULL) { + // for compile only, return value is the module function + ret = module_fun; } else { - // exception; restore context and re-raise same exception - mp_globals_set(old_globals); - mp_locals_set(old_locals); - nlr_jump(nlr.ret_val); + // execute module function and get return value + ret = mp_call_function_0(module_fun); } + + // deregister exception handler and restore context + nlr_pop_jump_callback(true); + + // return value + return ret; } #endif // MICROPY_ENABLE_COMPILER diff --git a/py/runtime.h b/py/runtime.h index 9b99e594b0..78194973d6 100644 --- a/py/runtime.h +++ b/py/runtime.h @@ -66,6 +66,13 @@ typedef struct _mp_sched_node_t { struct _mp_sched_node_t *next; } mp_sched_node_t; +// For use with mp_globals_locals_set_from_nlr_jump_callback. +typedef struct _nlr_jump_callback_node_globals_locals_t { + nlr_jump_callback_node_t callback; + mp_obj_dict_t *globals; + mp_obj_dict_t *locals; +} nlr_jump_callback_node_globals_locals_t; + // Tables mapping operator enums to qstrs, defined in objtype.c extern const byte mp_unary_op_method_name[]; extern const byte mp_binary_op_method_name[]; @@ -113,6 +120,8 @@ static inline void mp_globals_set(mp_obj_dict_t *d) { MP_STATE_THREAD(dict_globals) = d; } +void mp_globals_locals_set_from_nlr_jump_callback(void *ctx_in); + mp_obj_t mp_load_name(qstr qst); mp_obj_t mp_load_global(qstr qst); mp_obj_t mp_load_build_class(void);