diff --git a/py/bc.h b/py/bc.h index 62151a3d39..fb672ea03b 100644 --- a/py/bc.h +++ b/py/bc.h @@ -1,9 +1,3 @@ -typedef enum { - MP_VM_RETURN_NORMAL, - MP_VM_RETURN_YIELD, - MP_VM_RETURN_EXCEPTION, -} mp_vm_return_kind_t; - // Exception stack entry typedef struct _mp_exc_stack { const byte *handler; diff --git a/py/objgenerator.c b/py/objgenerator.c index d1bae30de1..c3d7f12d70 100644 --- a/py/objgenerator.c +++ b/py/objgenerator.c @@ -135,7 +135,14 @@ STATIC mp_obj_t gen_resume_and_raise(mp_obj_t self_in, mp_obj_t send_value, mp_o return ret; case MP_VM_RETURN_EXCEPTION: - nlr_jump(ret); + // TODO: Optimization of returning MP_OBJ_NULL is really part + // of mp_iternext() protocol, but this function is called by other methods + // too, which may not handled MP_OBJ_NULL. + if (mp_obj_is_subclass_fast(mp_obj_get_type(ret), &mp_type_StopIteration)) { + return MP_OBJ_NULL; + } else { + nlr_jump(ret); + } default: assert(0); diff --git a/py/runtime.c b/py/runtime.c index 4012506627..4898864962 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -17,6 +17,7 @@ #include "builtintables.h" #include "bc.h" #include "intdivmod.h" +#include "objgenerator.h" #if 0 // print debugging info #define DEBUG_PRINT (1) @@ -903,6 +904,62 @@ mp_obj_t mp_iternext(mp_obj_t o_in) { } } +// TODO: Unclear what to do with StopIterarion exception here. +mp_vm_return_kind_t mp_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, mp_obj_t *ret_val) { + mp_obj_type_t *type = mp_obj_get_type(self_in); + + if (type == &mp_type_gen_instance) { + return mp_obj_gen_resume(self_in, send_value, throw_value, ret_val); + } + + if (type->iternext != NULL && send_value == mp_const_none) { + mp_obj_t ret = type->iternext(self_in); + if (ret != MP_OBJ_NULL) { + *ret_val = ret; + return MP_VM_RETURN_YIELD; + } else { + // Emulate raise StopIteration() + // Special case, handled in vm.c + *ret_val = MP_OBJ_NULL; + return MP_VM_RETURN_NORMAL; + } + } + + mp_obj_t dest[3]; // Reserve slot for send() arg + + if (send_value == mp_const_none) { + mp_load_method_maybe(self_in, MP_QSTR___next__, dest); + if (dest[0] != MP_OBJ_NULL) { + *ret_val = mp_call_method_n_kw(0, 0, dest); + return MP_VM_RETURN_YIELD; + } + } + + if (send_value != MP_OBJ_NULL) { + mp_load_method(self_in, MP_QSTR_send, dest); + dest[2] = send_value; + *ret_val = mp_call_method_n_kw(1, 0, dest); + return MP_VM_RETURN_YIELD; + } + + if (throw_value != MP_OBJ_NULL) { + if (mp_obj_is_subclass_fast(mp_obj_get_type(throw_value), &mp_type_GeneratorExit)) { + mp_load_method_maybe(self_in, MP_QSTR_close, dest); + if (dest[0] != MP_OBJ_NULL) { + *ret_val = mp_call_method_n_kw(0, 0, dest); + // We assume one can't "yield" from close() + return MP_VM_RETURN_NORMAL; + } + } + mp_load_method(self_in, MP_QSTR_throw, dest); + *ret_val = mp_call_method_n_kw(1, 0, &throw_value); + return MP_VM_RETURN_YIELD; + } + + assert(0); + return MP_VM_RETURN_NORMAL; // Should be unreachable +} + mp_obj_t mp_make_raise_obj(mp_obj_t o) { DEBUG_printf("raise %p\n", o); if (mp_obj_is_exception_type(o)) { diff --git a/py/runtime.h b/py/runtime.h index 8487309a13..f79cb2e306 100644 --- a/py/runtime.h +++ b/py/runtime.h @@ -1,3 +1,9 @@ +typedef enum { + MP_VM_RETURN_NORMAL, + MP_VM_RETURN_YIELD, + MP_VM_RETURN_EXCEPTION, +} mp_vm_return_kind_t; + void mp_init(void); void mp_deinit(void); @@ -55,6 +61,7 @@ void mp_store_subscr(mp_obj_t base, mp_obj_t index, mp_obj_t val); mp_obj_t mp_getiter(mp_obj_t o); mp_obj_t mp_iternext_allow_raise(mp_obj_t o); // may return MP_OBJ_NULL instead of raising StopIteration() mp_obj_t mp_iternext(mp_obj_t o); // will always return MP_OBJ_NULL instead of raising StopIteration(...) +mp_vm_return_kind_t mp_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, mp_obj_t *ret_val); mp_obj_t mp_make_raise_obj(mp_obj_t o); diff --git a/py/vm.c b/py/vm.c index 52d9268184..edcad39565 100644 --- a/py/vm.c +++ b/py/vm.c @@ -808,9 +808,9 @@ yield: if (inject_exc != MP_OBJ_NULL) { t_exc = inject_exc; inject_exc = MP_OBJ_NULL; - ret_kind = mp_obj_gen_resume(TOP(), mp_const_none, t_exc, &obj2); + ret_kind = mp_resume(TOP(), mp_const_none, t_exc, &obj2); } else { - ret_kind = mp_obj_gen_resume(TOP(), obj1, MP_OBJ_NULL, &obj2); + ret_kind = mp_resume(TOP(), obj1, MP_OBJ_NULL, &obj2); } if (ret_kind == MP_VM_RETURN_YIELD) { diff --git a/tests/basics/gen-yield-from-ducktype.py b/tests/basics/gen-yield-from-ducktype.py new file mode 100644 index 0000000000..aa0109c914 --- /dev/null +++ b/tests/basics/gen-yield-from-ducktype.py @@ -0,0 +1,44 @@ +class MyGen: + + def __init__(self): + self.v = 0 + + def __iter__(self): + return self + + def __next__(self): + self.v += 1 + if self.v > 5: + raise StopIteration + return self.v + +def gen(): + yield from MyGen() + +def gen2(): + yield from gen() + +print(list(gen())) +print(list(gen2())) + + +class Incrementer: + + def __iter__(self): + return self + + def __next__(self): + return self.send(None) + + def send(self, val): + if val is None: + return "Incrementer initialized" + return val + 1 + +def gen3(): + yield from Incrementer() + +g = gen3() +print(next(g)) +print(g.send(5)) +print(g.send(100)) diff --git a/tests/basics/gen-yield-from-iter.py b/tests/basics/gen-yield-from-iter.py new file mode 100644 index 0000000000..2d06328fbb --- /dev/null +++ b/tests/basics/gen-yield-from-iter.py @@ -0,0 +1,8 @@ +def gen(): + yield from (1, 2, 3) + +def gen2(): + yield from gen() + +print(list(gen())) +print(list(gen2()))