Introduce actor System to keep track of outstanding tasks

This commit is contained in:
Tony Garnock-Jones 2023-02-12 22:02:08 +01:00
parent b957490d78
commit 17f9833708
3 changed files with 86 additions and 75 deletions

View File

@ -5,7 +5,7 @@ except ImportError:
setup( setup(
name="syndicate-py", name="syndicate-py",
version="0.11.2", version="0.12.0",
author="Tony Garnock-Jones", author="Tony Garnock-Jones",
author_email="tonyg@leastfixedpoint.com", author_email="tonyg@leastfixedpoint.com",
license="GNU General Public License v3 or later (GPLv3+)", license="GNU General Public License v3 or later (GPLv3+)",

View File

@ -22,29 +22,52 @@ _active.turn = None
# decorator # decorator
def run_system(**kwargs): def run_system(**kwargs):
return lambda boot_proc: start_actor_system(boot_proc, **kwargs) return lambda boot_proc: System().run(boot_proc, **kwargs)
def start_actor_system(boot_proc, debug = False, name = None, configure_logging = True): class System:
if configure_logging: def __init__(self, loop = None):
logging.basicConfig(level = logging.DEBUG if debug else logging.INFO) self.tasks = set()
loop = asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
if debug: self.inhabitant_count = 0
loop.set_debug(True)
queue_task(lambda: Actor(boot_proc, name = name), loop = loop)
loop.run_forever()
while asyncio.all_tasks(loop):
loop.stop()
loop.run_forever()
loop.close()
def adjust_engine_inhabitant_count(delta): def run(self, boot_proc, debug = False, name = None, configure_logging = True):
loop = asyncio.get_running_loop() if configure_logging:
if not hasattr(loop, '__syndicate_inhabitant_count'): logging.basicConfig(level = logging.DEBUG if debug else logging.INFO)
loop.__syndicate_inhabitant_count = 0 if debug:
loop.__syndicate_inhabitant_count = loop.__syndicate_inhabitant_count + delta self.loop.set_debug(True)
if loop.__syndicate_inhabitant_count == 0: self.queue_task(lambda: Actor(boot_proc, system = self, name = name))
log.debug('Inhabitant count reached zero') self.loop.run_forever()
loop.stop() while asyncio.all_tasks(self.loop):
self.loop.stop()
self.loop.run_forever()
self.loop.close()
def adjust_engine_inhabitant_count(self, delta):
self.inhabitant_count = self.inhabitant_count + delta
if self.inhabitant_count == 0:
log.debug('Inhabitant count reached zero')
self.loop.stop()
def queue_task(self, thunk):
async def task():
try:
await ensure_awaitable(thunk())
finally:
self.tasks.remove(t)
t = self.loop.create_task(task())
self.tasks.add(t)
return t
def queue_task_threadsafe(self, thunk):
async def task():
await ensure_awaitable(thunk())
return asyncio.run_coroutine_threadsafe(task(), self.loop)
async def ensure_awaitable(value):
if inspect.isawaitable(value):
return await value
else:
return value
def remove_noerror(collection, item): def remove_noerror(collection, item):
try: try:
@ -53,11 +76,12 @@ def remove_noerror(collection, item):
pass pass
class Actor: class Actor:
def __init__(self, boot_proc, name = None, initial_assertions = {}, daemon = False): def __init__(self, boot_proc, system, name = None, initial_assertions = {}, daemon = False):
self.name = name or 'a' + str(next(_next_actor_number)) self.name = name or 'a' + str(next(_next_actor_number))
self._system = system
self._daemon = daemon self._daemon = daemon
if not daemon: if not daemon:
adjust_engine_inhabitant_count(1) system.adjust_engine_inhabitant_count(1)
self.root = Facet(self, None) self.root = Facet(self, None)
self.outbound = initial_assertions or {} self.outbound = initial_assertions or {}
self.exit_reason = None # None -> running, True -> terminated OK, exn -> error self.exit_reason = None # None -> running, True -> terminated OK, exn -> error
@ -77,7 +101,7 @@ class Actor:
def daemon(self, value): def daemon(self, value):
if self._daemon != value: if self._daemon != value:
self._daemon = value self._daemon = value
adjust_engine_inhabitant_count(-1 if value else 1) self._system.adjust_engine_inhabitant_count(-1 if value else 1)
@property @property
def alive(self): def alive(self):
@ -115,7 +139,7 @@ class Actor:
h() h()
self.root._terminate(exit_reason == True) self.root._terminate(exit_reason == True)
if not self._daemon: if not self._daemon:
adjust_engine_inhabitant_count(-1) self._system.adjust_engine_inhabitant_count(-1)
def _pop_outbound(self, handle, clear_from_source_facet): def _pop_outbound(self, handle, clear_from_source_facet):
e = self.outbound.pop(handle) e = self.outbound.pop(handle)
@ -214,7 +238,7 @@ class Facet:
await coro_fn(self) await coro_fn(self)
finally: finally:
Turn.external(self, cancel_linked_task) Turn.external(self, cancel_linked_task)
task = find_loop(loop).create_task(guarded_task()) task = self.actor._system.loop.create_task(guarded_task())
self.linked_tasks.append(task) self.linked_tasks.append(task)
def _terminate(self, orderly): def _terminate(self, orderly):
@ -262,25 +286,9 @@ class ActiveFacet:
self.turn._facet = self.outer_facet self.turn._facet = self.outer_facet
self.outer_facet = None self.outer_facet = None
async def ensure_awaitable(value):
if inspect.isawaitable(value):
return await value
else:
return value
def find_loop(loop = None): def find_loop(loop = None):
return asyncio.get_running_loop() if loop is None else loop return asyncio.get_running_loop() if loop is None else loop
def queue_task(thunk, loop = None):
async def task():
await ensure_awaitable(thunk())
return find_loop(loop).create_task(task())
def queue_task_threadsafe(thunk, loop = None):
async def task():
await ensure_awaitable(thunk())
return asyncio.run_coroutine_threadsafe(task(), find_loop(loop))
class Turn: class Turn:
@staticproperty @staticproperty
def active(): def active():
@ -312,10 +320,11 @@ class Turn:
@classmethod @classmethod
def external(cls, facet, action, loop = None): def external(cls, facet, action, loop = None):
return queue_task_threadsafe(lambda: cls.run(facet, action), loop) return facet.actor._system.queue_task_threadsafe(lambda: cls.run(facet, action))
def __init__(self, facet): def __init__(self, facet):
self._facet = facet self._facet = facet
self._system = facet.actor._system
self.queues = {} self.queues = {}
@property @property
@ -361,10 +370,11 @@ class Turn:
for handle in initial_handles: for handle in initial_handles:
new_outbound[handle] = \ new_outbound[handle] = \
self._facet.actor._pop_outbound(handle, clear_from_source_facet=True) self._facet.actor._pop_outbound(handle, clear_from_source_facet=True)
queue_task(lambda: Actor(boot_proc, self._system.queue_task(lambda: Actor(boot_proc,
name = name, system = self._system,
initial_assertions = new_outbound, name = name,
daemon = daemon)) initial_assertions = new_outbound,
daemon = daemon))
self._enqueue(self._facet, action) self._enqueue(self._facet, action)
def stop_actor(self): def stop_actor(self):
@ -481,7 +491,7 @@ class Turn:
action() action()
turn._facet = saved_facet turn._facet = saved_facet
return lambda: Turn.run(actor.root, deliver_q) return lambda: Turn.run(actor.root, deliver_q)
queue_task(make_deliver_q(actor, q)) self._system.queue_task(make_deliver_q(actor, q))
self.queues = {} self.queues = {}
def stop_if_inert_after(action): def stop_if_inert_after(action):
@ -552,7 +562,7 @@ def __boot_inert():
_inert_facet = Turn.active._facet _inert_facet = Turn.active._facet
_inert_ref = Turn.active.ref(_inert_entity) _inert_ref = Turn.active.ref(_inert_entity)
async def __run_inert(): async def __run_inert():
Actor(__boot_inert, name = '_inert_actor') Actor(__boot_inert, system = System(), name = '_inert_actor')
def __setup_inert(): def __setup_inert():
def setup_main(): def setup_main():
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()

View File

@ -1,7 +1,6 @@
import sys import sys
import asyncio import asyncio
import websockets import websockets
import logging
from preserves import Embedded, stringify from preserves import Embedded, stringify
from preserves.fold import map_embeddeds from preserves.fold import map_embeddeds
@ -83,7 +82,7 @@ class TunnelRelay:
self.publish_oid = publish_oid self.publish_oid = publish_oid
self._reset() self._reset()
self.facet.linked_task( self.facet.linked_task(
lambda facet: self._reconnecting_main(asyncio.get_running_loop(), lambda facet: self._reconnecting_main(facet.actor._system,
on_connected = on_connected, on_connected = on_connected,
on_disconnected = on_disconnected)) on_disconnected = on_disconnected))
@ -187,6 +186,7 @@ class TunnelRelay:
def _handle_event(self, v): def _handle_event(self, v):
packet = protocol.Packet.decode(v) packet = protocol.Packet.decode(v)
# self.facet.log.info('IN: %r', packet)
variant = packet.VARIANT.name variant = packet.VARIANT.name
if variant == 'Turn': self._handle_turn_events(packet.value.value) if variant == 'Turn': self._handle_turn_events(packet.value.value)
elif variant == 'Error': self._on_error(packet.value.message, packet.value.detail) elif variant == 'Error': self._on_error(packet.value.message, packet.value.detail)
@ -244,8 +244,9 @@ class TunnelRelay:
def flush_pending(): def flush_pending():
packet = protocol.Packet.Turn(protocol.Turn(self.pending_turn)) packet = protocol.Packet.Turn(protocol.Turn(self.pending_turn))
self.pending_turn = [] self.pending_turn = []
# self.facet.log.info('OUT: %r', packet)
self._send_bytes(encode(packet)) self._send_bytes(encode(packet))
actor.queue_task(lambda: turn.run(self.facet, flush_pending)) self.facet.actor._system.queue_task(lambda: turn.run(self.facet, flush_pending))
self.pending_turn.append(protocol.TurnEvent(protocol.Oid(remote_oid), turn_event)) self.pending_turn.append(protocol.TurnEvent(protocol.Oid(remote_oid), turn_event))
def _send_bytes(self, bs): def _send_bytes(self, bs):
@ -254,10 +255,10 @@ class TunnelRelay:
def _disconnect(self): def _disconnect(self):
raise Exception('subclassresponsibility') raise Exception('subclassresponsibility')
async def _reconnecting_main(self, loop, on_connected=None, on_disconnected=None): async def _reconnecting_main(self, system, on_connected=None, on_disconnected=None):
should_run = True should_run = True
while should_run and self.facet.alive: while should_run and self.facet.alive:
did_connect = await self.main(loop, on_connected=(on_connected or _default_on_connected)) did_connect = await self.main(system, on_connected=(on_connected or _default_on_connected))
should_run = await (on_disconnected or _default_on_disconnected)(self, did_connect) should_run = await (on_disconnected or _default_on_disconnected)(self, did_connect)
@staticmethod @staticmethod
@ -362,17 +363,17 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
pass pass
self.stop_signal.get_loop().call_soon_threadsafe(set_stop_signal) self.stop_signal.get_loop().call_soon_threadsafe(set_stop_signal)
async def _create_connection(self, loop): async def _create_connection(self, system):
raise Exception('subclassresponsibility') raise Exception('subclassresponsibility')
async def main(self, loop, on_connected=None): async def main(self, system, on_connected=None):
if self.transport is not None: if self.transport is not None:
raise Exception('Cannot run connection twice!') raise Exception('Cannot run connection twice!')
self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode) self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode)
self.stop_signal = loop.create_future() self.stop_signal = system.loop.create_future()
try: try:
_transport, _protocol = await self._create_connection(loop) _transport, _protocol = await self._create_connection(system)
except OSError as e: except OSError as e:
log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e)) log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False return False
@ -389,44 +390,44 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
@transport.address(transportAddress.Tcp) @transport.address(transportAddress.Tcp)
class TcpTunnelRelay(_StreamTunnelRelay): class TcpTunnelRelay(_StreamTunnelRelay):
async def _create_connection(self, loop): async def _create_connection(self, system):
return await loop.create_connection(lambda: self, self.address.host, self.address.port) return await system.loop.create_connection(lambda: self, self.address.host, self.address.port)
@transport.address(transportAddress.Unix) @transport.address(transportAddress.Unix)
class UnixSocketTunnelRelay(_StreamTunnelRelay): class UnixSocketTunnelRelay(_StreamTunnelRelay):
async def _create_connection(self, loop): async def _create_connection(self, system):
return await loop.create_unix_connection(lambda: self, self.address.path) return await system.loop.create_unix_connection(lambda: self, self.address.path)
@transport.address(transportAddress.WebSocket) @transport.address(transportAddress.WebSocket)
class WebsocketTunnelRelay(TunnelRelay): class WebsocketTunnelRelay(TunnelRelay):
def __init__(self, address, **kwargs): def __init__(self, address, **kwargs):
super().__init__(address, **kwargs) super().__init__(address, **kwargs)
self.loop = None self.system = None
self.ws = None self.ws = None
def _send_bytes(self, bs): def _send_bytes(self, bs):
if self.loop: if self.system:
def _do_send(): def _do_send():
if self.ws: if self.ws:
self.loop.create_task(self.ws.send(bs)) self.system.queue_task(lambda: self.ws.send(bs))
self.loop.call_soon_threadsafe(_do_send) self.system.loop.call_soon_threadsafe(_do_send)
def _disconnect(self): def _disconnect(self):
if self.loop: if self.system:
def _do_disconnect(): def _do_disconnect():
if self.ws: if self.ws:
self.loop.create_task(self.ws.close()) self.system.queue_task(lambda: self.ws.close())
self.loop.call_soon_threadsafe(_do_disconnect) self.system.loop.call_soon_threadsafe(_do_disconnect)
def __connection_error(self, e): def __connection_error(self, e):
self.facet.log.error('Could not connect to server: %s' % (e,)) self.facet.log.error('Could not connect to server: %s' % (e,))
return False return False
async def main(self, loop, on_connected=None): async def main(self, system, on_connected=None):
if self.ws is not None: if self.ws is not None:
raise Exception('Cannot run connection twice!') raise Exception('Cannot run connection twice!')
self.loop = loop self.system = system
try: try:
self.ws = await websockets.connect(self.address.url) self.ws = await websockets.connect(self.address.url)
@ -448,7 +449,7 @@ class WebsocketTunnelRelay(TunnelRelay):
if self.ws: if self.ws:
await self.ws.close() await self.ws.close()
self.loop = None self.system = None
self.ws = None self.ws = None
return True return True
@ -460,8 +461,8 @@ class PipeTunnelRelay(_StreamTunnelRelay):
self.output_fileobj = output_fileobj self.output_fileobj = output_fileobj
self.reader = asyncio.StreamReader() self.reader = asyncio.StreamReader()
async def _create_connection(self, loop): async def _create_connection(self, system):
return await loop.connect_read_pipe(lambda: self, self.input_fileobj) return await system.loop.connect_read_pipe(lambda: self, self.input_fileobj)
def _send_bytes(self, bs): def _send_bytes(self, bs):
self.output_fileobj.buffer.write(bs) self.output_fileobj.buffer.write(bs)