Relay and chat demo; not quite there yet

This commit is contained in:
Tony Garnock-Jones 2021-08-18 22:59:04 -04:00
parent 10f4bc9e34
commit 2b06fc2758
8 changed files with 607 additions and 88 deletions

112
chat.py
View File

@ -2,69 +2,75 @@ import sys
import asyncio
import random
import threading
import syndicate.mini.core as S
import syndicate
from syndicate import patterns as P, actor
from syndicate.schema import simpleChatProtocol, gatekeeper, sturdy, dataspace
from syndicate.during import During
Present = S.Record.makeConstructor('Present', 'who')
Says = S.Record.makeConstructor('Says', 'who what')
Present = simpleChatProtocol.Present
Says = simpleChatProtocol.Says
if len(sys.argv) == 1:
conn_url = 'ws://localhost:8000/#chat'
elif len(sys.argv) == 2:
conn_url = sys.argv[1]
else:
sys.stderr.write(
'Usage: chat.py [ tcp://HOST[:PORT]#SCOPE | ws://HOST[:PORT]#SCOPE | unix:PATH#SCOPE ]\n')
sys.exit(1)
conn = S.Connection.from_url(conn_url)
conn_str = '<ws "ws://localhost:8001/">'
cap_str = '<ref "syndicate" [] #[pkgN9TBmEd3Q04grVG4Zdw==]>'
cap = sturdy.SturdyRef.decode(syndicate.parse(cap_str))
# sys.stderr.write(
# 'Usage: chat.py [ <tcp "HOST" PORT> | <ws "ws://HOST[:PORT]/"> | <unix "PATH"> ]\n')
# sys.exit(1)
me = 'user_' + str(random.randint(10, 1000))
_print = print
def print(*items):
_print(*items)
sys.stdout.flush()
## Courtesy of http://listofrandomnames.com/ :-)
names = ['Daria', 'Kendra', 'Danny', 'Rufus', 'Diana', 'Arnetta', 'Dominick', 'Melonie', 'Regan',
'Glenda', 'Janet', 'Luci', 'Ronnie', 'Vita', 'Amie', 'Stefani', 'Catherine', 'Grady',
'Terrance', 'Rey', 'Fay', 'Shantae', 'Carlota', 'Judi', 'Crissy', 'Tasha', 'Jordan',
'Rolande', 'Buster', 'Diamond', 'Dallas', 'Lissa', 'Yang', 'Charlena', 'Brooke', 'Haydee',
'Griselda', 'Kasie', 'Clara', 'Claudie', 'Darell', 'Emery', 'Barbera', 'Chong', 'Karin',
'Veronica', 'Karly', 'Shaunda', 'Nigel', 'Cleo']
def on_presence(turn, who):
print('%s joined' % (who,))
return lambda turn: print('%s left' % (who,))
me = random.choice(names) + '_' + str(random.randint(10, 1000))
def main_facet(turn, root_facet, ds):
print('main_facet', ds)
f = turn._facet
turn.publish(ds, Present(me))
turn.publish(ds, dataspace.Observe(P.rec('Present', P.CAPTURE),
During(turn, on_add = on_presence).ref))
turn.publish(ds, dataspace.Observe(P.rec('Says', P.CAPTURE, P.CAPTURE), During(
turn,
on_msg = lambda turn, who, what: print('%s says %r' % (who, what))).ref))
with conn.turn() as t:
with conn.actor().react(t) as facet:
facet.add(Present(me))
facet.add(S.Observe(Present(S.CAPTURE)),
on_add=lambda t, who: print(who, 'joined'),
on_del=lambda t, who: print(who, 'left'))
facet.add(S.Observe(Says(S.CAPTURE, S.CAPTURE)),
on_msg=lambda t, who, what: print(who, 'said', repr(what)))
loop = asyncio.get_running_loop()
def accept_input():
while True:
line = sys.stdin.readline()
if not line:
actor.Turn.external(loop, f, lambda turn: turn.stop(root_facet))
break
actor.Turn.external(loop, f, lambda turn: turn.send(ds, Says(me, line.strip())))
threading.Thread(target=accept_input, daemon=True).start()
async def on_connected():
print('-'*50, 'Connected')
async def on_disconnected(did_connect):
if did_connect:
print('-'*50, 'Disconnected')
else:
await asyncio.sleep(2)
return bool(conn)
def main(turn):
root_facet = turn._facet
gk_receiver = During(turn, on_add = lambda turn, gk: turn.publish(
gk.embeddedValue, gatekeeper.Resolve(cap, ds_receiver))).ref
ds_receiver = During(turn, on_add = lambda turn, ds: turn.facet(
lambda turn: main_facet(turn, root_facet, ds.embeddedValue))).ref
def accept_input():
global conn
while True:
line = sys.stdin.readline()
if not line:
conn.destroy()
conn = None
break
with conn.turn() as t:
t.send(Says(me, line.strip()))
disarm = turn.prevent_inert_check()
async def on_connected(tr):
disarm()
print('-'*50, 'Connected')
async def on_disconnected(tr, did_connect):
if did_connect:
print('-'*50, 'Disconnected')
else:
await asyncio.sleep(2)
return True
loop = asyncio.get_event_loop()
loop.set_debug(True)
threading.Thread(target=accept_input, daemon=True).start()
loop.run_until_complete(conn.reconnecting_main(loop, on_connected, on_disconnected))
loop.stop()
loop.run_forever()
loop.close()
conn = syndicate.relay.TunnelRelay.from_str(turn,
conn_str,
gatekeeper_peer = gk_receiver,
on_connected = on_connected,
on_disconnected = on_disconnected)
actor.start_actor_system(main, name = 'chat', debug = False)

