mini-syndicate-py/syndicate/mini/core.py

331 lines
9.9 KiB
Python

import asyncio
import secrets
import logging
import websockets
log = logging.getLogger(__name__)
import syndicate.mini.protocol as protocol
from syndicate.mini.protocol import Capture, Discard, Observe
CAPTURE = Capture(Discard())
from preserves import *
def _encode(event):
e = protocol.Encoder()
e.append(event)
return e.contents()
_instance_id = secrets.token_urlsafe(8)
_uuid_counter = 0
def uuid(prefix='__@syndicate'):
global _uuid_counter
c = _uuid_counter
_uuid_counter = c + 1
return prefix + '_' + _instance_id + '_' + str(c)
def _ignore(*args, **kwargs):
pass
class Turn(object):
def __init__(self, conn):
self.conn = conn
self.items = []
def _extend(self, item):
self.items.append(item)
def _reset(self):
self.items.clear()
def _commit(self):
if self.items:
self.conn._send(protocol.Turn(self.items))
self._reset()
def send(self, message):
self._extend(protocol.Message(message))
def __enter__(self):
if self.items:
raise Exception('Cannot reenter with statement for Turn')
return self
def __exit__(self, t, v, tb):
if t is None:
self._commit()
def _fresh_id(assertion):
return uuid('sub' if Observe.isClassOf(assertion) else 'pub')
class Endpoint(object):
def __init__(self, turn, assertion,
on_add=None, on_del=None, on_msg=None):
self.assertion = None
self.id = None
self.on_add = on_add or _ignore
self.on_del = on_del or _ignore
self.on_msg = on_msg or _ignore
self.cache = set()
self.set(turn, assertion)
def set(self, turn, new_assertion, on_transition=None):
if self.id is not None:
turn.conn._unmap_endpoint(turn, self, on_end=on_transition)
self.id = None
self.assertion = new_assertion
if self.assertion is not None:
self.id = _fresh_id(self.assertion)
turn.conn._map_endpoint(turn, self)
def clear(self, turn, on_cleared=None):
self.set(turn, None, on_transition=on_cleared)
def _reset(self, turn):
for captures in set(self.cache):
self._del(turn, captures)
def _add(self, turn, captures):
if captures in self.cache:
log.error('Server error: duplicate captures %r added for endpoint %r %r' % (
captures,
self.id,
self.assertion))
else:
self.cache.add(captures)
self.on_add(turn, *captures)
def _del(self, turn, captures):
if captures in self.cache:
self.cache.discard(captures)
self.on_del(turn, *captures)
else:
log.error('Server error: nonexistent captures %r removed from endpoint %r %r' % (
captures,
self.id,
self.assertion))
def _msg(self, turn, captures):
self.on_msg(turn, *captures)
class DummyEndpoint(object):
def _add(self, turn, captures): pass
def _del(self, turn, captures): pass
def _msg(self, turn, captures): pass
_dummy_endpoint = DummyEndpoint()
class Connection(object):
def __init__(self, scope):
self.scope = scope
self.endpoints = {}
self.end_callbacks = {}
def _each_endpoint(self):
return list(self.endpoints.values())
def turn(self):
return Turn(self)
def destroy(self):
with self.turn() as t:
for ep in self._each_endpoint():
ep.clear(t)
t._reset() ## don't actually Clear the endpoints, we are about to disconnect
self._disconnect()
def _unmap_endpoint(self, turn, ep, on_end=None):
del self.endpoints[ep.id]
if on_end:
self.end_callbacks[ep.id] = on_end
turn._extend(protocol.Clear(ep.id))
def _on_end(self, turn, id):
if id in self.end_callbacks:
self.end_callbacks[id](turn)
del self.end_callbacks[id]
def _map_endpoint(self, turn, ep):
self.endpoints[ep.id] = ep
turn._extend(protocol.Assert(ep.id, ep.assertion))
def _on_disconnected(self):
with self.turn() as t:
for ep in self._each_endpoint():
ep._reset(t)
t._reset() ## we have been disconnected, no point in keeping the actions
self._disconnect()
def _on_connected(self):
self._send(protocol.Connect(self.scope))
with self.turn() as t:
for ep in self._each_endpoint():
self._map_endpoint(t, ep)
def _lookup(self, endpointId):
return self.endpoints.get(endpointId, _dummy_endpoint)
def _on_event(self, v):
with self.turn() as t:
self._handle_event(t, v)
def _handle_event(self, turn, v):
if protocol.Turn.isClassOf(v):
for item in protocol.Turn._items(v):
if protocol.Add.isClassOf(item): self._lookup(item[0])._add(turn, item[1])
elif protocol.Del.isClassOf(item): self._lookup(item[0])._del(turn, item[1])
elif protocol.Msg.isClassOf(item): self._lookup(item[0])._msg(turn, item[1])
elif protocol.End.isClassOf(item): self._on_end(turn, item[0])
else: log.error('Unhandled server Turn item: %r' % (item,))
return
elif protocol.Err.isClassOf(v):
self._on_error(v[0], v[1])
return
elif protocol.Ping.isClassOf(v):
self._send_bytes(_encode(protocol.Pong()))
return
else:
log.error('Unhandled server message: %r' % (v,))
def _on_error(self, detail, context):
log.error('%s: error from server: %r (context: %r)' % (
self.__class__.__qualname__, detail, context))
self._disconnect()
def _send(self, m):
return self._send_bytes(_encode(m))
def _send_bytes(self, bs, commitNeeded = False):
raise Exception('subclassresponsibility')
def _disconnect(self):
raise Exception('subclassresponsibility')
class _StreamConnection(Connection, asyncio.Protocol):
def __init__(self, scope):
super().__init__(scope)
self.decoder = None
self.stop_signal = None
self.transport = None
def connection_lost(self, exc):
self._on_disconnected()
def connection_made(self, transport):
self.transport = transport
self._on_connected()
def data_received(self, chunk):
self.decoder.extend(chunk)
while True:
v = self.decoder.try_next()
if v is None: break
self._on_event(v)
def _send_bytes(self, bs, commitNeeded = False):
if self.transport:
self.transport.write(bs)
if commitNeeded:
self.commitNeeded = True
def _disconnect(self):
if self.stop_signal:
self.stop_signal.get_loop().call_soon_threadsafe(
lambda: self.stop_signal.set_result(True))
async def _create_connection(self, loop):
raise Exception('subclassresponsibility')
async def main(self, loop, on_connected=None):
if self.transport is not None:
raise Exception('Cannot run connection twice!')
self.decoder = protocol.Decoder()
self.stop_signal = loop.create_future()
try:
_transport, _protocol = await self._create_connection(loop)
except OSError as e:
log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
try:
if on_connected: on_connected()
await self.stop_signal
return True
finally:
self.transport.close()
self.transport = None
self.stop_signal = None
self.decoder = None
class TcpConnection(_StreamConnection):
def __init__(self, host, port, scope):
super().__init__(scope)
self.host = host
self.port = port
async def _create_connection(self, loop):
return await loop.create_connection(lambda: self, self.host, self.port)
class UnixSocketConnection(_StreamConnection):
def __init__(self, path, scope):
super().__init__(scope)
self.path = path
async def _create_connection(self, loop):
return await loop.create_unix_connection(lambda: self, self.path)
class WebsocketConnection(Connection):
def __init__(self, url, scope):
super().__init__(scope)
self.url = url
self.loop = None
self.ws = None
def _send_bytes(self, bs, commitNeeded = False):
if self.loop:
def _do_send():
if self.ws:
self.loop.create_task(self.ws.send(bs))
self.loop.call_soon_threadsafe(_do_send)
if commitNeeded:
self.commitNeeded = True
def _disconnect(self):
if self.loop:
def _do_disconnect():
if self.ws:
self.loop.create_task(self.ws.close())
self.loop.call_soon_threadsafe(_do_disconnect)
async def main(self, loop, on_connected=None):
if self.ws is not None:
raise Exception('Cannot run connection twice!')
self.loop = loop
try:
async with websockets.connect(self.url) as ws:
if on_connected: on_connected()
self.ws = ws
self._on_connected()
try:
while True:
chunk = await ws.recv()
self._on_event(protocol.Decoder(chunk).next())
except websockets.exceptions.ConnectionClosed:
pass
except OSError as e:
log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
finally:
self._on_disconnected()
if self.ws:
await self.ws.close()
self.loop = None
self.ws = None
return True