From 49cec95fddf2e2326ef02ccb043ac18b2678b169 Mon Sep 17 00:00:00 2001 From: David Lechner Date: Sun, 7 May 2023 16:19:33 -0500 Subject: [PATCH] py/{compile,runtime}: Fix *args after kwarg. This fixes a compiler/runtime bug where *args after a kwarg was not handled correctly. Prior to this change, if `*args` was encountered in a function call after a keyword argument, the compiler would push a single object and increment the positional arg count. However, two objects for the keyword argument key and value had already been pushed. This caused inconsistencies that the runtime could not resolve since it expects all of the positional args first followed by key/value pairs for the keyword args. To fix it, we need to conditionally change what happens when `*args` is encountered depending on if it is before or after the first keyword argument. If it is before, everything is handled as before. If after, instead of pushing a single object and incrementing the positional arg count, we push two objects and increment the keyword arg count. This makes it possible for the runtime to handle it with minimal changes. In the runtime, we have to add some extra checks to handle the new case of the possibility that one of the `n_kw` args is a `*arg`. We already have a case where `**arg` is handled as a keyword argument where the key is `MP_OBJ_NULL`. We now do the same for `*arg` as well. The existing `star_args` flags is used to determine if the value corresponding to a key of `MP_OBJ_NULL` is `*arg` or `**arg`. A couple of test that failed before this fix are added. Fixes: https://github.com/micropython/micropython/issues/11439 Signed-off-by: David Lechner --- py/compile.c | 11 +++++++- py/runtime.c | 52 ++++++++++++++++++++++++++++-------- tests/basics/fun_callstar.py | 4 +++ 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/py/compile.c b/py/compile.c index bb7c1117fa..33a930e649 100644 --- a/py/compile.c +++ b/py/compile.c @@ -2401,8 +2401,17 @@ STATIC void compile_trailer_paren_helper(compiler_t *comp, mp_parse_node_t pn_ar } star_flags |= MP_EMIT_STAR_FLAG_SINGLE; star_args |= (mp_uint_t)1 << i; + + if (n_keyword == 0) { + // star-args before kwargs encoded as positional arg + n_positional++; + } else { + // star-args after kwargs encoded as kw arg with key=NULL + EMIT(load_null); + n_keyword++; + } + compile_node(comp, pns_arg->nodes[0]); - n_positional++; } else if (MP_PARSE_NODE_STRUCT_KIND(pns_arg) == PN_arglist_dbl_star) { star_flags |= MP_EMIT_STAR_FLAG_DOUBLE; // double-star args are stored as kw arg with key of None diff --git a/py/runtime.c b/py/runtime.c index 3c7c0350c1..2211a5e38c 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -735,17 +735,30 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_ mp_obj_t *args2; size_t args2_alloc; size_t args2_len = 0; + size_t n_args_star_args = n_args; // Try to get a hint for unpacked * args length ssize_t list_len = 0; - if (star_args != 0) { - for (size_t i = 0; i < n_args; i++) { - if ((star_args >> i) & 1) { - mp_obj_t len = mp_obj_len_maybe(args[i]); - if (len != MP_OBJ_NULL) { + if (star_args) { + // kw can also contain star args. + n_args_star_args += n_kw; + + for (size_t i = 0; i < n_args_star_args; i++) { + if (!((star_args >> i) & 1)) { + continue; + } + + mp_obj_t arg = i >= n_args ? args[n_args + 2 * (i - n_args) + 1] : args[i]; + + mp_obj_t len = mp_obj_len_maybe(arg); + + if (len != MP_OBJ_NULL) { + list_len += mp_obj_get_int(len); + + if (i < n_args) { // -1 accounts for 1 of n_args occupied by this arg - list_len += mp_obj_get_int(len) - 1; + list_len--; } } } @@ -757,9 +770,20 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_ for (size_t i = 0; i < n_kw; i++) { mp_obj_t key = args[n_args + i * 2]; mp_obj_t value = args[n_args + i * 2 + 1]; - if (key == MP_OBJ_NULL && value != MP_OBJ_NULL && mp_obj_is_type(value, &mp_type_dict)) { + + if (key == MP_OBJ_NULL) { // -1 accounts for 1 of n_kw occupied by this arg - kw_dict_len += mp_obj_dict_len(value) - 1; + kw_dict_len--; + + if (((star_args >> (n_args + i)) & 1)) { + // star args were already handled above + continue; + } + + // double-star args + if (mp_obj_is_type(value, &mp_type_dict)) { + kw_dict_len += mp_obj_dict_len(value); + } } } @@ -792,8 +816,9 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_ args2[args2_len++] = self; } - for (size_t i = 0; i < n_args; i++) { - mp_obj_t arg = args[i]; + for (size_t i = 0; i < n_args_star_args; i++) { + mp_obj_t arg = i >= n_args ? args[n_args + 2 * (i - n_args) + 1] : args[i]; + if ((star_args >> i) & 1) { // star arg if (mp_obj_is_type(arg, &mp_type_tuple) || mp_obj_is_type(arg, &mp_type_list)) { @@ -824,7 +849,7 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_ args2[args2_len++] = item; } } - } else { + } else if (i < n_args) { // normal argument assert(args2_len < args2_alloc); args2[args2_len++] = arg; @@ -848,6 +873,11 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_ mp_obj_t kw_key = args[n_args + i * 2]; mp_obj_t kw_value = args[n_args + i * 2 + 1]; if (kw_key == MP_OBJ_NULL) { + if ((star_args >> (n_args + i)) & 1) { + // star args have already been handled above + continue; + } + // double-star args if (mp_obj_is_type(kw_value, &mp_type_dict)) { // dictionary diff --git a/tests/basics/fun_callstar.py b/tests/basics/fun_callstar.py index 53d2ece3e1..9c3b199297 100644 --- a/tests/basics/fun_callstar.py +++ b/tests/basics/fun_callstar.py @@ -23,6 +23,10 @@ foo(*range(3)) # pos then iterator foo(1, *range(2, 4)) +# star after kw +foo(1, 2, c=3, *()) +foo(b=2, *(1,), c=3) + # an iterator with many elements def foo(*rest): print(rest)