From 25e140652b946e9a136a16a9831c0b18f1285702 Mon Sep 17 00:00:00 2001 From: Damien George Date: Wed, 4 Oct 2017 21:00:05 +1100 Subject: [PATCH] py/modmath: Add full checks for math domain errors. This patch changes how most of the plain math functions are implemented: there are now two generic math wrapper functions that take a pointer to a math function (like sin, cos) and perform the necessary conversion to and from MicroPython types. This helps to reduce code size. The generic functions can also check for math domain errors in a generic way, by testing if the result is NaN or infinity combined with finite inputs. The result is that, with this patch, all math functions now have full domain error checking (even gamma and lgamma) and code size has decreased for most ports. Code size changes in bytes for those with the math module are: unix x64: -432 unix nanbox: -792 stm32: -88 esp8266: +12 Tests are also added to check domain errors are handled correctly. --- py/modmath.c | 67 ++++++++++++++++++++---------- tests/float/math_domain.py | 51 +++++++++++++++++++++++ tests/float/math_domain_special.py | 36 ++++++++++++++++ 3 files changed, 133 insertions(+), 21 deletions(-) create mode 100644 tests/float/math_domain.py create mode 100644 tests/float/math_domain_special.py diff --git a/py/modmath.c b/py/modmath.c index 4d627c67f4..6c248844df 100644 --- a/py/modmath.c +++ b/py/modmath.c @@ -3,7 +3,7 @@ * * The MIT License (MIT) * - * Copyright (c) 2013, 2014 Damien P. George + * Copyright (c) 2013-2017 Damien P. George * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,13 +39,30 @@ STATIC NORETURN void math_error(void) { mp_raise_ValueError("math domain error"); } -#define MATH_FUN_1(py_name, c_name) \ - STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \ - STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); +STATIC mp_obj_t math_generic_1(mp_obj_t x_obj, mp_float_t (*f)(mp_float_t)) { + mp_float_t x = mp_obj_get_float(x_obj); + mp_float_t ans = f(x); + if ((isnan(ans) && !isnan(x)) || (isinf(ans) && !isinf(x))) { + math_error(); + } + return mp_obj_new_float(ans); +} -#define MATH_FUN_2(py_name, c_name) \ - STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_float(y_obj))); } \ - STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); +STATIC mp_obj_t math_generic_2(mp_obj_t x_obj, mp_obj_t y_obj, mp_float_t (*f)(mp_float_t, mp_float_t)) { + mp_float_t x = mp_obj_get_float(x_obj); + mp_float_t y = mp_obj_get_float(y_obj); + mp_float_t ans = f(x, y); + if ((isnan(ans) && !isnan(x) && !isnan(y)) || (isinf(ans) && !isinf(x))) { + math_error(); + } + return mp_obj_new_float(ans); +} + +#define MATH_FUN_1(py_name, c_name) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \ + return math_generic_1(x_obj, MICROPY_FLOAT_C_FUN(c_name)); \ + } \ + STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); #define MATH_FUN_1_TO_BOOL(py_name, c_name) \ STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_bool(c_name(mp_obj_get_float(x_obj))); } \ @@ -55,15 +72,17 @@ STATIC NORETURN void math_error(void) { STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_int_from_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \ STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); -#define MATH_FUN_1_ERRCOND(py_name, c_name, error_condition) \ - STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \ - mp_float_t x = mp_obj_get_float(x_obj); \ - if (error_condition) { \ - math_error(); \ - } \ - return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(x)); \ +#define MATH_FUN_2(py_name, c_name) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { \ + return math_generic_2(x_obj, y_obj, MICROPY_FLOAT_C_FUN(c_name)); \ } \ - STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); + STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); + +#define MATH_FUN_2_FLT_INT(py_name, c_name) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { \ + return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_int(y_obj))); \ + } \ + STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); #if MP_NEED_LOG2 // 1.442695040888963407354163704 is 1/_M_LN2 @@ -71,7 +90,7 @@ STATIC NORETURN void math_error(void) { #endif // sqrt(x): returns the square root of x -MATH_FUN_1_ERRCOND(sqrt, sqrt, (x < (mp_float_t)0.0)) +MATH_FUN_1(sqrt, sqrt) // pow(x, y): returns x to the power of y MATH_FUN_2(pow, pow) // exp(x) @@ -80,9 +99,9 @@ MATH_FUN_1(exp, exp) // expm1(x) MATH_FUN_1(expm1, expm1) // log2(x) -MATH_FUN_1_ERRCOND(log2, log2, (x <= (mp_float_t)0.0)) +MATH_FUN_1(log2, log2) // log10(x) -MATH_FUN_1_ERRCOND(log10, log10, (x <= (mp_float_t)0.0)) +MATH_FUN_1(log10, log10) // cosh(x) MATH_FUN_1(cosh, cosh) // sinh(x) @@ -113,9 +132,15 @@ MATH_FUN_2(atan2, atan2) // ceil(x) MATH_FUN_1_TO_INT(ceil, ceil) // copysign(x, y) -MATH_FUN_2(copysign, copysign) +STATIC mp_float_t MICROPY_FLOAT_C_FUN(copysign_func)(mp_float_t x, mp_float_t y) { + return MICROPY_FLOAT_C_FUN(copysign)(x, y); +} +MATH_FUN_2(copysign, copysign_func) // fabs(x) -MATH_FUN_1(fabs, fabs) +STATIC mp_float_t MICROPY_FLOAT_C_FUN(fabs_func)(mp_float_t x) { + return MICROPY_FLOAT_C_FUN(fabs)(x); +} +MATH_FUN_1(fabs, fabs_func) // floor(x) MATH_FUN_1_TO_INT(floor, floor) //TODO: delegate to x.__floor__() if x is not a float // fmod(x, y) @@ -129,7 +154,7 @@ MATH_FUN_1_TO_BOOL(isnan, isnan) // trunc(x) MATH_FUN_1_TO_INT(trunc, trunc) // ldexp(x, exp) -MATH_FUN_2(ldexp, ldexp) +MATH_FUN_2_FLT_INT(ldexp, ldexp) #if MICROPY_PY_MATH_SPECIAL_FUNCTIONS // erf(x): return the error function of x MATH_FUN_1(erf, erf) diff --git a/tests/float/math_domain.py b/tests/float/math_domain.py new file mode 100644 index 0000000000..0cf10fb2ad --- /dev/null +++ b/tests/float/math_domain.py @@ -0,0 +1,51 @@ +# Tests domain errors in math functions + +try: + import math +except ImportError: + print("SKIP") + raise SystemExit + +inf = float('inf') +nan = float('nan') + +# single argument functions +for name, f, args in ( + ('fabs', math.fabs, ()), + ('ceil', math.ceil, ()), + ('floor', math.floor, ()), + ('trunc', math.trunc, ()), + ('sqrt', math.sqrt, (-1, 0)), + ('exp', math.exp, ()), + ('sin', math.sin, ()), + ('cos', math.cos, ()), + ('tan', math.tan, ()), + ('asin', math.asin, (-1.1, 1, 1.1)), + ('acos', math.acos, (-1.1, 1, 1.1)), + ('atan', math.atan, ()), + ('ldexp', lambda x: math.ldexp(x, 0), ()), + ('radians', math.radians, ()), + ('degrees', math.degrees, ()), + ): + for x in args + (inf, nan): + try: + ans = f(x) + print('%.4f' % ans) + except ValueError: + print(name, 'ValueError') + except OverflowError: + print(name, 'OverflowError') + +# double argument functions +for name, f, args in ( + ('pow', math.pow, ((0, 2), (-1, 2), (0, -1), (-1, 2.3))), + ('fmod', math.fmod, ((1.2, inf), (1.2, 0), (inf, 1.2))), + ('atan2', math.atan2, ((0, 0),)), + ('copysign', math.copysign, ()), + ): + for x in args + ((0, inf), (inf, 0), (inf, inf), (inf, nan), (nan, inf), (nan, nan)): + try: + ans = f(*x) + print('%.4f' % ans) + except ValueError: + print(name, 'ValueError') diff --git a/tests/float/math_domain_special.py b/tests/float/math_domain_special.py new file mode 100644 index 0000000000..388920350e --- /dev/null +++ b/tests/float/math_domain_special.py @@ -0,0 +1,36 @@ +# Tests domain errors in special math functions + +try: + import math + math.erf +except (ImportError, AttributeError): + print("SKIP") + raise SystemExit + +inf = float('inf') +nan = float('nan') + +# single argument functions +for name, f, args in ( + ('expm1', math.exp, ()), + ('log2', math.log2, (-1, 0)), + ('log10', math.log10, (-1, 0)), + ('sinh', math.sinh, ()), + ('cosh', math.cosh, ()), + ('tanh', math.tanh, ()), + ('asinh', math.asinh, ()), + ('acosh', math.acosh, (-1, 0.9, 1)), + ('atanh', math.atanh, (-1, 1)), + ('erf', math.erf, ()), + ('erfc', math.erfc, ()), + ('gamma', math.gamma, (-2, -1, 0, 1)), + ('lgamma', math.lgamma, (-2, -1, 0, 1)), + ): + for x in args + (inf, nan): + try: + ans = f(x) + print('%.4f' % ans) + except ValueError: + print(name, 'ValueError') + except OverflowError: + print(name, 'OverflowError')