From 2b06fc2758e92e3bea000676500e4d827ecf9d41 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Wed, 18 Aug 2021 22:59:04 -0400 Subject: [PATCH] Relay and chat demo; not quite there yet --- chat.py | 112 +++++------ syndicate/__init__.py | 2 + syndicate/actor.py | 92 ++++++--- syndicate/during.py | 9 +- syndicate/patterns.py | 32 ++++ syndicate/relay.py | 415 +++++++++++++++++++++++++++++++++++++++++ syndicate/schema.py | 13 +- syndicate/transport.py | 20 ++ 8 files changed, 607 insertions(+), 88 deletions(-) create mode 100644 syndicate/patterns.py create mode 100644 syndicate/relay.py create mode 100644 syndicate/transport.py diff --git a/chat.py b/chat.py index 648c35f..64f30f8 100644 --- a/chat.py +++ b/chat.py @@ -2,69 +2,75 @@ import sys import asyncio import random import threading -import syndicate.mini.core as S +import syndicate +from syndicate import patterns as P, actor +from syndicate.schema import simpleChatProtocol, gatekeeper, sturdy, dataspace +from syndicate.during import During -Present = S.Record.makeConstructor('Present', 'who') -Says = S.Record.makeConstructor('Says', 'who what') +Present = simpleChatProtocol.Present +Says = simpleChatProtocol.Says -if len(sys.argv) == 1: - conn_url = 'ws://localhost:8000/#chat' -elif len(sys.argv) == 2: - conn_url = sys.argv[1] -else: - sys.stderr.write( - 'Usage: chat.py [ tcp://HOST[:PORT]#SCOPE | ws://HOST[:PORT]#SCOPE | unix:PATH#SCOPE ]\n') - sys.exit(1) -conn = S.Connection.from_url(conn_url) +conn_str = '' +cap_str = '' +cap = sturdy.SturdyRef.decode(syndicate.parse(cap_str)) + +# sys.stderr.write( +# 'Usage: chat.py [ | | ]\n') +# sys.exit(1) + +me = 'user_' + str(random.randint(10, 1000)) _print = print def print(*items): _print(*items) sys.stdout.flush() -## Courtesy of http://listofrandomnames.com/ :-) -names = ['Daria', 'Kendra', 'Danny', 'Rufus', 'Diana', 'Arnetta', 'Dominick', 'Melonie', 'Regan', - 'Glenda', 'Janet', 'Luci', 'Ronnie', 'Vita', 'Amie', 'Stefani', 'Catherine', 'Grady', - 'Terrance', 'Rey', 'Fay', 'Shantae', 'Carlota', 'Judi', 'Crissy', 'Tasha', 'Jordan', - 'Rolande', 'Buster', 'Diamond', 'Dallas', 'Lissa', 'Yang', 'Charlena', 'Brooke', 'Haydee', - 'Griselda', 'Kasie', 'Clara', 'Claudie', 'Darell', 'Emery', 'Barbera', 'Chong', 'Karin', - 'Veronica', 'Karly', 'Shaunda', 'Nigel', 'Cleo'] +def on_presence(turn, who): + print('%s joined' % (who,)) + return lambda turn: print('%s left' % (who,)) -me = random.choice(names) + '_' + str(random.randint(10, 1000)) +def main_facet(turn, root_facet, ds): + print('main_facet', ds) + f = turn._facet + turn.publish(ds, Present(me)) + turn.publish(ds, dataspace.Observe(P.rec('Present', P.CAPTURE), + During(turn, on_add = on_presence).ref)) + turn.publish(ds, dataspace.Observe(P.rec('Says', P.CAPTURE, P.CAPTURE), During( + turn, + on_msg = lambda turn, who, what: print('%s says %r' % (who, what))).ref)) -with conn.turn() as t: - with conn.actor().react(t) as facet: - facet.add(Present(me)) - facet.add(S.Observe(Present(S.CAPTURE)), - on_add=lambda t, who: print(who, 'joined'), - on_del=lambda t, who: print(who, 'left')) - facet.add(S.Observe(Says(S.CAPTURE, S.CAPTURE)), - on_msg=lambda t, who, what: print(who, 'said', repr(what))) + loop = asyncio.get_running_loop() + def accept_input(): + while True: + line = sys.stdin.readline() + if not line: + actor.Turn.external(loop, f, lambda turn: turn.stop(root_facet)) + break + actor.Turn.external(loop, f, lambda turn: turn.send(ds, Says(me, line.strip()))) + threading.Thread(target=accept_input, daemon=True).start() -async def on_connected(): - print('-'*50, 'Connected') -async def on_disconnected(did_connect): - if did_connect: - print('-'*50, 'Disconnected') - else: - await asyncio.sleep(2) - return bool(conn) +def main(turn): + root_facet = turn._facet + gk_receiver = During(turn, on_add = lambda turn, gk: turn.publish( + gk.embeddedValue, gatekeeper.Resolve(cap, ds_receiver))).ref + ds_receiver = During(turn, on_add = lambda turn, ds: turn.facet( + lambda turn: main_facet(turn, root_facet, ds.embeddedValue))).ref -def accept_input(): - global conn - while True: - line = sys.stdin.readline() - if not line: - conn.destroy() - conn = None - break - with conn.turn() as t: - t.send(Says(me, line.strip())) + disarm = turn.prevent_inert_check() + async def on_connected(tr): + disarm() + print('-'*50, 'Connected') + async def on_disconnected(tr, did_connect): + if did_connect: + print('-'*50, 'Disconnected') + else: + await asyncio.sleep(2) + return True -loop = asyncio.get_event_loop() -loop.set_debug(True) -threading.Thread(target=accept_input, daemon=True).start() -loop.run_until_complete(conn.reconnecting_main(loop, on_connected, on_disconnected)) -loop.stop() -loop.run_forever() -loop.close() + conn = syndicate.relay.TunnelRelay.from_str(turn, + conn_str, + gatekeeper_peer = gk_receiver, + on_connected = on_connected, + on_disconnected = on_disconnected) + +actor.start_actor_system(main, name = 'chat', debug = False) diff --git a/syndicate/__init__.py b/syndicate/__init__.py index 9afe956..a548f3c 100644 --- a/syndicate/__init__.py +++ b/syndicate/__init__.py @@ -2,3 +2,5 @@ __path__ = __import__('pkgutil').extend_path(__path__, __name__) # This is 'import *' in order to effectively re-export preserves as part of this module's API. from preserves import * + +from . import relay diff --git a/syndicate/actor.py b/syndicate/actor.py index a64bdcb..4cc89a3 100644 --- a/syndicate/actor.py +++ b/syndicate/actor.py @@ -4,6 +4,8 @@ import logging import sys import traceback +from preserves import Embedded, preserve + from .idgen import IdGenerator log = logging.getLogger(__name__) @@ -12,11 +14,17 @@ _next_actor_number = IdGenerator() _next_handle = IdGenerator() _next_facet_id = IdGenerator() -def start_actor_system(boot_proc): +def start_actor_system(boot_proc, debug = False, name = None, configure_logging = True): + if configure_logging: + logging.basicConfig(level = logging.DEBUG if debug else logging.INFO) loop = asyncio.get_event_loop() - loop.set_debug(True) - queue_task(lambda: Actor(boot_proc), loop = loop) + if debug: + loop.set_debug(True) + queue_task(lambda: Actor(boot_proc, name = name), loop = loop) loop.run_forever() + while asyncio.all_tasks(loop): + loop.stop() + loop.run_forever() loop.close() def adjust_engine_inhabitant_count(delta): @@ -69,13 +77,14 @@ class Actor: def terminate(self, turn, exit_reason): if self.exit_reason is not None: return + self.log.debug('Terminating %r with exit_reason %r', self, exit_reason) self.exit_reason = exit_reason if exit_reason != True: self.log.error('crashed: %s' % (exit_reason,)) for h in self.exit_hooks: h(turn) def finish_termination(): - Turn.run(self, + Turn.run(self.root, lambda turn: self.root._terminate(turn, exit_reason == True), zombie_turn = True) if not self._daemon: @@ -130,6 +139,7 @@ class Facet: def _terminate(self, turn, orderly): if not self.alive: return + self.log.debug('%s terminating %r', 'orderly' if orderly else 'disorderly', self) self.alive = False parent = self.parent @@ -144,6 +154,7 @@ class Facet: h(turn) for e in self.outbound.values(): turn._retract(e) + self.outbound.clear() if orderly: if parent: @@ -161,12 +172,12 @@ class ActiveFacet: self.inner_facet = facet def __enter__(self): - self.outer_facet = self.turn.facet - self.turn.facet = self.inner_facet + self.outer_facet = self.turn._facet + self.turn._facet = self.inner_facet return None def __exit__(self, t, v, tb): - self.turn.facet = self.outer_facet + self.turn._facet = self.outer_facet self.outer_facet = None async def ensure_awaitable(value): @@ -180,6 +191,11 @@ def queue_task(thunk, loop = asyncio): await ensure_awaitable(thunk()) return loop.create_task(task()) +def queue_task_threadsafe(thunk, loop): + async def task(): + await ensure_awaitable(thunk()) + return asyncio.run_coroutine_threadsafe(task(), loop) + class Turn: @classmethod def run(cls, facet, action, zombie_turn = False): @@ -191,31 +207,38 @@ class Turn: action(turn) except: ei = sys.exc_info() - self.log.error('%s', ''.join(traceback.format_exception(*ei))) + facet.log.error('%s', ''.join(traceback.format_exception(*ei))) Turn.run(facet.actor.root, lambda turn: facet.actor.terminate(turn, ei[1])) else: turn._deliver() + @classmethod + def external(cls, loop, facet, action): + return queue_task_threadsafe(lambda: cls.run(facet, action), loop) + def __init__(self, facet): - self.facet = facet + self._facet = facet self.queues = {} @property def log(self): - return self.facet.actor.log + return self._facet.log def ref(self, entity): - return Ref(self.facet, entity) + return Ref(self._facet, entity) def facet(self, boot_proc): - new_facet = Facet(self.facet.actor, self.facet) + new_facet = Facet(self._facet.actor, self._facet) with ActiveFacet(self, new_facet): stop_if_inert_after(boot_proc)(self) return new_facet + def prevent_inert_check(self): + return self._facet.prevent_inert_check() + def stop(self, facet = None, continuation = None): if facet is None: - facet = self.facet + facet = self._facet def action(turn): facet._terminate(turn, True) if continuation is not None: @@ -227,19 +250,18 @@ class Turn: new_outbound = {} if initial_assertions is not None: for handle in initial_assertions: - new_outbound[handle] = self.facet.outbound[handle] - del self.facet.outbound[handle] + new_outbound[handle] = self._facet.outbound.pop(handle) queue_task(lambda: Actor(boot_proc, name = name, initial_assertions = new_outbound, daemon = daemon)) - self._enqueue(self.facet, action) + self._enqueue(self._facet, action) def stop_actor(self): - self._enqueue(self.facet.actor.root, lambda turn: self.facet.actor.terminate(turn, True)) + self._enqueue(self._facet.actor.root, lambda turn: self._facet.actor.terminate(turn, True)) def crash(self, exn): - self._enqueue(self.facet.actor.root, lambda turn: self.facet.actor.terminate(turn, exn)) + self._enqueue(self._facet.actor.root, lambda turn: self._facet.actor.terminate(turn, exn)) def publish(self, ref, assertion): handle = next(_next_handle) @@ -248,16 +270,18 @@ class Turn: def _publish(self, ref, assertion, handle): # TODO: attenuation + assertion = preserve(assertion) e = OutboundAssertion(handle, ref) - self.facet.outbound[handle] = e + self._facet.outbound[handle] = e def action(turn): e.established = True + self.log.debug('%r <-- publish %r handle %r', ref, assertion, handle) ref.entity.on_publish(turn, assertion, handle) self._enqueue(ref.facet, action) def retract(self, handle): if handle is not None: - e = self.facet.outbound.get(handle, None) + e = self._facet.outbound.pop(handle, None) if e is not None: self._retract(e) @@ -267,10 +291,11 @@ class Turn: return new_handle def _retract(self, e): - del self.facet.outbound[e.handle] + # Assumes e has already been removed from self._facet.outbound def action(turn): if e.established: e.established = False + self.log.debug('%r <-- retract handle %r', e.ref, e.handle) e.ref.entity.on_retract(turn, e.handle) self._enqueue(e.ref.facet, action) @@ -281,11 +306,17 @@ class Turn: self._sync(ref, self.ref(SyncContinuation())) def _sync(self, ref, peer): - self._enqueue(ref.facet, lambda turn: ref.entity.on_sync(turn, peer)) + peer = preserve(peer) + def action(turn): + self.log.debug('%r <-- sync peer %r', ref, peer) + ref.entity.on_sync(turn, peer) + self._enqueue(ref.facet, action) def send(self, ref, message): # TODO: attenuation + message = preserve(message) def action(turn): + self.log.debug('%r <-- message %r', ref, message) ref.entity.on_message(turn, message) self._enqueue(ref.facet, action) @@ -309,10 +340,10 @@ def stop_if_inert_after(action): def wrapped_action(turn): action(turn) def check_action(turn): - if (turn.facet.parent is not None and not turn.facet.parent.alive) \ - or turn.facet.isinert(): + if (turn._facet.parent is not None and not turn._facet.parent.alive) \ + or turn._facet.isinert(): turn.stop() - turn._enqueue(turn.facet, check_action) + turn._enqueue(turn._facet, check_action) return wrapped_action class Ref: @@ -343,4 +374,15 @@ class Entity: def on_sync(self, turn, peer): turn.send(peer, True) +_inert_actor = None +_inert_facet = None +_inert_ref = None _inert_entity = Entity() +def __boot_inert(turn): + global _inert_actor, _inert_facet, _inert_ref + _inert_actor = turn._facet.actor + _inert_facet = turn._facet + _inert_ref = turn.ref(_inert_entity) +async def __run_inert(): + Actor(__boot_inert, name = '_inert_actor') +asyncio.get_event_loop().run_until_complete(__run_inert()) diff --git a/syndicate/during.py b/syndicate/during.py index 215f880..c404580 100644 --- a/syndicate/during.py +++ b/syndicate/during.py @@ -27,12 +27,13 @@ class During(actor.Entity): def on_publish(self, turn, v, handle): retract_handler = self._on_add(turn, *self._wrap(v)) if retract_handler is not None: - self.retract_handlers[handle] = retract_handler + if isinstance(retract_handler, actor.Facet): + self.retract_handlers[handle] = lambda turn: turn.stop(retract_handler) + else: + self.retract_handlers[handle] = retract_handler def on_retract(self, turn, handle): - if handle in self.retract_handlers: - self.retract_handlers[handle](turn) - del self.retract_handlers[handle] + self.retract_handlers.pop(handle, lambda turn: ())(turn) def on_message(self, turn, v): self._on_msg(turn, *self._wrap(v)) diff --git a/syndicate/patterns.py b/syndicate/patterns.py new file mode 100644 index 0000000..743c644 --- /dev/null +++ b/syndicate/patterns.py @@ -0,0 +1,32 @@ +from .schema import dataspacePatterns as P +from . import Symbol + +_dict = dict ## we're about to shadow the builtin + +_ = P.Pattern.DDiscard(P.DDiscard()) + +def bind(p): + return P.Pattern.DBind(P.DBind(p)) + +CAPTURE = bind(_) + +def lit(v): + return P.Pattern.DLit(P.DLit(v)) + +def rec(labelstr, *members): + return _rec(Symbol(labelstr), *members) + +def _rec(label, *members): + return P.Pattern.DCompound(P.DCompound.rec( + P.CRec(label, len(members)), + _dict(enumerate(members)))) + +def arr(*members): + return P.Pattern.DCompound(P.DCompound.arr( + P.CArr(len(members)), + _dict(enumerate(members)))) + +def dict(*kvs): + return P.Pattern.DCompound(P.DCompound.dict( + P.CDict(), + _dict(kvs))) diff --git a/syndicate/relay.py b/syndicate/relay.py new file mode 100644 index 0000000..b7f60f4 --- /dev/null +++ b/syndicate/relay.py @@ -0,0 +1,415 @@ +import asyncio +import websockets +import logging + +from preserves import Embedded, stringify +from preserves.fold import map_embeddeds + +from . import actor, encode, transport, Decoder +from .actor import _inert_ref, Turn, adjust_engine_inhabitant_count +from .idgen import IdGenerator +from .schema import externalProtocol as protocol, sturdy, transportAddress + +class InboundAssertion: + def __init__(self, remote_handle, local_handle, wire_symbols): + self.remote_handle = remote_handle + self.local_handle = local_handle + self.wire_symbols = wire_symbols + +_next_local_oid = IdGenerator() + +class WireSymbol: + def __init__(self, oid, ref): + self.oid = oid + self.ref = ref + self.count = 0 + + def __repr__(self): + return '' % (self.oid, self.count, self.ref) + +class Membrane: + def __init__(self): + self.oid_map = {} + self.ref_map = {} + + def _get(self, map, key, is_transient, ws_maker): + ws = map.get(key, None) + if ws is None: + ws = ws_maker() + self.oid_map[ws.oid] = ws + self.ref_map[ws.ref] = ws + if not is_transient: + ws.count = ws.count + 1 + return ws + + def get_ref(self, local_ref, is_transient, ws_maker): + return self._get(self.ref_map, local_ref, is_transient, ws_maker) + + def get_oid(self, remote_oid, ws_maker): + return self._get(self.oid_map, remote_oid, False, ws_maker) + + def drop(self, ws): + ws.count = ws.count - 1 + if ws.count == 0: + del self.oid_map[ws.oid] + del self.ref_map[ws.ref] + +# There are other kinds of relay. This one has exactly two participants connected to each other. +class TunnelRelay: + def __init__(self, + turn, + address, + gatekeeper_peer = None, + gatekeeper_oid = 0, + on_connected = None, + on_disconnected = None, + ): + self.ref = turn.ref(self) + self.facet = turn._facet + self.facet.on_stop(self._shutdown) + self.address = address + self.gatekeeper_peer = gatekeeper_peer + self.gatekeeper_oid = gatekeeper_oid + self._reset() + actor.queue_task(lambda: self._reconnecting_main(asyncio.get_running_loop(), + on_connected = on_connected, + on_disconnected = on_disconnected)) + + def _reset(self): + self.inbound_assertions = {} # map remote handle to InboundAssertion + self.outbound_assertions = {} # map local handle to wire_symbols + self.exported_references = Membrane() + self.imported_references = Membrane() + self.pending_turn = [] + self._connected = False + self.gatekeeper_handle = None + + @property + def connected(self): + return self._connected + + def _shutdown(self, turn): + self._disconnect() + + def deregister(self, handle): + for ws in self.outbound_assertions.pop(handle, ()): + self.exported_references.drop(ws) + + def _lookup(self, local_oid): + ws = self.exported_references.oid_map.get(local_oid, None) + return _inert_ref if ws is None else ws.ref + + def register(self, assertion, maybe_handle): + exported = [] + rewritten = map_embeddeds( + lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, exported)), + assertion) + if maybe_handle is not None: + self.outbound_assertions[maybe_handle] = exported + return rewritten + + def rewrite_ref_out(self, r, is_transient, exported): + if isinstance(r.entity, RelayEntity) and r.entity.relay == self: + # TODO attenuation + return sturdy.WireRef.yours(sturdy.Oid(r.entity.oid), ()) + else: + ws = self.exported_references.get_ref( + r, is_transient, lambda: WireSymbol(next(_next_local_oid), r)) + exported.append(ws) + return sturdy.WireRef.mine(sturdy.Oid(ws.oid)) + + def rewrite_in(self, turn, assertion): + imported = [] + rewritten = map_embeddeds( + lambda wire_ref: Embedded(self.rewrite_ref_in(turn, wire_ref, imported)), + assertion) + return (rewritten, imported) + + def rewrite_ref_in(self, turn, wire_ref, imported): + if wire_ref.VARIANT.name == 'mine': + oid = wire_ref.oid.value + ws = self.imported_references.get_oid( + oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid)))) + imported.append(ws) + return ws.ref + else: + oid = wire_ref.oid.value + local_ref = self._lookup(oid) + attenuation = wire_ref.attenuation + if len(attenuation) > 0: + raise NotImplementedError('Non-empty attenuations not yet implemented') # TODO + return local_ref + + def _on_disconnected(self): + self._connected = False + def retract_inbound(turn): + for ia in self.inbound_assertions.values(): + turn.retract(ia.local_handle) + if self.gatekeeper_handle is not None: + turn.retract(self.gatekeeper_handle) + self._reset() + Turn.run(self.facet, retract_inbound) + self._disconnect() + + def _on_connected(self): + self._connected = True + if self.gatekeeper_peer is not None: + def connected_action(turn): + gk = self.rewrite_ref_in(turn, + sturdy.WireRef.mine(sturdy.Oid(self.gatekeeper_oid)), + []) + self.gatekeeper_handle = turn.publish(self.gatekeeper_peer, Embedded(gk)) + Turn.run(self.facet, connected_action) + + def _on_event(self, v): + Turn.run(self.facet, lambda turn: self._handle_event(turn, v)) + + def _handle_event(self, turn, v): + packet = protocol.Packet.decode(v) + variant = packet.VARIANT.name + if variant == 'Turn': self._handle_turn_events(turn, packet.value.value) + elif variant == 'Error': self._on_error(turn, packet.value.message, packet.value.detail) + + def _on_error(self, turn, message, detail): + self.facet.log.error('Error from server: %r (detail: %r)', message, detail) + self._disconnect() + + def _handle_turn_events(self, turn, events): + for e in events: + ref = self._lookup(e.oid.value) + event = e.event + variant = event.VARIANT.name + if variant == 'Assert': + self._handle_publish(turn, ref, event.value.assertion.value, event.value.handle.value) + elif variant == 'Retract': + self._handle_retract(turn, ref, event.value.handle.value) + elif variant == 'Message': + self._handle_message(turn, ref, event.value.body.value) + elif variant == 'Sync': + self._handle_sync(turn, ref, event.value.peer) + + def _handle_publish(self, turn, ref, assertion, remote_handle): + (assertion, imported) = self.rewrite_in(turn, assertion) + self.inbound_assertions[remote_handle] = \ + InboundAssertion(remote_handle, turn.publish(ref, assertion), imported) + + def _handle_retract(self, turn, ref, remote_handle): + ia = self.inbound_assertions.pop(remote_handle, None) + if ia is None: + raise ValueError('Peer retracted invalid handle %s' % (remote_handle,)) + for ws in ia.wire_symbols: + self.imported_references.drop(ws) + turn.retract(ia.local_handle) + + def _handle_message(self, turn, ref, message): + (message, imported) = self.rewrite_in(turn, message) + if len(imported) > 0: + raise ValueError('Cannot receive transient reference') + turn.send(ref, message) + + def _handle_sync(self, turn, ref, wire_peer): + imported = [] + peer = self.rewrite_ref_in(turn, wire_peer, imported) + def done(turn): + turn.send(peer, True) + for ws in imported: + self.imported_references.drop(ws) + turn.sync(ref, done) + + def _send(self, remote_oid, turn_event): + if len(self.pending_turn) == 0: + def flush_pending(turn): + packet = protocol.Packet.Turn(protocol.Turn(self.pending_turn)) + self.pending_turn = [] + self._send_bytes(encode(packet)) + actor.queue_task(lambda: Turn.run(self.facet, flush_pending)) + self.pending_turn.append(protocol.TurnEvent(protocol.Oid(remote_oid), turn_event)) + + def _send_bytes(self, bs): + raise Exception('subclassresponsibility') + + def _disconnect(self): + raise Exception('subclassresponsibility') + + async def _reconnecting_main(self, loop, on_connected=None, on_disconnected=None): + adjust_engine_inhabitant_count(1) + should_run = True + while should_run and self.facet.alive: + did_connect = await self.main(loop, on_connected=(on_connected or _default_on_connected)) + should_run = await (on_disconnected or _default_on_disconnected)(self, did_connect) + adjust_engine_inhabitant_count(-1) + + @staticmethod + def from_str(turn, s, **kwargs): + return transport.connection_from_str(turn, s, **kwargs) + +class RelayEntity(actor.Entity): + def __init__(self, relay, oid): + self.relay = relay + self.oid = oid + + def __repr__(self): + return '' % (stringify(self.relay.address), self.oid) + + def _send(self, e): + self.relay._send(self.oid, e) + + def on_publish(self, turn, assertion, handle): + self._send(protocol.Event.Assert(protocol.Assert( + protocol.Assertion(self.relay.register(assertion, handle)), + protocol.Handle(handle)))) + + def on_retract(self, turn, handle): + self.relay.deregister(handle) + self._send(protocol.Event.Retract(protocol.Retract(protocol.Handle(handle)))) + + def on_message(self, turn, message): + self._send(protocol.Event.Message(protocol.Message( + protocol.Assertion(self.relay.register(message, None))))) + + def on_sync(self, turn, peer): + exported = [] + entity = SyncPeerEntity(self.relay, peer, exported) + rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, exported)) + self._send(protocol.Event.Sync(protocol.Sync(rewritten))) + +class SyncPeerEntity(actor.Entity): + def __init__(self, relay, peer, exported): + self.relay = relay + self.peer = peer + self.exported = exported + + def on_message(self, turn, body): + self.relay.exported_references.drop(self.exported[0]) + turn.send(self.peer, body) + +async def _default_on_connected(relay): + relay.facet.log.info('Connected') + +async def _default_on_disconnected(relay, did_connect): + if did_connect: + # Reconnect immediately + relay.facet.log.info('Disconnected') + else: + await asyncio.sleep(2) + return True + +class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol): + def __init__(self, turn, address, **kwargs): + super().__init__(turn, address, **kwargs) + self.decoder = None + self.stop_signal = None + self.transport = None + + def connection_lost(self, exc): + self._on_disconnected() + + def connection_made(self, transport): + self.transport = transport + self._on_connected() + + def data_received(self, chunk): + self.decoder.extend(chunk) + while True: + v = self.decoder.try_next() + if v is None: break + self._on_event(v) + + def _send_bytes(self, bs): + if self.transport: + self.transport.write(bs) + + def _disconnect(self): + if self.stop_signal: + self.stop_signal.get_loop().call_soon_threadsafe( + lambda: self.stop_signal.set_result(True)) + + async def _create_connection(self, loop): + raise Exception('subclassresponsibility') + + async def main(self, loop, on_connected=None): + if self.transport is not None: + raise Exception('Cannot run connection twice!') + + self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode) + self.stop_signal = loop.create_future() + try: + _transport, _protocol = await self._create_connection(loop) + except OSError as e: + log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e)) + return False + + try: + if on_connected: await on_connected(self) + await self.stop_signal + return True + finally: + self.transport.close() + self.transport = None + self.stop_signal = None + self.decoder = None + +@transport.address(transportAddress.Tcp) +class TcpTunnelRelay(_StreamTunnelRelay): + async def _create_connection(self, loop): + return await loop.create_connection(lambda: self, self.address.host, self.address.port) + +@transport.address(transportAddress.Unix) +class UnixSocketTunnelRelay(_StreamTunnelRelay): + async def _create_connection(self, loop): + return await loop.create_unix_connection(lambda: self, self.address.path) + +@transport.address(transportAddress.WebSocket) +class WebsocketTunnelRelay(TunnelRelay): + def __init__(self, turn, address, **kwargs): + super().__init__(turn, address, **kwargs) + self.loop = None + self.ws = None + + def _send_bytes(self, bs): + if self.loop: + def _do_send(): + if self.ws: + self.loop.create_task(self.ws.send(bs)) + self.loop.call_soon_threadsafe(_do_send) + + def _disconnect(self): + if self.loop: + def _do_disconnect(): + if self.ws: + self.loop.create_task(self.ws.close()) + self.loop.call_soon_threadsafe(_do_disconnect) + + def __connection_error(self, e): + self.facet.log.error('Could not connect to server: %s' % (e,)) + return False + + async def main(self, loop, on_connected=None): + if self.ws is not None: + raise Exception('Cannot run connection twice!') + + self.loop = loop + + try: + self.ws = await websockets.connect(self.address.url) + except OSError as e: + return self.__connection_error(e) + except websockets.exceptions.InvalidHandshake as e: + return self.__connection_error(e) + + try: + if on_connected: await on_connected(self) + self._on_connected() + while True: + chunk = await self.ws.recv() + self._on_event(Decoder(chunk, decode_embedded = sturdy.WireRef.decode).next()) + except websockets.exceptions.WebSocketException: + pass + finally: + self._on_disconnected() + + if self.ws: + await self.ws.close() + self.loop = None + self.ws = None + return True diff --git a/syndicate/schema.py b/syndicate/schema.py index d636d34..463ebca 100644 --- a/syndicate/schema.py +++ b/syndicate/schema.py @@ -1,6 +1,7 @@ -from preserves.schema import load_schema_file -import pathlib - -for (n, ns) in load_schema_file(pathlib.Path(__file__).parent / - '../../syndicate-protocols/schema-bundle.bin')._items().items(): - globals()[n] = ns +def __load(): + from preserves.schema import load_schema_file + import pathlib + for (n, ns) in load_schema_file(pathlib.Path(__file__).parent / + '../../../syndicate-protocols/schema-bundle.bin')._items().items(): + globals()[n] = ns +__load() diff --git a/syndicate/transport.py b/syndicate/transport.py new file mode 100644 index 0000000..3373ee2 --- /dev/null +++ b/syndicate/transport.py @@ -0,0 +1,20 @@ +from preserves import parse + +constructors = {} + +class InvalidTransportAddress(ValueError): pass + +# decorator +def address(address_class): + def k(connection_factory_class): + constructors[address_class] = connection_factory_class + return connection_factory_class + return k + +def connection_from_str(turn, s, **kwargs): + address = parse(s) + for (address_class, factory_class) in constructors.items(): + decoded_address = address_class.try_decode(address) + if decoded_address is not None: + return factory_class(turn, decoded_address, **kwargs) + raise InvalidTransportAddress('Invalid transport address', address)