diff --git a/extmod/moduasyncio.c b/extmod/moduasyncio.c index dd2d1e7475..b0921b6eb1 100644 --- a/extmod/moduasyncio.c +++ b/extmod/moduasyncio.c @@ -265,6 +265,9 @@ STATIC mp_obj_t task_getiter(mp_obj_t self_in, mp_obj_iter_buf_t *iter_buf) { } else if (self->state == TASK_STATE_RUNNING_NOT_WAITED_ON) { // Allocate the waiting queue. self->state = task_queue_make_new(&task_queue_type, 0, 0, NULL); + } else if (mp_obj_get_type(self->state) != &task_queue_type) { + // Task has state used for another purpose, so can't also wait on it. + mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("can't wait")); } return self_in; } diff --git a/extmod/uasyncio/core.py b/extmod/uasyncio/core.py index 12833cf0cd..28b5e960ac 100644 --- a/extmod/uasyncio/core.py +++ b/extmod/uasyncio/core.py @@ -195,6 +195,11 @@ def run_until_complete(main_task=None): if t.state is True: # "None" indicates that the task is complete and not await'ed on (yet). t.state = None + elif callable(t.state): + # The task has a callback registered to be called on completion. + t.state(t, er) + t.state = False + waiting = True else: # Schedule any other tasks waiting on the completion of this task. while t.state.peek(): diff --git a/extmod/uasyncio/task.py b/extmod/uasyncio/task.py index 26df7b1725..d775164909 100644 --- a/extmod/uasyncio/task.py +++ b/extmod/uasyncio/task.py @@ -123,7 +123,7 @@ class Task: def __init__(self, coro, globals=None): self.coro = coro # Coroutine of this Task self.data = None # General data for queue it is waiting on - self.state = True # None, False, True or a TaskQueue instance + self.state = True # None, False, True, a callable, or a TaskQueue instance self.ph_key = 0 # Pairing heap self.ph_child = None # Paring heap self.ph_child_last = None # Paring heap @@ -137,6 +137,9 @@ class Task: elif self.state is True: # Allocated head of linked list of Tasks waiting on completion of this task. self.state = TaskQueue() + elif type(self.state) is not TaskQueue: + # Task has state used for another purpose, so can't also wait on it. + raise RuntimeError("can't wait") return self def __next__(self):