diff --git a/extmod/modbtree.c b/extmod/modbtree.c index 5c13115328..8b76885809 100644 --- a/extmod/modbtree.c +++ b/extmod/modbtree.c @@ -282,7 +282,7 @@ STATIC mp_obj_t btree_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { STATIC mp_obj_t btree_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_obj_btree_t *self = MP_OBJ_TO_PTR(lhs_in); switch (op) { - case MP_BINARY_OP_IN: { + case MP_BINARY_OP_CONTAINS: { DBT key, val; key.data = (void*)mp_obj_str_get_data(rhs_in, &key.size); int res = __bt_get(self->db, &key, &val, 0); diff --git a/py/objarray.c b/py/objarray.c index 7003ec9e7d..a35539484b 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -269,8 +269,7 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs return lhs_in; } - case MP_BINARY_OP_IN: { - /* NOTE `a in b` is `b.__contains__(a)` */ + case MP_BINARY_OP_CONTAINS: { mp_buffer_info_t lhs_bufinfo; mp_buffer_info_t rhs_bufinfo; diff --git a/py/objdict.c b/py/objdict.c index 1553a83b46..d0f95e41ad 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -115,7 +115,7 @@ STATIC mp_obj_t dict_unary_op(mp_unary_op_t op, mp_obj_t self_in) { STATIC mp_obj_t dict_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_obj_dict_t *o = MP_OBJ_TO_PTR(lhs_in); switch (op) { - case MP_BINARY_OP_IN: { + case MP_BINARY_OP_CONTAINS: { mp_map_elem_t *elem = mp_map_lookup(&o->map, rhs_in, MP_MAP_LOOKUP); return mp_obj_new_bool(elem != NULL); } @@ -485,7 +485,7 @@ STATIC mp_obj_t dict_view_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t if (o->kind != MP_DICT_VIEW_KEYS) { return MP_OBJ_NULL; // op not supported } - if (op != MP_BINARY_OP_IN) { + if (op != MP_BINARY_OP_CONTAINS) { return MP_OBJ_NULL; // op not supported } return dict_binary_op(op, o->dict, rhs_in); diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 7b5cb0b9d4..17e3ee6d24 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -207,7 +207,7 @@ mp_obj_t mp_obj_int_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_i return mp_obj_new_float(flhs / frhs); #endif - } else if (op >= MP_BINARY_OP_INPLACE_OR) { + } else if (op >= MP_BINARY_OP_INPLACE_OR && op < MP_BINARY_OP_CONTAINS) { mp_obj_int_t *res = mp_obj_int_new_mpz(); switch (op) { diff --git a/py/objset.c b/py/objset.c index 6ed15c7914..3e98c30e8f 100644 --- a/py/objset.c +++ b/py/objset.c @@ -461,7 +461,7 @@ STATIC mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { #else bool update = true; #endif - if (op != MP_BINARY_OP_IN && !is_set_or_frozenset(rhs)) { + if (op != MP_BINARY_OP_CONTAINS && !is_set_or_frozenset(rhs)) { // For all ops except containment the RHS must be a set/frozenset return MP_OBJ_NULL; } @@ -507,7 +507,7 @@ STATIC mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { return set_issubset(lhs, rhs); case MP_BINARY_OP_MORE_EQUAL: return set_issuperset(lhs, rhs); - case MP_BINARY_OP_IN: { + case MP_BINARY_OP_CONTAINS: { mp_obj_set_t *o = MP_OBJ_TO_PTR(lhs); mp_obj_t elem = mp_set_lookup(&o->set, rhs, MP_MAP_LOOKUP); return mp_obj_new_bool(elem != MP_OBJ_NULL); diff --git a/py/objstr.c b/py/objstr.c index 1ff5132d29..b4f15b38d5 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -384,8 +384,7 @@ mp_obj_t mp_obj_str_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_i return mp_obj_new_str_from_vstr(lhs_type, &vstr); } - case MP_BINARY_OP_IN: - /* NOTE `a in b` is `b.__contains__(a)` */ + case MP_BINARY_OP_CONTAINS: return mp_obj_new_bool(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL); //case MP_BINARY_OP_NOT_EQUAL: // This is never passed here diff --git a/py/objtype.c b/py/objtype.c index 01d248256a..267cae8156 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -424,7 +424,7 @@ const byte mp_binary_op_method_name[MP_BINARY_OP_NUM_RUNTIME] = { [MP_BINARY_OP_LESS_EQUAL] = MP_QSTR___le__, [MP_BINARY_OP_MORE_EQUAL] = MP_QSTR___ge__, // MP_BINARY_OP_NOT_EQUAL, // a != b calls a == b and inverts result - [MP_BINARY_OP_IN] = MP_QSTR___contains__, + [MP_BINARY_OP_CONTAINS] = MP_QSTR___contains__, // All inplace methods are optional, and normal methods will be used // as a fallback. diff --git a/py/opmethods.c b/py/opmethods.c index 31901bb521..247fa5bbc8 100644 --- a/py/opmethods.c +++ b/py/opmethods.c @@ -47,6 +47,6 @@ MP_DEFINE_CONST_FUN_OBJ_2(mp_op_delitem_obj, op_delitem); STATIC mp_obj_t op_contains(mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_obj_type_t *type = mp_obj_get_type(lhs_in); - return type->binary_op(MP_BINARY_OP_IN, lhs_in, rhs_in); + return type->binary_op(MP_BINARY_OP_CONTAINS, lhs_in, rhs_in); } MP_DEFINE_CONST_FUN_OBJ_2(mp_op_contains_obj, op_contains); diff --git a/py/runtime.c b/py/runtime.c index 08a35c2e60..c7fe393675 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -523,30 +523,12 @@ mp_obj_t mp_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { } } - /* deal with `in` - * - * NOTE `a in b` is `b.__contains__(a)`, hence why the generic dispatch - * needs to go below with swapped arguments - */ + // Convert MP_BINARY_OP_IN to MP_BINARY_OP_CONTAINS with swapped args. if (op == MP_BINARY_OP_IN) { - mp_obj_type_t *type = mp_obj_get_type(rhs); - if (type->binary_op != NULL) { - mp_obj_t res = type->binary_op(op, rhs, lhs); - if (res != MP_OBJ_NULL) { - return res; - } - } - - // final attempt, walk the iterator (will raise if rhs is not iterable) - mp_obj_iter_buf_t iter_buf; - mp_obj_t iter = mp_getiter(rhs, &iter_buf); - mp_obj_t next; - while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { - if (mp_obj_equal(next, lhs)) { - return mp_const_true; - } - } - return mp_const_false; + op = MP_BINARY_OP_CONTAINS; + mp_obj_t temp = lhs; + lhs = rhs; + rhs = temp; } // generic binary_op supplied by type @@ -575,6 +557,20 @@ generic_binary_op: } #endif + if (op == MP_BINARY_OP_CONTAINS) { + // If type didn't support containment then explicitly walk the iterator. + // mp_getiter will raise the appropriate exception if lhs is not iterable. + mp_obj_iter_buf_t iter_buf; + mp_obj_t iter = mp_getiter(lhs, &iter_buf); + mp_obj_t next; + while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { + if (mp_obj_equal(next, rhs)) { + return mp_const_true; + } + } + return mp_const_false; + } + unsupported_op: if (MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE) { mp_raise_TypeError("unsupported type for operator"); diff --git a/py/runtime0.h b/py/runtime0.h index a72b7feb7a..960532d176 100644 --- a/py/runtime0.h +++ b/py/runtime0.h @@ -131,6 +131,10 @@ typedef enum { #endif , + // The runtime will convert MP_BINARY_OP_IN to this operator with swapped args. + // A type should implement this containment operator instead of MP_BINARY_OP_IN. + MP_BINARY_OP_CONTAINS, + MP_BINARY_OP_NUM_RUNTIME, // These 2 are not supported by the runtime and must be synthesised by the emitter diff --git a/tests/micropython/viper_error.py.exp b/tests/micropython/viper_error.py.exp index 96be5a5902..a44fb3ff0a 100644 --- a/tests/micropython/viper_error.py.exp +++ b/tests/micropython/viper_error.py.exp @@ -18,7 +18,7 @@ ViperTypeError('must raise an object',) ViperTypeError('unary op __pos__ not implemented',) ViperTypeError('unary op __neg__ not implemented',) ViperTypeError('unary op __invert__ not implemented',) -ViperTypeError('binary op __contains__ not implemented',) +ViperTypeError('binary op not implemented',) NotImplementedError('native yield',) NotImplementedError('native yield from',) NotImplementedError('conversion to object',)