diff --git a/chat.py b/chat.py index 00b40bb..7100c75 100644 --- a/chat.py +++ b/chat.py @@ -40,19 +40,18 @@ def main_facet(turn, root_facet, ds): turn.publish(ds, dataspace.Observe(P.rec('Says', P.CAPTURE, P.CAPTURE), During(turn, on_msg = on_says).ref)) - loop = asyncio.get_running_loop() async def accept_input(): reader = asyncio.StreamReader() - print(await loop.connect_read_pipe(lambda: asyncio.StreamReaderProtocol(reader), sys.stdin)) + await actor.find_loop().connect_read_pipe( + lambda: asyncio.StreamReaderProtocol(reader), + sys.stdin) while True: - line = await reader.readline() - line = line.decode('utf-8') + line = (await reader.readline()).decode('utf-8') if not line: - actor.Turn.external(loop, f, lambda turn: turn.stop(root_facet)) + actor.Turn.external(f, lambda turn: turn.stop(root_facet)) break - actor.Turn.external(loop, f, lambda turn: turn.send(ds, Says(me, line.strip()))) - input_task = loop.create_task(accept_input()) - turn._facet.on_stop(lambda turn: input_task.cancel()) + actor.Turn.external(f, lambda turn: turn.send(ds, Says(me, line.strip()))) + turn.linked_task(accept_input()) def main(turn): root_facet = turn._facet diff --git a/syndicate/actor.py b/syndicate/actor.py index 4cc89a3..4452078 100644 --- a/syndicate/actor.py +++ b/syndicate/actor.py @@ -36,6 +36,12 @@ def adjust_engine_inhabitant_count(delta): log.debug('Inhabitant count reached zero') loop.stop() +def remove_noerror(collection, item): + try: + collection.remove(item) + except ValueError: + pass + class Actor: def __init__(self, boot_proc, name = None, initial_assertions = {}, daemon = False): self.name = name or 'a' + str(next(_next_actor_number)) @@ -75,6 +81,9 @@ class Actor: def at_exit(self, hook): self.exit_hooks.append(hook) + def cancel_at_exit(self, hook): + remove_noerror(self.exit_hooks, hook) + def terminate(self, turn, exit_reason): if self.exit_reason is not None: return self.log.debug('Terminating %r with exit_reason %r', self, exit_reason) @@ -124,6 +133,9 @@ class Facet: def on_stop(self, a): self.shutdown_actions.append(a) + def cancel_on_stop(self, a): + remove_noerror(self.shutdown_actions, a) + def isinert(self): return len(self.children) == 0 and len(self.outbound) == 0 and self.inert_check_preventers == 0 @@ -186,15 +198,18 @@ async def ensure_awaitable(value): else: return value -def queue_task(thunk, loop = asyncio): - async def task(): - await ensure_awaitable(thunk()) - return loop.create_task(task()) +def find_loop(loop = None): + return asyncio.get_running_loop() if loop is None else loop -def queue_task_threadsafe(thunk, loop): +def queue_task(thunk, loop = None): async def task(): await ensure_awaitable(thunk()) - return asyncio.run_coroutine_threadsafe(task(), loop) + return find_loop(loop).create_task(task()) + +def queue_task_threadsafe(thunk, loop = None): + async def task(): + await ensure_awaitable(thunk()) + return asyncio.run_coroutine_threadsafe(task(), find_loop(loop)) class Turn: @classmethod @@ -213,7 +228,7 @@ class Turn: turn._deliver() @classmethod - def external(cls, loop, facet, action): + def external(cls, facet, action, loop = None): return queue_task_threadsafe(lambda: cls.run(facet, action), loop) def __init__(self, facet): @@ -236,6 +251,24 @@ class Turn: def prevent_inert_check(self): return self._facet.prevent_inert_check() + def linked_task(self, coro, loop = None): + task = None + def cancel_linked_task(turn): + nonlocal task + if task is not None: + task.cancel() + task = None + self._facet.cancel_on_stop(cancel_linked_task) + self._facet.actor.cancel_at_exit(cancel_linked_task) + async def guarded_task(): + try: + await coro + finally: + Turn.external(self._facet, cancel_linked_task) + task = find_loop(loop).create_task(guarded_task()) + self._facet.on_stop(cancel_linked_task) + self._facet.actor.at_exit(cancel_linked_task) + def stop(self, facet = None, continuation = None): if facet is None: facet = self._facet