View File

@ -2,3 +2,5 @@ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
# This is 'import *' in order to effectively re-export preserves as part of this module's API.
from preserves import *
from . import relay

View File

@ -4,6 +4,8 @@ import logging
import sys
import traceback
from preserves import Embedded, preserve
from .idgen import IdGenerator
log = logging.getLogger(__name__)
@ -12,11 +14,17 @@ _next_actor_number = IdGenerator()
_next_handle = IdGenerator()
_next_facet_id = IdGenerator()
def start_actor_system(boot_proc):
def start_actor_system(boot_proc, debug = False, name = None, configure_logging = True):
if configure_logging:
logging.basicConfig(level = logging.DEBUG if debug else logging.INFO)
loop = asyncio.get_event_loop()
loop.set_debug(True)
queue_task(lambda: Actor(boot_proc), loop = loop)
if debug:
loop.set_debug(True)
queue_task(lambda: Actor(boot_proc, name = name), loop = loop)
loop.run_forever()
while asyncio.all_tasks(loop):
loop.stop()
loop.run_forever()
loop.close()
def adjust_engine_inhabitant_count(delta):
@ -69,13 +77,14 @@ class Actor:
def terminate(self, turn, exit_reason):
if self.exit_reason is not None: return
self.log.debug('Terminating %r with exit_reason %r', self, exit_reason)
self.exit_reason = exit_reason
if exit_reason != True:
self.log.error('crashed: %s' % (exit_reason,))
for h in self.exit_hooks:
h(turn)
def finish_termination():
Turn.run(self,
Turn.run(self.root,
lambda turn: self.root._terminate(turn, exit_reason == True),
zombie_turn = True)
if not self._daemon:
@ -130,6 +139,7 @@ class Facet:
def _terminate(self, turn, orderly):
if not self.alive: return
self.log.debug('%s terminating %r', 'orderly' if orderly else 'disorderly', self)
self.alive = False
parent = self.parent
@ -144,6 +154,7 @@ class Facet:
h(turn)
for e in self.outbound.values():
turn._retract(e)
self.outbound.clear()
if orderly:
if parent:
@ -161,12 +172,12 @@ class ActiveFacet:
self.inner_facet = facet
def __enter__(self):
self.outer_facet = self.turn.facet
self.turn.facet = self.inner_facet
self.outer_facet = self.turn._facet
self.turn._facet = self.inner_facet
return None
def __exit__(self, t, v, tb):
self.turn.facet = self.outer_facet
self.turn._facet = self.outer_facet
self.outer_facet = None
async def ensure_awaitable(value):
@ -180,6 +191,11 @@ def queue_task(thunk, loop = asyncio):
await ensure_awaitable(thunk())
return loop.create_task(task())
def queue_task_threadsafe(thunk, loop):
async def task():
await ensure_awaitable(thunk())
return asyncio.run_coroutine_threadsafe(task(), loop)
class Turn:
@classmethod
def run(cls, facet, action, zombie_turn = False):
@ -191,31 +207,38 @@ class Turn:
action(turn)
except:
ei = sys.exc_info()
self.log.error('%s', ''.join(traceback.format_exception(*ei)))
facet.log.error('%s', ''.join(traceback.format_exception(*ei)))
Turn.run(facet.actor.root, lambda turn: facet.actor.terminate(turn, ei[1]))
else:
turn._deliver()
@classmethod
def external(cls, loop, facet, action):
return queue_task_threadsafe(lambda: cls.run(facet, action), loop)
def __init__(self, facet):
self.facet = facet
self._facet = facet
self.queues = {}
@property
def log(self):
return self.facet.actor.log
return self._facet.log
def ref(self, entity):
return Ref(self.facet, entity)
return Ref(self._facet, entity)
def facet(self, boot_proc):
new_facet = Facet(self.facet.actor, self.facet)
new_facet = Facet(self._facet.actor, self._facet)
with ActiveFacet(self, new_facet):
stop_if_inert_after(boot_proc)(self)
return new_facet
def prevent_inert_check(self):
return self._facet.prevent_inert_check()
def stop(self, facet = None, continuation = None):
if facet is None:
facet = self.facet
facet = self._facet
def action(turn):
facet._terminate(turn, True)
if continuation is not None:
@ -227,19 +250,18 @@ class Turn:
new_outbound = {}
if initial_assertions is not None:
for handle in initial_assertions:
new_outbound[handle] = self.facet.outbound[handle]
del self.facet.outbound[handle]
new_outbound[handle] = self._facet.outbound.pop(handle)
queue_task(lambda: Actor(boot_proc,
name = name,
initial_assertions = new_outbound,
daemon = daemon))
self._enqueue(self.facet, action)
self._enqueue(self._facet, action)
def stop_actor(self):
self._enqueue(self.facet.actor.root, lambda turn: self.facet.actor.terminate(turn, True))
self._enqueue(self._facet.actor.root, lambda turn: self._facet.actor.terminate(turn, True))
def crash(self, exn):
self._enqueue(self.facet.actor.root, lambda turn: self.facet.actor.terminate(turn, exn))
self._enqueue(self._facet.actor.root, lambda turn: self._facet.actor.terminate(turn, exn))
def publish(self, ref, assertion):
handle = next(_next_handle)
@ -248,16 +270,18 @@ class Turn:
def _publish(self, ref, assertion, handle):
# TODO: attenuation
assertion = preserve(assertion)
e = OutboundAssertion(handle, ref)
self.facet.outbound[handle] = e
self._facet.outbound[handle] = e
def action(turn):
e.established = True
self.log.debug('%r <-- publish %r handle %r', ref, assertion, handle)
ref.entity.on_publish(turn, assertion, handle)
self._enqueue(ref.facet, action)
def retract(self, handle):
if handle is not None:
e = self.facet.outbound.get(handle, None)
e = self._facet.outbound.pop(handle, None)
if e is not None:
self._retract(e)
@ -267,10 +291,11 @@ class Turn:
return new_handle
def _retract(self, e):
del self.facet.outbound[e.handle]
# Assumes e has already been removed from self._facet.outbound
def action(turn):
if e.established:
e.established = False
self.log.debug('%r <-- retract handle %r', e.ref, e.handle)
e.ref.entity.on_retract(turn, e.handle)
self._enqueue(e.ref.facet, action)
@ -281,11 +306,17 @@ class Turn:
self._sync(ref, self.ref(SyncContinuation()))
def _sync(self, ref, peer):
self._enqueue(ref.facet, lambda turn: ref.entity.on_sync(turn, peer))
peer = preserve(peer)
def action(turn):
self.log.debug('%r <-- sync peer %r', ref, peer)
ref.entity.on_sync(turn, peer)
self._enqueue(ref.facet, action)
def send(self, ref, message):
# TODO: attenuation
message = preserve(message)
def action(turn):
self.log.debug('%r <-- message %r', ref, message)
ref.entity.on_message(turn, message)
self._enqueue(ref.facet, action)
@ -309,10 +340,10 @@ def stop_if_inert_after(action):
def wrapped_action(turn):
action(turn)
def check_action(turn):
if (turn.facet.parent is not None and not turn.facet.parent.alive) \
or turn.facet.isinert():
if (turn._facet.parent is not None and not turn._facet.parent.alive) \
or turn._facet.isinert():
turn.stop()
turn._enqueue(turn.facet, check_action)
turn._enqueue(turn._facet, check_action)
return wrapped_action
class Ref:
@ -343,4 +374,15 @@ class Entity:
def on_sync(self, turn, peer):
turn.send(peer, True)
_inert_actor = None
_inert_facet = None
_inert_ref = None
_inert_entity = Entity()
def __boot_inert(turn):
global _inert_actor, _inert_facet, _inert_ref
_inert_actor = turn._facet.actor
_inert_facet = turn._facet
_inert_ref = turn.ref(_inert_entity)
async def __run_inert():
Actor(__boot_inert, name = '_inert_actor')
asyncio.get_event_loop().run_until_complete(__run_inert())

View File

@ -27,12 +27,13 @@ class During(actor.Entity):
def on_publish(self, turn, v, handle):
retract_handler = self._on_add(turn, *self._wrap(v))
if retract_handler is not None:
self.retract_handlers[handle] = retract_handler
if isinstance(retract_handler, actor.Facet):
self.retract_handlers[handle] = lambda turn: turn.stop(retract_handler)
else:
self.retract_handlers[handle] = retract_handler
def on_retract(self, turn, handle):
if handle in self.retract_handlers:
self.retract_handlers[handle](turn)
del self.retract_handlers[handle]
self.retract_handlers.pop(handle, lambda turn: ())(turn)
def on_message(self, turn, v):
self._on_msg(turn, *self._wrap(v))

32
syndicate/patterns.py Normal file
View File

@ -0,0 +1,32 @@
from .schema import dataspacePatterns as P
from . import Symbol
_dict = dict ## we're about to shadow the builtin
_ = P.Pattern.DDiscard(P.DDiscard())
def bind(p):
return P.Pattern.DBind(P.DBind(p))
CAPTURE = bind(_)
def lit(v):
return P.Pattern.DLit(P.DLit(v))
def rec(labelstr, *members):
return _rec(Symbol(labelstr), *members)
def _rec(label, *members):
return P.Pattern.DCompound(P.DCompound.rec(
P.CRec(label, len(members)),
_dict(enumerate(members))))
def arr(*members):
return P.Pattern.DCompound(P.DCompound.arr(
P.CArr(len(members)),
_dict(enumerate(members))))
def dict(*kvs):
return P.Pattern.DCompound(P.DCompound.dict(
P.CDict(),
_dict(kvs)))

415
syndicate/relay.py Normal file
View File

@ -0,0 +1,415 @@
import asyncio
import websockets
import logging
from preserves import Embedded, stringify
from preserves.fold import map_embeddeds
from . import actor, encode, transport, Decoder
from .actor import _inert_ref, Turn, adjust_engine_inhabitant_count
from .idgen import IdGenerator
from .schema import externalProtocol as protocol, sturdy, transportAddress
class InboundAssertion:
def __init__(self, remote_handle, local_handle, wire_symbols):
self.remote_handle = remote_handle
self.local_handle = local_handle
self.wire_symbols = wire_symbols
_next_local_oid = IdGenerator()
class WireSymbol:
def __init__(self, oid, ref):
self.oid = oid
self.ref = ref
self.count = 0
def __repr__(self):
return '<ws:%d/%d:%r>' % (self.oid, self.count, self.ref)
class Membrane:
def __init__(self):
self.oid_map = {}
self.ref_map = {}
def _get(self, map, key, is_transient, ws_maker):
ws = map.get(key, None)
if ws is None:
ws = ws_maker()
self.oid_map[ws.oid] = ws
self.ref_map[ws.ref] = ws
if not is_transient:
ws.count = ws.count + 1
return ws
def get_ref(self, local_ref, is_transient, ws_maker):
return self._get(self.ref_map, local_ref, is_transient, ws_maker)
def get_oid(self, remote_oid, ws_maker):
return self._get(self.oid_map, remote_oid, False, ws_maker)
def drop(self, ws):
ws.count = ws.count - 1
if ws.count == 0:
del self.oid_map[ws.oid]
del self.ref_map[ws.ref]
# There are other kinds of relay. This one has exactly two participants connected to each other.
class TunnelRelay:
def __init__(self,
turn,
address,
gatekeeper_peer = None,
gatekeeper_oid = 0,
on_connected = None,
on_disconnected = None,
):
self.ref = turn.ref(self)
self.facet = turn._facet
self.facet.on_stop(self._shutdown)
self.address = address
self.gatekeeper_peer = gatekeeper_peer
self.gatekeeper_oid = gatekeeper_oid
self._reset()
actor.queue_task(lambda: self._reconnecting_main(asyncio.get_running_loop(),
on_connected = on_connected,
on_disconnected = on_disconnected))
def _reset(self):
self.inbound_assertions = {} # map remote handle to InboundAssertion
self.outbound_assertions = {} # map local handle to wire_symbols
self.exported_references = Membrane()
self.imported_references = Membrane()
self.pending_turn = []
self._connected = False
self.gatekeeper_handle = None
@property
def connected(self):
return self._connected
def _shutdown(self, turn):
self._disconnect()
def deregister(self, handle):
for ws in self.outbound_assertions.pop(handle, ()):
self.exported_references.drop(ws)
def _lookup(self, local_oid):
ws = self.exported_references.oid_map.get(local_oid, None)
return _inert_ref if ws is None else ws.ref
def register(self, assertion, maybe_handle):
exported = []
rewritten = map_embeddeds(
lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, exported)),
assertion)
if maybe_handle is not None:
self.outbound_assertions[maybe_handle] = exported
return rewritten
def rewrite_ref_out(self, r, is_transient, exported):
if isinstance(r.entity, RelayEntity) and r.entity.relay == self:
# TODO attenuation
return sturdy.WireRef.yours(sturdy.Oid(r.entity.oid), ())
else:
ws = self.exported_references.get_ref(
r, is_transient, lambda: WireSymbol(next(_next_local_oid), r))
exported.append(ws)
return sturdy.WireRef.mine(sturdy.Oid(ws.oid))
def rewrite_in(self, turn, assertion):
imported = []
rewritten = map_embeddeds(
lambda wire_ref: Embedded(self.rewrite_ref_in(turn, wire_ref, imported)),
assertion)
return (rewritten, imported)
def rewrite_ref_in(self, turn, wire_ref, imported):
if wire_ref.VARIANT.name == 'mine':
oid = wire_ref.oid.value
ws = self.imported_references.get_oid(
oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid))))
imported.append(ws)
return ws.ref
else:
oid = wire_ref.oid.value
local_ref = self._lookup(oid)
attenuation = wire_ref.attenuation
if len(attenuation) > 0:
raise NotImplementedError('Non-empty attenuations not yet implemented') # TODO
return local_ref
def _on_disconnected(self):
self._connected = False
def retract_inbound(turn):
for ia in self.inbound_assertions.values():
turn.retract(ia.local_handle)
if self.gatekeeper_handle is not None:
turn.retract(self.gatekeeper_handle)
self._reset()
Turn.run(self.facet, retract_inbound)
self._disconnect()
def _on_connected(self):
self._connected = True
if self.gatekeeper_peer is not None:
def connected_action(turn):
gk = self.rewrite_ref_in(turn,
sturdy.WireRef.mine(sturdy.Oid(self.gatekeeper_oid)),
[])
self.gatekeeper_handle = turn.publish(self.gatekeeper_peer, Embedded(gk))
Turn.run(self.facet, connected_action)
def _on_event(self, v):
Turn.run(self.facet, lambda turn: self._handle_event(turn, v))
def _handle_event(self, turn, v):
packet = protocol.Packet.decode(v)
variant = packet.VARIANT.name
if variant == 'Turn': self._handle_turn_events(turn, packet.value.value)
elif variant == 'Error': self._on_error(turn, packet.value.message, packet.value.detail)
def _on_error(self, turn, message, detail):
self.facet.log.error('Error from server: %r (detail: %r)', message, detail)
self._disconnect()
def _handle_turn_events(self, turn, events):
for e in events:
ref = self._lookup(e.oid.value)
event = e.event
variant = event.VARIANT.name
if variant == 'Assert':
self._handle_publish(turn, ref, event.value.assertion.value, event.value.handle.value)
elif variant == 'Retract':
self._handle_retract(turn, ref, event.value.handle.value)
elif variant == 'Message':
self._handle_message(turn, ref, event.value.body.value)
elif variant == 'Sync':
self._handle_sync(turn, ref, event.value.peer)
def _handle_publish(self, turn, ref, assertion, remote_handle):
(assertion, imported) = self.rewrite_in(turn, assertion)
self.inbound_assertions[remote_handle] = \
InboundAssertion(remote_handle, turn.publish(ref, assertion), imported)
def _handle_retract(self, turn, ref, remote_handle):
ia = self.inbound_assertions.pop(remote_handle, None)
if ia is None:
raise ValueError('Peer retracted invalid handle %s' % (remote_handle,))
for ws in ia.wire_symbols:
self.imported_references.drop(ws)
turn.retract(ia.local_handle)
def _handle_message(self, turn, ref, message):
(message, imported) = self.rewrite_in(turn, message)
if len(imported) > 0:
raise ValueError('Cannot receive transient reference')
turn.send(ref, message)
def _handle_sync(self, turn, ref, wire_peer):
imported = []
peer = self.rewrite_ref_in(turn, wire_peer, imported)
def done(turn):
turn.send(peer, True)
for ws in imported:
self.imported_references.drop(ws)
turn.sync(ref, done)
def _send(self, remote_oid, turn_event):
if len(self.pending_turn) == 0:
def flush_pending(turn):
packet = protocol.Packet.Turn(protocol.Turn(self.pending_turn))
self.pending_turn = []
self._send_bytes(encode(packet))
actor.queue_task(lambda: Turn.run(self.facet, flush_pending))
self.pending_turn.append(protocol.TurnEvent(protocol.Oid(remote_oid), turn_event))
def _send_bytes(self, bs):
raise Exception('subclassresponsibility')
def _disconnect(self):
raise Exception('subclassresponsibility')
async def _reconnecting_main(self, loop, on_connected=None, on_disconnected=None):
adjust_engine_inhabitant_count(1)
should_run = True
while should_run and self.facet.alive:
did_connect = await self.main(loop, on_connected=(on_connected or _default_on_connected))
should_run = await (on_disconnected or _default_on_disconnected)(self, did_connect)
adjust_engine_inhabitant_count(-1)
@staticmethod
def from_str(turn, s, **kwargs):
return transport.connection_from_str(turn, s, **kwargs)
class RelayEntity(actor.Entity):
def __init__(self, relay, oid):
self.relay = relay
self.oid = oid
def __repr__(self):
return '<Relay %s %s>' % (stringify(self.relay.address), self.oid)
def _send(self, e):
self.relay._send(self.oid, e)
def on_publish(self, turn, assertion, handle):
self._send(protocol.Event.Assert(protocol.Assert(
protocol.Assertion(self.relay.register(assertion, handle)),
protocol.Handle(handle))))
def on_retract(self, turn, handle):
self.relay.deregister(handle)
self._send(protocol.Event.Retract(protocol.Retract(protocol.Handle(handle))))
def on_message(self, turn, message):
self._send(protocol.Event.Message(protocol.Message(
protocol.Assertion(self.relay.register(message, None)))))
def on_sync(self, turn, peer):
exported = []
entity = SyncPeerEntity(self.relay, peer, exported)
rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, exported))
self._send(protocol.Event.Sync(protocol.Sync(rewritten)))
class SyncPeerEntity(actor.Entity):
def __init__(self, relay, peer, exported):
self.relay = relay
self.peer = peer
self.exported = exported
def on_message(self, turn, body):
self.relay.exported_references.drop(self.exported[0])
turn.send(self.peer, body)
async def _default_on_connected(relay):
relay.facet.log.info('Connected')
async def _default_on_disconnected(relay, did_connect):
if did_connect:
# Reconnect immediately
relay.facet.log.info('Disconnected')
else:
await asyncio.sleep(2)
return True
class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
def __init__(self, turn, address, **kwargs):
super().__init__(turn, address, **kwargs)
self.decoder = None
self.stop_signal = None
self.transport = None
def connection_lost(self, exc):
self._on_disconnected()
def connection_made(self, transport):
self.transport = transport
self._on_connected()
def data_received(self, chunk):
self.decoder.extend(chunk)
while True:
v = self.decoder.try_next()
if v is None: break
self._on_event(v)
def _send_bytes(self, bs):
if self.transport:
self.transport.write(bs)
def _disconnect(self):
if self.stop_signal:
self.stop_signal.get_loop().call_soon_threadsafe(
lambda: self.stop_signal.set_result(True))
async def _create_connection(self, loop):
raise Exception('subclassresponsibility')
async def main(self, loop, on_connected=None):
if self.transport is not None:
raise Exception('Cannot run connection twice!')
self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode)
self.stop_signal = loop.create_future()
try:
_transport, _protocol = await self._create_connection(loop)
except OSError as e:
log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
try:
if on_connected: await on_connected(self)
await self.stop_signal
return True
finally:
self.transport.close()
self.transport = None
self.stop_signal = None
self.decoder = None
@transport.address(transportAddress.Tcp)
class TcpTunnelRelay(_StreamTunnelRelay):
async def _create_connection(self, loop):
return await loop.create_connection(lambda: self, self.address.host, self.address.port)
@transport.address(transportAddress.Unix)
class UnixSocketTunnelRelay(_StreamTunnelRelay):
async def _create_connection(self, loop):
return await loop.create_unix_connection(lambda: self, self.address.path)
@transport.address(transportAddress.WebSocket)
class WebsocketTunnelRelay(TunnelRelay):
def __init__(self, turn, address, **kwargs):
super().__init__(turn, address, **kwargs)
self.loop = None
self.ws = None
def _send_bytes(self, bs):
if self.loop:
def _do_send():
if self.ws:
self.loop.create_task(self.ws.send(bs))
self.loop.call_soon_threadsafe(_do_send)
def _disconnect(self):
if self.loop:
def _do_disconnect():
if self.ws:
self.loop.create_task(self.ws.close())
self.loop.call_soon_threadsafe(_do_disconnect)
def __connection_error(self, e):
self.facet.log.error('Could not connect to server: %s' % (e,))
return False
async def main(self, loop, on_connected=None):
if self.ws is not None:
raise Exception('Cannot run connection twice!')
self.loop = loop
try:
self.ws = await websockets.connect(self.address.url)
except OSError as e:
return self.__connection_error(e)
except websockets.exceptions.InvalidHandshake as e:
return self.__connection_error(e)
try:
if on_connected: await on_connected(self)
self._on_connected()
while True:
chunk = await self.ws.recv()
self._on_event(Decoder(chunk, decode_embedded = sturdy.WireRef.decode).next())
except websockets.exceptions.WebSocketException:
pass
finally:
self._on_disconnected()
if self.ws:
await self.ws.close()
self.loop = None
self.ws = None
return True

View File

@ -1,6 +1,7 @@
from preserves.schema import load_schema_file
import pathlib
for (n, ns) in load_schema_file(pathlib.Path(__file__).parent /
'../../syndicate-protocols/schema-bundle.bin')._items().items():
globals()[n] = ns
def __load():
from preserves.schema import load_schema_file
import pathlib
for (n, ns) in load_schema_file(pathlib.Path(__file__).parent /
'../../../syndicate-protocols/schema-bundle.bin')._items().items():
globals()[n] = ns
__load()

20
syndicate/transport.py Normal file
View File

@ -0,0 +1,20 @@
from preserves import parse
constructors = {}
class InvalidTransportAddress(ValueError): pass
# decorator
def address(address_class):
def k(connection_factory_class):
constructors[address_class] = connection_factory_class
return connection_factory_class
return k
def connection_from_str(turn, s, **kwargs):
address = parse(s)
for (address_class, factory_class) in constructors.items():
decoded_address = address_class.try_decode(address)
if decoded_address is not None:
return factory_class(turn, decoded_address, **kwargs)
raise InvalidTransportAddress('Invalid transport address', address)