Turn-based protocol

This commit is contained in:
Tony Garnock-Jones 2019-06-12 00:26:40 +01:00
parent 7762529d16
commit c73fb462d2
4 changed files with 137 additions and 103 deletions

18
chat.py
View File

@ -31,14 +31,13 @@ names = ['Daria', 'Kendra', 'Danny', 'Rufus', 'Diana', 'Arnetta', 'Dominick', 'M
me = random.choice(names) + '_' + str(random.randint(10, 1000))
S.Endpoint(conn, Present(me))
S.Endpoint(conn, S.Observe(Present(S.CAPTURE)),
on_add=lambda who: print(who, 'joined'),
on_del=lambda who: print(who, 'left'))
S.Endpoint(conn, S.Observe(Says(S.CAPTURE, S.CAPTURE)),
on_msg=lambda who, what: print(who, 'said', repr(what)))
with conn.turn() as t:
S.Endpoint(t, Present(me))
S.Endpoint(t, S.Observe(Present(S.CAPTURE)),
on_add=lambda t, who: print(who, 'joined'),
on_del=lambda t, who: print(who, 'left'))
S.Endpoint(t, S.Observe(Says(S.CAPTURE, S.CAPTURE)),
on_msg=lambda t, who, what: print(who, 'said', repr(what)))
async def reconnect(loop):
while conn:
@ -56,7 +55,8 @@ def accept_input():
conn.destroy()
conn = None
break
conn.send(Says(me, line.strip()))
with conn.turn() as t:
t.send(Says(me, line.strip()))
loop = asyncio.get_event_loop()
loop.set_debug(True)

View File

@ -9,17 +9,19 @@ OverlayLink = S.Record.makeConstructor('OverlayLink', 'downNode upNode')
conn = S.WebsocketConnection(sys.argv[1], sys.argv[2])
uplinks = {}
def add_uplink(src, tgt):
def add_uplink(turn, src, tgt):
uplinks[src] = tgt
summarise_uplinks()
def del_uplink(src, tgt):
def del_uplink(turn, src, tgt):
del uplinks[src]
summarise_uplinks()
def summarise_uplinks():
print(repr(uplinks))
S.Endpoint(conn, S.Observe(OverlayLink(S.CAPTURE, S.CAPTURE)),
on_add=add_uplink,
on_del=del_uplink)
with conn.turn() as t:
S.Endpoint(t, S.Observe(OverlayLink(S.CAPTURE, S.CAPTURE)),
on_add=add_uplink,
on_del=del_uplink)
async def reconnect(loop):
while conn:

View File

@ -12,6 +12,11 @@ CAPTURE = Capture(Discard())
from preserves import *
def _encode(event):
e = protocol.Encoder()
e.append(event)
return e.contents()
_instance_id = secrets.token_urlsafe(8)
_uuid_counter = 0
@ -24,39 +29,65 @@ def uuid(prefix='__@syndicate'):
def _ignore(*args, **kwargs):
pass
class Endpoint(object):
def __init__(self, conn, assertion, id=None,
on_add=None, on_del=None, on_msg=None, on_end=None):
class Turn(object):
def __init__(self, conn):
self.conn = conn
self.assertion = assertion
self.id = id or uuid('sub' if Observe.isClassOf(assertion) else 'pub')
self.items = []
def _extend(self, item):
self.items.append(item)
def _reset(self):
self.items.clear()
def _commit(self):
if self.items:
self.conn._send(protocol.Turn(self.items))
self._reset()
def send(self, message):
self._extend(protocol.Message(message))
def __enter__(self):
if self.items:
raise Exception('Cannot reenter with statement for Turn')
return self
def __exit__(self, t, v, tb):
if t is None:
self._commit()
def _fresh_id(assertion):
return uuid('sub' if Observe.isClassOf(assertion) else 'pub')
class Endpoint(object):
def __init__(self, turn, assertion,
on_add=None, on_del=None, on_msg=None):
self.assertion = None
self.id = None
self.on_add = on_add or _ignore
self.on_del = on_del or _ignore
self.on_msg = on_msg or _ignore
self.on_end = on_end or _ignore
self.cache = set()
self.conn._update_endpoint(self)
self.set(turn, assertion)
def set(self, new_assertion):
def set(self, turn, new_assertion, on_transition=None):
if self.id is not None:
turn.conn._unmap_endpoint(turn, self, on_end=on_transition)
self.id = None
self.assertion = new_assertion
if self.conn:
self.conn._update_endpoint(self)
if self.assertion is not None:
self.id = _fresh_id(self.assertion)
turn.conn._map_endpoint(turn, self)
def send(self, message):
'''Shortcut to Connection.send.'''
if self.conn:
self.conn.send(message)
def clear(self, turn, on_cleared=None):
self.set(turn, None, on_transition=on_cleared)
def destroy(self):
if self.conn:
self.conn._clear_endpoint(self)
self.conn = None
def _reset(self):
def _reset(self, turn):
for captures in set(self.cache):
self._del(captures)
self._del(turn, captures)
def _add(self, captures):
def _add(self, turn, captures):
if captures in self.cache:
log.error('Server error: duplicate captures %r added for endpoint %r %r' % (
captures,
@ -64,107 +95,109 @@ class Endpoint(object):
self.assertion))
else:
self.cache.add(captures)
self.on_add(*captures)
self.on_add(turn, *captures)
def _del(self, captures):
def _del(self, turn, captures):
if captures in self.cache:
self.cache.discard(captures)
self.on_del(*captures)
self.on_del(turn, *captures)
else:
log.error('Server error: nonexistent captures %r removed from endpoint %r %r' % (
captures,
self.id,
self.assertion))
def _msg(self, captures):
self.on_msg(*captures)
def _end(self):
self.on_end()
def _msg(self, turn, captures):
self.on_msg(turn, *captures)
class DummyEndpoint(object):
def _add(self, captures): pass
def _del(self, captures): pass
def _msg(self, captures): pass
def _end(self): pass
def _add(self, turn, captures): pass
def _del(self, turn, captures): pass
def _msg(self, turn, captures): pass
_dummy_endpoint = DummyEndpoint()
class Connection(object):
def __init__(self, scope):
self.endpoints = {}
self.scope = scope
self.commitNeeded = False
self.worklist = []
self.endpoints = {}
self.end_callbacks = {}
def _each_endpoint(self):
return list(self.endpoints.values())
def turn(self):
return Turn(self)
def destroy(self):
for ep in self._each_endpoint():
ep.destroy()
with self.turn() as t:
for ep in self._each_endpoint():
ep.clear(t)
t._reset() ## don't actually Clear the endpoints, we are about to disconnect
self._disconnect()
def _encode(self, event):
e = protocol.Encoder()
e.append(event)
return e.contents()
def _unmap_endpoint(self, turn, ep, on_end=None):
del self.endpoints[ep.id]
if on_end:
self.end_callbacks[ep.id] = on_end
turn._extend(protocol.Clear(ep.id))
def _update_endpoint(self, ep):
def _on_end(self, turn, id):
if id in self.end_callbacks:
self.end_callbacks[id](turn)
del self.end_callbacks[id]
def _map_endpoint(self, turn, ep):
self.endpoints[ep.id] = ep
self._send(self._encode(protocol.Assert(ep.id, ep.assertion)), commitNeeded = True)
def _clear_endpoint(self, ep):
if ep.id in self.endpoints:
del self.endpoints[ep.id]
self._send(self._encode(protocol.Clear(ep.id)), commitNeeded = True)
def send(self, message):
self._send(self._encode(protocol.Message(message)), commitNeeded = True)
self._commit_if_needed()
turn._extend(protocol.Assert(ep.id, ep.assertion))
def _on_disconnected(self):
for ep in self._each_endpoint():
ep._reset()
with self.turn() as t:
for ep in self._each_endpoint():
ep._reset(t)
t._reset() ## we have been disconnected, no point in keeping the actions
self._disconnect()
def _on_connected(self):
self._send(self._encode(protocol.Connect(self.scope)))
for ep in self._each_endpoint():
self._update_endpoint(ep)
self._commit_work()
self._send(protocol.Connect(self.scope))
with self.turn() as t:
for ep in self._each_endpoint():
self._map_endpoint(t, ep)
def _lookup(self, endpointId):
return self.endpoints.get(endpointId, _dummy_endpoint)
def _push_work(self, thunk):
self.worklist.append(thunk)
def _commit_work(self):
for thunk in self.worklist:
thunk()
self.worklist.clear()
self._commit_if_needed()
def _commit_if_needed(self):
if self.commitNeeded:
self._send(self._encode(protocol.Commit()))
self.commitNeeded = False
def _on_event(self, v):
if protocol.Add.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._add(v[1]))
if protocol.Del.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._del(v[1]))
if protocol.Msg.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._msg(v[1]))
if protocol.End.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._end())
if protocol.Commit.isClassOf(v): return self._commit_work()
if protocol.Err.isClassOf(v): return self._on_error(v[0])
if protocol.Ping.isClassOf(v): self._send(self._encode(protocol.Pong()))
with self.turn() as t:
self._handle_event(t, v)
def _on_error(self, detail):
log.error('%s: error from server: %r' % (self.__class__.__qualname__, detail))
def _handle_event(self, turn, v):
if protocol.Turn.isClassOf(v):
for item in protocol.Turn._items(v):
if protocol.Add.isClassOf(item): self._lookup(item[0])._add(turn, item[1])
elif protocol.Del.isClassOf(item): self._lookup(item[0])._del(turn, item[1])
elif protocol.Msg.isClassOf(item): self._lookup(item[0])._msg(turn, item[1])
elif protocol.End.isClassOf(item): self._on_end(turn, item[0])
else: log.error('Unhandled server Turn item: %r' % (item,))
return
elif protocol.Err.isClassOf(v):
self._on_error(v[0], v[1])
return
elif protocol.Ping.isClassOf(v):
self._send_bytes(_encode(protocol.Pong()))
return
else:
log.error('Unhandled server message: %r' % (v,))
def _on_error(self, detail, context):
log.error('%s: error from server: %r (context: %r)' % (
self.__class__.__qualname__, detail, context))
self._disconnect()
def _send(self, bs, commitNeeded = False):
def _send(self, m):
return self._send_bytes(_encode(m))
def _send_bytes(self, bs, commitNeeded = False):
raise Exception('subclassresponsibility')
def _disconnect(self):
@ -191,7 +224,7 @@ class _StreamConnection(Connection, asyncio.Protocol):
if v is None: break
self._on_event(v)
def _send(self, bs, commitNeeded = False):
def _send_bytes(self, bs, commitNeeded = False):
if self.transport:
self.transport.write(bs)
if commitNeeded:
@ -251,7 +284,7 @@ class WebsocketConnection(Connection):
self.loop = None
self.ws = None
def _send(self, bs, commitNeeded = False):
def _send_bytes(self, bs, commitNeeded = False):
if self.loop:
def _do_send():
if self.ws:

View File

@ -3,10 +3,9 @@ from preserves import Record, Symbol
## Enrolment
Connect = Record.makeConstructor('Connect', 'scope')
Peer = Record.makeConstructor('Peer', 'scope')
## Bidirectional
Commit = Record.makeConstructor('Commit', '')
Turn = Record.makeConstructor('Turn', 'items')
## Client -> Server
Assert = Record.makeConstructor('Assert', 'endpointName assertion')
@ -18,7 +17,7 @@ Add = Record.makeConstructor('Add', 'endpointName captures')
Del = Record.makeConstructor('Del', 'endpointName captures')
Msg = Record.makeConstructor('Msg', 'endpointName captures')
End = Record.makeConstructor('End', 'endpointName')
Err = Record.makeConstructor('Err', 'detail')
Err = Record.makeConstructor('Err', 'detail context')
## Bidirectional
Ping = Record.makeConstructor('Ping', '')