diff --git a/py/compile.c b/py/compile.c index 47741b8da8..5959a8cbc9 100644 --- a/py/compile.c +++ b/py/compile.c @@ -61,6 +61,8 @@ typedef struct _compiler_t { int param_pass_num_dict_params; int param_pass_num_default_params; + bool func_arg_is_super; // used to compile special case of super() function call + scope_t *scope_head; scope_t *scope_cur; @@ -959,6 +961,7 @@ void compile_decorated(compiler_t *comp, mp_parse_node_struct_t *pns) { // nodes[1] contains arguments to the decorator function, if any if (!MP_PARSE_NODE_IS_NULL(pns_decorator->nodes[1])) { // call the decorator function with the arguments in nodes[1] + comp->func_arg_is_super = false; compile_node(comp, pns_decorator->nodes[1]); } } @@ -2062,9 +2065,38 @@ void compile_factor_2(compiler_t *comp, mp_parse_node_struct_t *pns) { } } +void compile_power(compiler_t *comp, mp_parse_node_struct_t *pns) { + // this is to handle special super() call + comp->func_arg_is_super = MP_PARSE_NODE_IS_ID(pns->nodes[0]) && MP_PARSE_NODE_LEAF_ARG(pns->nodes[0]) == MP_QSTR_super; + + compile_generic_all_nodes(comp, pns); +} + void compile_trailer_paren_helper(compiler_t *comp, mp_parse_node_struct_t *pns, bool is_method_call) { // function to call is on top of stack +#if !MICROPY_EMIT_CPYTHON + // this is to handle special super() call + if (MP_PARSE_NODE_IS_NULL(pns->nodes[0]) && comp->func_arg_is_super && comp->scope_cur->kind == SCOPE_FUNCTION) { + EMIT_ARG(load_id, MP_QSTR___class__); + // get first argument to function + bool found = false; + for (int i = 0; i < comp->scope_cur->id_info_len; i++) { + if (comp->scope_cur->id_info[i].param) { + EMIT_ARG(load_fast, MP_QSTR_, comp->scope_cur->id_info[i].local_num); + found = true; + break; + } + } + if (!found) { + printf("TypeError: super() call cannot find self\n"); + return; + } + EMIT_ARG(call_function, 2, 0, false, false); + return; + } +#endif + int old_n_arg_keyword = comp->n_arg_keyword; bool old_have_star_arg = comp->have_star_arg; bool old_have_dbl_star_arg = comp->have_dbl_star_arg; @@ -2107,6 +2139,7 @@ void compile_power_trailers(compiler_t *comp, mp_parse_node_struct_t *pns) { } else { compile_node(comp, pns->nodes[i]); } + comp->func_arg_is_super = false; } } @@ -2836,7 +2869,7 @@ void compile_scope(compiler_t *comp, scope_t *scope, pass_kind_t pass) { #if MICROPY_EMIT_CPYTHON EMIT_ARG(load_closure, MP_QSTR___class__, 0); // XXX check this is the correct local num #else - EMIT_ARG(load_fast, MP_QSTR___class__, 0); // XXX check this is the correct local num + EMIT_ARG(load_fast, MP_QSTR___class__, id->local_num); #endif } EMIT(return_value); @@ -3044,6 +3077,8 @@ mp_obj_t mp_compile(mp_parse_node_t pn, qstr source_file, bool is_repl) { comp->break_continue_except_level = 0; comp->cur_except_level = 0; + comp->func_arg_is_super = false; + comp->scope_head = NULL; comp->scope_cur = NULL; @@ -3054,7 +3089,7 @@ mp_obj_t mp_compile(mp_parse_node_t pn, qstr source_file, bool is_repl) { scope_t *module_scope = scope_new_and_link(comp, SCOPE_MODULE, pn, EMIT_OPT_NONE); // compile pass 1 - comp->emit = emit_pass1_new(MP_QSTR___class__); + comp->emit = emit_pass1_new(); comp->emit_method_table = &emit_pass1_method_table; comp->emit_inline_asm = NULL; comp->emit_inline_asm_method_table = NULL; diff --git a/py/emit.h b/py/emit.h index 062b38ef9c..ce0c98ba78 100644 --- a/py/emit.h +++ b/py/emit.h @@ -117,7 +117,7 @@ extern const emit_method_table_t emit_bc_method_table; extern const emit_method_table_t emit_native_x64_method_table; extern const emit_method_table_t emit_native_thumb_method_table; -emit_t *emit_pass1_new(qstr qstr___class__); +emit_t *emit_pass1_new(void); emit_t *emit_cpython_new(uint max_num_labels); emit_t *emit_bc_new(uint max_num_labels); emit_t *emit_native_x64_new(uint max_num_labels); diff --git a/py/emitpass1.c b/py/emitpass1.c index 38115a51c1..634d090518 100644 --- a/py/emitpass1.c +++ b/py/emitpass1.c @@ -15,13 +15,11 @@ #include "emit.h" struct _emit_t { - qstr qstr___class__; scope_t *scope; }; -emit_t *emit_pass1_new(qstr qstr___class__) { +emit_t *emit_pass1_new(void) { emit_t *emit = m_new(emit_t, 1); - emit->qstr___class__ = qstr___class__; return emit; } @@ -45,18 +43,21 @@ static void emit_pass1_load_id(emit_t *emit, qstr qstr) { bool added; id_info_t *id = scope_find_or_add_id(emit->scope, qstr, &added); if (added) { - if (strcmp(qstr_str(qstr), "super") == 0 && emit->scope->kind == SCOPE_FUNCTION) { +#if MICROPY_EMIT_CPYTHON + if (qstr == MP_QSTR_super && emit->scope->kind == SCOPE_FUNCTION) { // special case, super is a global, and also counts as use of __class__ id->kind = ID_INFO_KIND_GLOBAL_EXPLICIT; - id_info_t *id2 = scope_find_local_in_parent(emit->scope, emit->qstr___class__); + id_info_t *id2 = scope_find_local_in_parent(emit->scope, MP_QSTR___class__); if (id2 != NULL) { - id2 = scope_find_or_add_id(emit->scope, emit->qstr___class__, &added); + id2 = scope_find_or_add_id(emit->scope, MP_QSTR___class__, &added); if (added) { id2->kind = ID_INFO_KIND_FREE; - scope_close_over_in_parents(emit->scope, emit->qstr___class__); + scope_close_over_in_parents(emit->scope, MP_QSTR___class__); } } - } else { + } else +#endif + { id_info_t *id2 = scope_find_local_in_parent(emit->scope, qstr); if (id2 != NULL && (id2->kind == ID_INFO_KIND_LOCAL || id2->kind == ID_INFO_KIND_CELL || id2->kind == ID_INFO_KIND_FREE)) { id->kind = ID_INFO_KIND_FREE; diff --git a/py/grammar.h b/py/grammar.h index 32be6c66ca..c58ad9e069 100644 --- a/py/grammar.h +++ b/py/grammar.h @@ -214,7 +214,7 @@ DEF_RULE(term_op, nc, or(4), tok(OP_STAR), tok(OP_SLASH), tok(OP_PERCENT), tok(O DEF_RULE(factor, nc, or(2), rule(factor_2), rule(power)) DEF_RULE(factor_2, c(factor_2), and(2), rule(factor_op), rule(factor)) DEF_RULE(factor_op, nc, or(3), tok(OP_PLUS), tok(OP_MINUS), tok(OP_TILDE)) -DEF_RULE(power, c(generic_all_nodes), and(3), rule(atom), opt_rule(power_trailers), opt_rule(power_dbl_star)) +DEF_RULE(power, c(power), and(3), rule(atom), opt_rule(power_trailers), opt_rule(power_dbl_star)) DEF_RULE(power_trailers, c(power_trailers), one_or_more, rule(trailer)) DEF_RULE(power_dbl_star, c(power_dbl_star), and(2), tok(OP_DBL_STAR), rule(factor)) diff --git a/py/obj.h b/py/obj.h index 0ba4ae1b5d..2f4d441264 100644 --- a/py/obj.h +++ b/py/obj.h @@ -231,6 +231,7 @@ mp_obj_t mp_obj_new_list(uint n, mp_obj_t *items); mp_obj_t mp_obj_new_dict(int n_args); mp_obj_t mp_obj_new_set(int n_args, mp_obj_t *items); mp_obj_t mp_obj_new_slice(mp_obj_t start, mp_obj_t stop, mp_obj_t step); +mp_obj_t mp_obj_new_super(mp_obj_t type, mp_obj_t obj); mp_obj_t mp_obj_new_bound_meth(mp_obj_t meth, mp_obj_t self); mp_obj_t mp_obj_new_getitem_iter(mp_obj_t *args); mp_obj_t mp_obj_new_module(qstr module_name); @@ -371,6 +372,9 @@ void mp_obj_fun_bc_get(mp_obj_t self_in, int *n_args, uint *n_state, const byte mp_obj_t mp_identity(mp_obj_t self); +// super +extern const mp_obj_type_t super_type; + // generator extern const mp_obj_type_t gen_instance_type; diff --git a/py/objtype.c b/py/objtype.c index 24d7af6010..d35c9a98d6 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -362,6 +362,96 @@ mp_obj_t mp_obj_new_type(const char *name, mp_obj_t bases_tuple, mp_obj_t locals return o; } +/******************************************************************************/ +// super object + +typedef struct _mp_obj_super_t { + mp_obj_base_t base; + mp_obj_t type; + mp_obj_t obj; +} mp_obj_super_t; + +static void super_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in, mp_print_kind_t kind) { + mp_obj_super_t *self = self_in; + print(env, "type, PRINT_STR); + print(env, ", "); + mp_obj_print_helper(print, env, self->obj, PRINT_STR); + print(env, ">"); +} + +static mp_obj_t super_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp_obj_t *args) { + if (n_args != 2 || n_kw != 0) { + // 0 arguments are turned into 2 in the compiler + // 1 argument is not yet implemented + nlr_jump(mp_obj_new_exception_msg(MP_QSTR_TypeError, "super() requires 2 arguments")); + } + return mp_obj_new_super(args[0], args[1]); +} + +// for fail, do nothing; for attr, dest[0] = value; for method, dest[0] = method, dest[1] = self +static void super_load_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { + assert(MP_OBJ_IS_TYPE(self_in, &super_type)); + mp_obj_super_t *self = self_in; + + assert(MP_OBJ_IS_TYPE(self->type, &mp_const_type)); + + mp_obj_type_t *type = self->type; + + // for a const struct, this entry might be NULL + if (type->bases_tuple == MP_OBJ_NULL) { + return; + } + + uint len; + mp_obj_t *items; + mp_obj_tuple_get(type->bases_tuple, &len, &items); + for (uint i = 0; i < len; i++) { + assert(MP_OBJ_IS_TYPE(items[i], &mp_const_type)); + mp_obj_t member = mp_obj_class_lookup((mp_obj_type_t*)items[i], attr); + if (member != MP_OBJ_NULL) { + // XXX this and the code in class_load_attr need to be factored out + if (mp_obj_is_callable(member)) { + // class member is callable so build a bound method + // check if the methods are functions, static or class methods + // see http://docs.python.org/3.3/howto/descriptor.html + // TODO check that this is the correct place to have this logic + if (MP_OBJ_IS_TYPE(member, &mp_type_staticmethod)) { + // return just the function + dest[0] = ((mp_obj_staticmethod_t*)member)->fun; + } else if (MP_OBJ_IS_TYPE(member, &mp_type_classmethod)) { + // return a bound method, with self being the type of this object + dest[0] = ((mp_obj_classmethod_t*)member)->fun; + dest[1] = mp_obj_get_type(self->obj); + } else { + // return a bound method, with self being this object + dest[0] = member; + dest[1] = self->obj; + } + return; + } else { + // class member is a value, so just return that value + dest[0] = member; + return; + } + } + } +} + +const mp_obj_type_t super_type = { + { &mp_const_type }, + "super", + .print = super_print, + .make_new = super_make_new, + .load_attr = super_load_attr, +}; + +mp_obj_t mp_obj_new_super(mp_obj_t type, mp_obj_t obj) { + mp_obj_super_t *o = m_new_obj(mp_obj_super_t); + *o = (mp_obj_super_t){{&super_type}, type, obj}; + return o; +} + /******************************************************************************/ // built-ins specific to types diff --git a/py/qstrdefs.h b/py/qstrdefs.h index bf575e25d1..fe1de07252 100644 --- a/py/qstrdefs.h +++ b/py/qstrdefs.h @@ -80,6 +80,7 @@ Q(repr) Q(set) Q(sorted) Q(sum) +Q(super) Q(str) Q(sys) Q(tuple) diff --git a/py/runtime.c b/py/runtime.c index c84a28e4cb..9327f0d6ac 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -129,6 +129,7 @@ void rt_init(void) { mp_map_add_qstr(&map_builtins, MP_QSTR_list, (mp_obj_t)&list_type); mp_map_add_qstr(&map_builtins, MP_QSTR_map, (mp_obj_t)&map_type); mp_map_add_qstr(&map_builtins, MP_QSTR_set, (mp_obj_t)&set_type); + mp_map_add_qstr(&map_builtins, MP_QSTR_super, (mp_obj_t)&super_type); mp_map_add_qstr(&map_builtins, MP_QSTR_tuple, (mp_obj_t)&tuple_type); mp_map_add_qstr(&map_builtins, MP_QSTR_type, (mp_obj_t)&mp_const_type); mp_map_add_qstr(&map_builtins, MP_QSTR_zip, (mp_obj_t)&zip_type); @@ -852,7 +853,10 @@ static void rt_load_method_maybe(mp_obj_t base, qstr attr, mp_obj_t *dest) { // if nothing found yet, look for built-in and generic names if (dest[0] == MP_OBJ_NULL) { - if (attr == MP_QSTR___next__ && type->iternext != NULL) { + if (attr == MP_QSTR___class__) { + // a.__class__ is equivalent to type(a) + dest[0] = type; + } else if (attr == MP_QSTR___next__ && type->iternext != NULL) { dest[0] = (mp_obj_t)&mp_builtin_next_obj; dest[1] = base; } else if (type->load_attr == NULL) { diff --git a/tests/bytecode/mp-tests/class6.py b/tests/bytecode/mp-tests/class6.py new file mode 100644 index 0000000000..05a2454f50 --- /dev/null +++ b/tests/bytecode/mp-tests/class6.py @@ -0,0 +1,7 @@ +class A: + def f(self): + pass + +class B(A): + def f(self): + super().f() diff --git a/tests/bytecode/mp-tests/class7.py b/tests/bytecode/mp-tests/class7.py new file mode 100644 index 0000000000..3de41dbb52 --- /dev/null +++ b/tests/bytecode/mp-tests/class7.py @@ -0,0 +1,6 @@ +# accessing super, but not as a function call + +class A: + def f(): + #x = super + print(super)