From d4b75f6b6822885e331c69a74e56e23af40a6264 Mon Sep 17 00:00:00 2001 From: Damien George Date: Mon, 4 Sep 2017 14:16:27 +1000 Subject: [PATCH] py/obj: Fix comparison of float/complex NaN with itself. IEEE floating point is specified such that a comparison of NaN with itself returns false, and Python respects these semantics. This patch makes uPy also have these semantics. The fix has a minor impact on the speed of the object-equality fast-path, but that seems to be unavoidable and it's much more important to have correct behaviour (especially in this case where the wrong answer for nan==nan is silently returned). --- py/obj.c | 11 ++++++++++- tests/float/complex1.py | 5 +++++ tests/float/float1.py | 5 +++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/py/obj.c b/py/obj.c index 90ce47e8fb..857fe373f2 100644 --- a/py/obj.c +++ b/py/obj.c @@ -162,7 +162,16 @@ bool mp_obj_is_callable(mp_obj_t o_in) { // comparison returns NotImplemented, == and != are decided by comparing the object // pointer." bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { - if (o1 == o2) { + // Float (and complex) NaN is never equal to anything, not even itself, + // so we must have a special check here to cover those cases. + if (o1 == o2 + #if MICROPY_PY_BUILTINS_FLOAT + && !mp_obj_is_float(o1) + #endif + #if MICROPY_PY_BUILTINS_COMPLEX + && !MP_OBJ_IS_TYPE(o1, &mp_type_complex) + #endif + ) { return true; } if (o1 == mp_const_none || o2 == mp_const_none) { diff --git a/tests/float/complex1.py b/tests/float/complex1.py index a6038de04a..7f0b317b35 100644 --- a/tests/float/complex1.py +++ b/tests/float/complex1.py @@ -37,6 +37,11 @@ ans = 1j ** 2.5j; print("%.5g %.5g" % (ans.real, ans.imag)) print(1j == 1) print(1j == 1j) +# comparison of nan is special +nan = float('nan') * 1j +print(nan == 1j) +print(nan == nan) + # builtin abs print(abs(1j)) print("%.5g" % abs(1j + 2)) diff --git a/tests/float/float1.py b/tests/float/float1.py index 93f6f014c4..137dacc233 100644 --- a/tests/float/float1.py +++ b/tests/float/float1.py @@ -60,6 +60,11 @@ print(1.2 <= -3.4) print(1.2 >= 3.4) print(1.2 >= -3.4) +# comparison of nan is special +nan = float('nan') +print(nan == 1.2) +print(nan == nan) + try: 1.0 / 0 except ZeroDivisionError: