From 7310fd469a4dfdd38fda242f138f5671c2f82b61 Mon Sep 17 00:00:00 2001 From: Damien George Date: Sun, 24 Aug 2014 19:14:09 +0100 Subject: [PATCH] py: Consolidate min/max functions into one, and add key= argument. Addresses issue #811. --- py/builtin.c | 67 ++++++++++++++-------------------- tests/basics/builtin_minmax.py | 10 +++++ 2 files changed, 37 insertions(+), 40 deletions(-) diff --git a/py/builtin.c b/py/builtin.c index 1924e6080f..4fea1fdb21 100644 --- a/py/builtin.c +++ b/py/builtin.c @@ -284,63 +284,50 @@ STATIC mp_obj_t mp_builtin_iter(mp_obj_t o_in) { MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_iter_obj, mp_builtin_iter); -STATIC mp_obj_t mp_builtin_max(uint n_args, const mp_obj_t *args) { +STATIC mp_obj_t mp_builtin_min_max(uint n_args, const mp_obj_t *args, mp_map_t *kwargs, int op) { + mp_map_elem_t *key_elem = mp_map_lookup(kwargs, MP_OBJ_NEW_QSTR(MP_QSTR_key), MP_MAP_LOOKUP); + mp_obj_t key_fn = key_elem == NULL ? MP_OBJ_NULL : key_elem->value; if (n_args == 1) { // given an iterable mp_obj_t iterable = mp_getiter(args[0]); - mp_obj_t max_obj = NULL; + mp_obj_t best_key = MP_OBJ_NULL; + mp_obj_t best_obj = MP_OBJ_NULL; mp_obj_t item; while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { - if (max_obj == NULL || (mp_binary_op(MP_BINARY_OP_LESS, max_obj, item) == mp_const_true)) { - max_obj = item; + mp_obj_t key = key_fn == MP_OBJ_NULL ? item : mp_call_function_1(key_fn, item); + if (best_obj == MP_OBJ_NULL || (mp_binary_op(op, key, best_key) == mp_const_true)) { + best_key = key; + best_obj = item; } } - if (max_obj == NULL) { - nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "max() arg is an empty sequence")); + if (best_obj == MP_OBJ_NULL) { + nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "arg is an empty sequence")); } - return max_obj; + return best_obj; } else { // given many args - mp_obj_t max_obj = args[0]; - for (int i = 1; i < n_args; i++) { - if (mp_binary_op(MP_BINARY_OP_LESS, max_obj, args[i]) == mp_const_true) { - max_obj = args[i]; + mp_obj_t best_key = MP_OBJ_NULL; + mp_obj_t best_obj = MP_OBJ_NULL; + for (mp_uint_t i = 0; i < n_args; i++) { + mp_obj_t key = key_fn == MP_OBJ_NULL ? args[i] : mp_call_function_1(key_fn, args[i]); + if (best_obj == MP_OBJ_NULL || (mp_binary_op(op, key, best_key) == mp_const_true)) { + best_key = key; + best_obj = args[i]; } } - return max_obj; + return best_obj; } } -MP_DEFINE_CONST_FUN_OBJ_VAR(mp_builtin_max_obj, 1, mp_builtin_max); - -STATIC mp_obj_t mp_builtin_min(uint n_args, const mp_obj_t *args) { - if (n_args == 1) { - // given an iterable - mp_obj_t iterable = mp_getiter(args[0]); - mp_obj_t min_obj = NULL; - mp_obj_t item; - while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { - if (min_obj == NULL || (mp_binary_op(MP_BINARY_OP_LESS, item, min_obj) == mp_const_true)) { - min_obj = item; - } - } - if (min_obj == NULL) { - nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "min() arg is an empty sequence")); - } - return min_obj; - } else { - // given many args - mp_obj_t min_obj = args[0]; - for (int i = 1; i < n_args; i++) { - if (mp_binary_op(MP_BINARY_OP_LESS, args[i], min_obj) == mp_const_true) { - min_obj = args[i]; - } - } - return min_obj; - } +STATIC mp_obj_t mp_builtin_max(uint n_args, const mp_obj_t *args, mp_map_t *kwargs) { + return mp_builtin_min_max(n_args, args, kwargs, MP_BINARY_OP_MORE); } +MP_DEFINE_CONST_FUN_OBJ_KW(mp_builtin_max_obj, 1, mp_builtin_max); -MP_DEFINE_CONST_FUN_OBJ_VAR(mp_builtin_min_obj, 1, mp_builtin_min); +STATIC mp_obj_t mp_builtin_min(uint n_args, const mp_obj_t *args, mp_map_t *kwargs) { + return mp_builtin_min_max(n_args, args, kwargs, MP_BINARY_OP_LESS); +} +MP_DEFINE_CONST_FUN_OBJ_KW(mp_builtin_min_obj, 1, mp_builtin_min); STATIC mp_obj_t mp_builtin_next(mp_obj_t o) { mp_obj_t ret = mp_iternext_allow_raise(o); diff --git a/tests/basics/builtin_minmax.py b/tests/basics/builtin_minmax.py index 8ee4bbca7d..a5f035b909 100644 --- a/tests/basics/builtin_minmax.py +++ b/tests/basics/builtin_minmax.py @@ -13,3 +13,13 @@ print(max(-1,0)) print(min([1,2,4,0,-1,2])) print(max([1,2,4,0,-1,2])) +# test with key function +lst = [2, 1, 3, 4] +print(min(lst, key=lambda x:x)) +print(min(lst, key=lambda x:-x)) +print(min(1, 2, 3, 4, key=lambda x:-x)) +print(min(4, 3, 2, 1, key=lambda x:-x)) +print(max(lst, key=lambda x:x)) +print(max(lst, key=lambda x:-x)) +print(max(1, 2, 3, 4, key=lambda x:-x)) +print(max(4, 3, 2, 1, key=lambda x:-x))