diff --git a/py/obj.c b/py/obj.c index 07b1612552..ed047acc39 100644 --- a/py/obj.c +++ b/py/obj.c @@ -371,7 +371,7 @@ mp_float_t mp_obj_get_float(mp_obj_t arg) { } #if MICROPY_PY_BUILTINS_COMPLEX -void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) { +bool mp_obj_get_complex_maybe(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) { if (arg == mp_const_false) { *real = 0; *imag = 0; @@ -392,6 +392,13 @@ void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) { } else if (mp_obj_is_type(arg, &mp_type_complex)) { mp_obj_complex_get(arg, real, imag); } else { + return false; + } + return true; +} + +void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) { + if (!mp_obj_get_complex_maybe(arg, real, imag)) { #if MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE mp_raise_TypeError(MP_ERROR_TEXT("can't convert to complex")); #else diff --git a/py/obj.h b/py/obj.h index 590b9c4b6a..1fa24eb18c 100644 --- a/py/obj.h +++ b/py/obj.h @@ -778,6 +778,7 @@ bool mp_obj_get_int_maybe(mp_const_obj_t arg, mp_int_t *value); mp_float_t mp_obj_get_float(mp_obj_t self_in); bool mp_obj_get_float_maybe(mp_obj_t arg, mp_float_t *value); void mp_obj_get_complex(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag); +bool mp_obj_get_complex_maybe(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag); #endif void mp_obj_get_array(mp_obj_t o, size_t *len, mp_obj_t **items); // *items may point inside a GC block void mp_obj_get_array_fixed_n(mp_obj_t o, size_t len, mp_obj_t **items); // *items may point inside a GC block diff --git a/py/objcomplex.c b/py/objcomplex.c index 91e4402309..f4c4aeffcb 100644 --- a/py/objcomplex.c +++ b/py/objcomplex.c @@ -178,7 +178,10 @@ void mp_obj_complex_get(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag) { mp_obj_t mp_obj_complex_binary_op(mp_binary_op_t op, mp_float_t lhs_real, mp_float_t lhs_imag, mp_obj_t rhs_in) { mp_float_t rhs_real, rhs_imag; - mp_obj_get_complex(rhs_in, &rhs_real, &rhs_imag); // can be any type, this function will convert to float (if possible) + if (!mp_obj_get_complex_maybe(rhs_in, &rhs_real, &rhs_imag)) { + return MP_OBJ_NULL; // op not supported + } + switch (op) { case MP_BINARY_OP_ADD: case MP_BINARY_OP_INPLACE_ADD: diff --git a/tests/float/cmath_fun.py b/tests/float/cmath_fun.py index 7b5e692452..15b72e7a62 100644 --- a/tests/float/cmath_fun.py +++ b/tests/float/cmath_fun.py @@ -57,3 +57,9 @@ for f_name, f, test_vals in functions: if abs(real) < 1e-6: real = 0.0 print("complex(%.5g, %.5g)" % (real, ret.imag)) + +# test invalid type passed to cmath function +try: + log([]) +except TypeError: + print("TypeError") diff --git a/tests/float/complex_special_mehods.py b/tests/float/complex_special_mehods.py new file mode 100644 index 0000000000..6789013fa6 --- /dev/null +++ b/tests/float/complex_special_mehods.py @@ -0,0 +1,15 @@ +# test complex interacting with special methods + + +class A: + def __add__(self, x): + print("__add__") + return 1 + + def __radd__(self, x): + print("__radd__") + return 2 + + +print(A() + 1j) +print(1j + A()) diff --git a/tests/run-tests b/tests/run-tests index f9e4de4b34..102b0f7790 100755 --- a/tests/run-tests +++ b/tests/run-tests @@ -355,6 +355,7 @@ def run_tests(pyb, tests, args, base_path="."): if not has_complex: skip_tests.add('float/complex1.py') skip_tests.add('float/complex1_intbig.py') + skip_tests.add('float/complex_special_mehods.py') skip_tests.add('float/int_big_float.py') skip_tests.add('float/true_value.py') skip_tests.add('float/types.py')