diff --git a/py/mpqstrraw.h b/py/mpqstrraw.h index c3cda84b4d..8d681c9a49 100644 --- a/py/mpqstrraw.h +++ b/py/mpqstrraw.h @@ -41,6 +41,7 @@ Q(chr) Q(complex) Q(dict) Q(divmod) +Q(enumerate) Q(float) Q(hash) Q(int) diff --git a/py/obj.h b/py/obj.h index dbf9efe20e..6a6f98d0cd 100644 --- a/py/obj.h +++ b/py/obj.h @@ -294,6 +294,9 @@ void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items); void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value); mp_obj_t list_sort(mp_obj_t args, struct _mp_map_t *kwargs); +// enumerate +extern const mp_obj_type_t enumerate_type; + // dict extern const mp_obj_type_t dict_type; uint mp_obj_dict_len(mp_obj_t self_in); diff --git a/py/objenumerate.c b/py/objenumerate.c new file mode 100644 index 0000000000..5bfd8a3370 --- /dev/null +++ b/py/objenumerate.c @@ -0,0 +1,53 @@ +#include +#include + +#include "misc.h" +#include "mpconfig.h" +#include "obj.h" +#include "runtime.h" + +typedef struct _mp_obj_enumerate_t { + mp_obj_base_t base; + mp_obj_t iter; + machine_int_t cur; +} mp_obj_enumerate_t; + +static mp_obj_t enumerate_getiter(mp_obj_t self_in) { + return self_in; +} + +static mp_obj_t enumerate_iternext(mp_obj_t self_in); + +/* TODO: enumerate is one of the ones that can take args or kwargs. + Sticking to args for now */ +static mp_obj_t enumerate_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) { + /* NOTE: args are backwards */ + assert(n_args > 0); + args += n_args - 1; + mp_obj_enumerate_t *o = m_new_obj(mp_obj_enumerate_t); + o->base.type = &enumerate_type; + o->iter = rt_getiter(args[0]); + o->cur = n_args > 1 ? mp_obj_get_int(args[-1]) : 0; + + return o; +} + +const mp_obj_type_t enumerate_type = { + { &mp_const_type }, + "enumerate", + .make_new = enumerate_make_new, + .iternext = enumerate_iternext, + .getiter = enumerate_getiter, +}; + +static mp_obj_t enumerate_iternext(mp_obj_t self_in) { + assert(MP_OBJ_IS_TYPE(self_in, &enumerate_type)); + mp_obj_enumerate_t *self = self_in; + mp_obj_t next = rt_iternext(self->iter); + if (next == mp_const_stop_iteration) { + return mp_const_stop_iteration; + } else { + mp_obj_t items[] = {MP_OBJ_NEW_SMALL_INT(self->cur++), next}; + return mp_obj_new_tuple(2, items); + } +} diff --git a/py/py.mk b/py/py.mk index 95f9c07671..275c92d126 100644 --- a/py/py.mk +++ b/py/py.mk @@ -77,6 +77,7 @@ PY_O_BASENAME = \ objclosure.o \ objcomplex.o \ objdict.o \ + objenumerate.o \ objexcept.o \ objfloat.o \ objfun.o \ diff --git a/py/runtime.c b/py/runtime.c index dc3d15657c..6047b90b40 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -105,6 +105,7 @@ void rt_init(void) { mp_map_add_qstr(&map_builtins, MP_QSTR_complex, (mp_obj_t)&complex_type); #endif mp_map_add_qstr(&map_builtins, MP_QSTR_dict, (mp_obj_t)&dict_type); + mp_map_add_qstr(&map_builtins, MP_QSTR_enumerate, (mp_obj_t)&enumerate_type); #if MICROPY_ENABLE_FLOAT mp_map_add_qstr(&map_builtins, MP_QSTR_float, (mp_obj_t)&float_type); #endif diff --git a/tests/basics/tests/enumerate.py b/tests/basics/tests/enumerate.py new file mode 100644 index 0000000000..f2bdf4f326 --- /dev/null +++ b/tests/basics/tests/enumerate.py @@ -0,0 +1,6 @@ +print(list(enumerate([]))) +print(list(enumerate([1, 2, 3]))) +print(list(enumerate([1, 2, 3], 5))) +print(list(enumerate([1, 2, 3], -5))) +print(list(enumerate(range(10000)))) +