syndicate-py/syndicate/relay.py

491 lines
18 KiB
Python

import sys
import asyncio
from preserves import Embedded, stringify
from preserves.fold import map_embeddeds
from . import actor, encode, transport, Decoder, gatekeeper, turn
from .during import During
from .actor import _inert_ref
from .idgen import IdGenerator
from .schema import protocol, sturdy, transportAddress
class InboundAssertion:
def __init__(self, remote_handle, local_handle, pins):
self.remote_handle = remote_handle
self.local_handle = local_handle
self.pins = pins
class WireSymbol:
def __init__(self, oid, ref, membrane):
self.oid = oid
self.ref = ref
self.membrane = membrane
self.count = 0
def __repr__(self):
return '<ws:%d/%d:%r>' % (self.oid, self.count, self.ref)
def grab(self, pins):
self.count = self.count + 1
pins.append(self)
def drop(self):
self.count = self.count - 1
if self.count == 0:
del self.membrane.oid_map[self.oid]
del self.membrane.ref_map[self.ref]
class Membrane:
def __init__(self):
self.oid_map = {}
self.ref_map = {}
def _get(self, pins, map, key, is_transient, ws_maker):
ws = map.get(key, None)
if ws is None and ws_maker is not None:
ws = ws_maker()
self.oid_map[ws.oid] = ws
self.ref_map[ws.ref] = ws
if not is_transient and ws is not None:
ws.grab(pins)
return ws
def get_ref(self, pins, local_ref, is_transient, ws_maker):
return self._get(pins, self.ref_map, local_ref, is_transient, ws_maker)
def get_oid(self, pins, remote_oid, ws_maker):
return self._get(pins, self.oid_map, remote_oid, False, ws_maker)
def drop_all(wss):
for ws in wss:
ws.drop()
# There are other kinds of relay. This one has exactly two participants connected to each other.
class TunnelRelay:
def __init__(self,
address,
gatekeeper_peer = None,
gatekeeper_oid = 0,
publish_service = None,
publish_oid = 0,
on_connected = None,
on_disconnected = None,
connection_timeout = None,
):
self.facet = turn.active_facet()
self.facet.on_stop(self._shutdown)
self.address = address
self.gatekeeper_peer = gatekeeper_peer
self.gatekeeper_oid = gatekeeper_oid
self.publish_service = publish_service
self.publish_oid = publish_oid
self.connection_timeout = connection_timeout
self._reset()
self.facet.linked_task(
lambda facet: self._reconnecting_main(facet.actor._system,
on_connected = on_connected,
on_disconnected = on_disconnected))
def _reset(self):
self.inbound_assertions = {} # map remote handle to InboundAssertion
self.outbound_assertions = {} # map local handle to `WireSymbol`s
self.exported_references = Membrane()
self.imported_references = Membrane()
self.pending_turn = []
self._connected = False
self.gatekeeper_handle = None
if self.publish_service is None:
self.next_local_oid = IdGenerator(initial_value=0)
else:
self.next_local_oid = IdGenerator(initial_value=(self.publish_oid + 1))
# Very specific specialization of logic in rewrite_ref_out
ws = WireSymbol(self.publish_oid, self.publish_service, self.exported_references)
self.exported_references.get_ref([], self.publish_service, False, lambda: ws)
@property
def connected(self):
return self._connected
def _shutdown(self):
self._disconnect()
def deregister(self, handle):
drop_all(self.outbound_assertions.pop(handle, ()))
def _lookup_exported_oid(self, local_oid, pins):
ws = self.exported_references.get_oid(pins, local_oid, None)
if ws is None:
return _inert_ref
return ws.ref
def register_imported_oid(self, remote_oid, pins):
self.imported_references.get_oid(pins, remote_oid, None)
def register(self, target_oid, assertion, maybe_handle):
pins = []
self.register_imported_oid(target_oid, pins)
rewritten = map_embeddeds(
lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, pins)),
assertion)
if maybe_handle is not None:
self.outbound_assertions[maybe_handle] = pins
return rewritten
def rewrite_ref_out(self, r, is_transient, pins):
if isinstance(r.entity, RelayEntity) and r.entity.relay == self:
# TODO attenuation
return sturdy.WireRef.yours(sturdy.Oid(r.entity.oid), ())
else:
ws = self.exported_references.get_ref(
pins, r, is_transient, lambda: WireSymbol(next(self.next_local_oid), r,
self.exported_references))
return sturdy.WireRef.mine(sturdy.Oid(ws.oid))
def rewrite_in(self, assertion, pins):
rewritten = map_embeddeds(
lambda wire_ref: Embedded(self.rewrite_ref_in(wire_ref, pins)),
assertion)
return rewritten
def rewrite_ref_in(self, wire_ref, pins):
if wire_ref.VARIANT.name == 'mine':
oid = wire_ref.oid.value
ws = self.imported_references.get_oid(
pins, oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid)),
self.imported_references))
return ws.ref
else:
oid = wire_ref.oid.value
local_ref = self._lookup_exported_oid(oid, pins)
attenuation = wire_ref.attenuation
if len(attenuation) > 0:
raise NotImplementedError('Non-empty attenuations not yet implemented') # TODO
return local_ref
def _on_disconnected(self):
self._connected = False
def retract_inbound():
for ia in self.inbound_assertions.values():
turn.retract(ia.local_handle)
if self.gatekeeper_handle is not None:
turn.retract(self.gatekeeper_handle)
self._reset()
turn.run(self.facet, retract_inbound)
self._disconnect()
def _on_connected(self):
self._connected = True
if self.gatekeeper_peer is not None:
def connected_action():
gk = self.rewrite_ref_in(sturdy.WireRef.mine(sturdy.Oid(self.gatekeeper_oid)), [])
self.gatekeeper_handle = turn.publish(self.gatekeeper_peer, Embedded(gk))
turn.run(self.facet, connected_action)
def _on_event(self, v):
turn.run(self.facet, lambda: self._handle_event(v))
def _handle_event(self, v):
packet = protocol.Packet.decode(v)
# self.facet.log.info('IN: %r', packet)
variant = packet.VARIANT.name
if variant == 'Turn': self._handle_turn_events(packet.value.value)
elif variant == 'Error': self._on_error(packet.value.message, packet.value.detail)
elif variant == 'Extension': pass
def _on_error(self, message, detail):
self.facet.log.error('Error from server: %r (detail: %r)', message, detail)
self._disconnect()
def _handle_turn_events(self, events):
for e in events:
pins = []
ref = self._lookup_exported_oid(e.oid.value, pins)
event = e.event
variant = event.VARIANT.name
if variant == 'Assert':
self._handle_publish(pins, ref, event.value.assertion.value, event.value.handle.value)
elif variant == 'Retract':
self._handle_retract(pins, ref, event.value.handle.value)
elif variant == 'Message':
self._handle_message(pins, ref, event.value.body.value)
elif variant == 'Sync':
self._handle_sync(pins, ref, event.value.peer)
def _handle_publish(self, pins, ref, assertion, remote_handle):
assertion = self.rewrite_in(assertion, pins)
self.inbound_assertions[remote_handle] = \
InboundAssertion(remote_handle, turn.publish(ref, assertion), pins)
def _handle_retract(self, pins, ref, remote_handle):
ia = self.inbound_assertions.pop(remote_handle, None)
if ia is None:
raise ValueError('Peer retracted invalid handle %s' % (remote_handle,))
drop_all(ia.pins)
drop_all(pins)
turn.retract(ia.local_handle)
def _handle_message(self, pins, ref, message):
message = self.rewrite_in(message, pins)
for ws in pins:
if ws.count == 1:
raise ValueError('Cannot receive transient reference')
turn.send(ref, message)
drop_all(pins)
def _handle_sync(self, pins, ref, wire_peer):
peer = self.rewrite_ref_in(wire_peer, pins)
def done():
turn.send(peer, True)
drop_all(pins)
turn.sync(ref, done)
def _send(self, remote_oid, turn_event):
if len(self.pending_turn) == 0:
def flush_pending():
packet = protocol.Packet.Turn(protocol.Turn(self.pending_turn))
self.pending_turn = []
# self.facet.log.info('OUT: %r', packet)
self._send_bytes(encode(packet))
self.facet.actor._system.queue_task(lambda: turn.run(self.facet, flush_pending))
self.pending_turn.append(protocol.TurnEvent(protocol.Oid(remote_oid), turn_event))
def _send_bytes(self, bs):
raise Exception('subclassresponsibility')
def _disconnect(self):
raise Exception('subclassresponsibility')
async def _reconnecting_main(self, system, on_connected=None, on_disconnected=None):
should_run = True
while should_run and self.facet.alive:
did_connect = await self.main(system, on_connected=(on_connected or _default_on_connected))
should_run = await (on_disconnected or _default_on_disconnected)(self, did_connect)
@staticmethod
def from_str(conn_str, **kwargs):
return transport.connection_from_str(conn_str, **kwargs)
# decorator
def connect(conn_str, cap = None, **kwargs):
def prepare_resolution_handler(handler):
@During().add_handler
def handle_gatekeeper(gk):
if cap is None:
handler(gk.embeddedValue)
else:
gatekeeper.resolve(gk.embeddedValue, cap)(handler)
return transport.connection_from_str(
conn_str,
gatekeeper_peer = turn.ref(handle_gatekeeper),
**kwargs)
return prepare_resolution_handler
class RelayEntity(actor.Entity):
def __init__(self, relay, oid):
self.relay = relay
self.oid = oid
def __repr__(self):
return '<Relay %s %s>' % (stringify(self.relay.address), self.oid)
def _send(self, e):
self.relay._send(self.oid, e)
def on_publish(self, assertion, handle):
self._send(protocol.Event.Assert(protocol.Assert(
protocol.Assertion(self.relay.register(self.oid, assertion, handle)),
protocol.Handle(handle))))
def on_retract(self, handle):
self.relay.deregister(handle)
self._send(protocol.Event.Retract(protocol.Retract(protocol.Handle(handle))))
def on_message(self, message):
self._send(protocol.Event.Message(protocol.Message(
protocol.Assertion(self.relay.register(self.oid, message, None)))))
def on_sync(self, peer):
pins = []
self.relay.register_imported_oid(self.oid, pins)
entity = SyncPeerEntity(self.relay, peer, pins)
rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, pins))
self._send(protocol.Event.Sync(protocol.Sync(rewritten)))
class SyncPeerEntity(actor.Entity):
def __init__(self, relay, peer, pins):
self.relay = relay
self.peer = peer
self.pins = pins
def on_message(self, body):
drop_all(self.pins)
turn.send(self.peer, body)
async def _default_on_connected(relay):
relay.facet.log.info('Connected')
async def _default_on_disconnected(relay, did_connect):
if did_connect:
# Reconnect immediately
relay.facet.log.info('Disconnected')
else:
await asyncio.sleep(2)
return True
class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
def __init__(self, address, **kwargs):
super().__init__(address, **kwargs)
self.decoder = None
self.stop_signal = None
self.transport = None
def connection_lost(self, exc):
self._on_disconnected()
def data_received(self, chunk):
self.decoder.extend(chunk)
while True:
v = self.decoder.try_next()
if v is None: break
self._on_event(v)
def _send_bytes(self, bs):
if self.transport:
self.transport.write(bs)
def _disconnect(self):
if self.stop_signal:
def set_stop_signal():
try:
self.stop_signal.set_result(True)
except:
pass
self.stop_signal.get_loop().call_soon_threadsafe(set_stop_signal)
async def _create_connection(self, system):
raise Exception('subclassresponsibility')
async def main(self, system, on_connected=None):
if self.transport is not None:
raise Exception('Cannot run connection twice!')
self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode)
self.stop_signal = system.loop.create_future()
try:
try:
transport, _protocol = await asyncio.wait_for(
self._create_connection(system), timeout=self.connection_timeout)
except asyncio.TimeoutError:
self.facet.log.error(
'%s: Timeout connecting to server' % (self.__class__.__qualname__,))
return False
except OSError as e:
self.facet.log.error(
'%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
self.transport = transport
self._on_connected()
if on_connected: await on_connected(self)
await self.stop_signal
return True
finally:
if self.transport:
self.transport.close()
self.transport = None
self.stop_signal = None
self.decoder = None
@transport.address(transportAddress.Tcp)
class TcpTunnelRelay(_StreamTunnelRelay):
async def _create_connection(self, system):
return await system.loop.create_connection(lambda: self, self.address.host, self.address.port)
@transport.address(transportAddress.Unix)
class UnixSocketTunnelRelay(_StreamTunnelRelay):
async def _create_connection(self, system):
return await system.loop.create_unix_connection(lambda: self, self.address.path)
@transport.address(transportAddress.WebSocket)
class WebsocketTunnelRelay(TunnelRelay):
def __init__(self, address, **kwargs):
super().__init__(address, **kwargs)
self.system = None
self.ws = None
def _send_bytes(self, bs):
if self.system:
def _do_send():
if self.ws:
self.system.queue_task(lambda: self.ws.send(bs))
self.system.loop.call_soon_threadsafe(_do_send)
def _disconnect(self):
if self.system:
def _do_disconnect():
if self.ws:
self.system.queue_task(lambda: self.ws.close())
self.system.loop.call_soon_threadsafe(_do_disconnect)
def __connection_error(self, e):
self.facet.log.error('Could not connect to server: %s' % (e,))
return False
async def main(self, system, on_connected=None):
import websockets
if self.ws is not None:
raise Exception('Cannot run connection twice!')
self.system = system
try:
self.ws = await websockets.connect(
self.address.url, open_timeout=self.connection_timeout)
except asyncio.TimeoutError:
return self.__connection_error('timeout')
except OSError as e:
return self.__connection_error(e)
except websockets.exceptions.InvalidHandshake as e:
return self.__connection_error(e)
try:
if on_connected: await on_connected(self)
self._on_connected()
while True:
chunk = await self.ws.recv()
self._on_event(Decoder(chunk, decode_embedded = sturdy.WireRef.decode).next())
except websockets.exceptions.WebSocketException:
pass
finally:
self._on_disconnected()
if self.ws:
await self.ws.close()
self.system = None
self.ws = None
return True
@transport.address(transportAddress.Stdio)
class PipeTunnelRelay(_StreamTunnelRelay):
def __init__(self, address, input_fileobj = sys.stdin, output_fileobj = sys.stdout, **kwargs):
super().__init__(address, **kwargs)
self.input_fileobj = input_fileobj
self.output_fileobj = output_fileobj
self.reader = asyncio.StreamReader()
async def _create_connection(self, system):
return await system.loop.connect_read_pipe(lambda: self, self.input_fileobj)
def _send_bytes(self, bs):
self.output_fileobj.buffer.write(bs)
self.output_fileobj.buffer.flush()
def run_stdio_service(entity):
PipeTunnelRelay(transportAddress.Stdio(), publish_service=turn.ref(entity))
# decorator
def service(**kwargs):
return lambda entity: actor.run_system(**kwargs)(lambda: run_stdio_service(entity))