Turn-based protocol

This commit is contained in:
Tony Garnock-Jones 2019-06-12 00:26:40 +01:00
parent 7762529d16
commit c73fb462d2
4 changed files with 137 additions and 103 deletions

18
chat.py
View File

@ -31,14 +31,13 @@ names = ['Daria', 'Kendra', 'Danny', 'Rufus', 'Diana', 'Arnetta', 'Dominick', 'M
me = random.choice(names) + '_' + str(random.randint(10, 1000)) me = random.choice(names) + '_' + str(random.randint(10, 1000))
S.Endpoint(conn, Present(me)) with conn.turn() as t:
S.Endpoint(t, Present(me))
S.Endpoint(conn, S.Observe(Present(S.CAPTURE)), S.Endpoint(t, S.Observe(Present(S.CAPTURE)),
on_add=lambda who: print(who, 'joined'), on_add=lambda t, who: print(who, 'joined'),
on_del=lambda who: print(who, 'left')) on_del=lambda t, who: print(who, 'left'))
S.Endpoint(t, S.Observe(Says(S.CAPTURE, S.CAPTURE)),
S.Endpoint(conn, S.Observe(Says(S.CAPTURE, S.CAPTURE)), on_msg=lambda t, who, what: print(who, 'said', repr(what)))
on_msg=lambda who, what: print(who, 'said', repr(what)))
async def reconnect(loop): async def reconnect(loop):
while conn: while conn:
@ -56,7 +55,8 @@ def accept_input():
conn.destroy() conn.destroy()
conn = None conn = None
break break
conn.send(Says(me, line.strip())) with conn.turn() as t:
t.send(Says(me, line.strip()))
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.set_debug(True) loop.set_debug(True)

View File

@ -9,17 +9,19 @@ OverlayLink = S.Record.makeConstructor('OverlayLink', 'downNode upNode')
conn = S.WebsocketConnection(sys.argv[1], sys.argv[2]) conn = S.WebsocketConnection(sys.argv[1], sys.argv[2])
uplinks = {} uplinks = {}
def add_uplink(src, tgt): def add_uplink(turn, src, tgt):
uplinks[src] = tgt uplinks[src] = tgt
summarise_uplinks() summarise_uplinks()
def del_uplink(src, tgt): def del_uplink(turn, src, tgt):
del uplinks[src] del uplinks[src]
summarise_uplinks() summarise_uplinks()
def summarise_uplinks(): def summarise_uplinks():
print(repr(uplinks)) print(repr(uplinks))
S.Endpoint(conn, S.Observe(OverlayLink(S.CAPTURE, S.CAPTURE)),
on_add=add_uplink, with conn.turn() as t:
on_del=del_uplink) S.Endpoint(t, S.Observe(OverlayLink(S.CAPTURE, S.CAPTURE)),
on_add=add_uplink,
on_del=del_uplink)
async def reconnect(loop): async def reconnect(loop):
while conn: while conn:

View File

@ -12,6 +12,11 @@ CAPTURE = Capture(Discard())
from preserves import * from preserves import *
def _encode(event):
e = protocol.Encoder()
e.append(event)
return e.contents()
_instance_id = secrets.token_urlsafe(8) _instance_id = secrets.token_urlsafe(8)
_uuid_counter = 0 _uuid_counter = 0
@ -24,39 +29,65 @@ def uuid(prefix='__@syndicate'):
def _ignore(*args, **kwargs): def _ignore(*args, **kwargs):
pass pass
class Endpoint(object): class Turn(object):
def __init__(self, conn, assertion, id=None, def __init__(self, conn):
on_add=None, on_del=None, on_msg=None, on_end=None):
self.conn = conn self.conn = conn
self.assertion = assertion self.items = []
self.id = id or uuid('sub' if Observe.isClassOf(assertion) else 'pub')
def _extend(self, item):
self.items.append(item)
def _reset(self):
self.items.clear()
def _commit(self):
if self.items:
self.conn._send(protocol.Turn(self.items))
self._reset()
def send(self, message):
self._extend(protocol.Message(message))
def __enter__(self):
if self.items:
raise Exception('Cannot reenter with statement for Turn')
return self
def __exit__(self, t, v, tb):
if t is None:
self._commit()
def _fresh_id(assertion):
return uuid('sub' if Observe.isClassOf(assertion) else 'pub')
class Endpoint(object):
def __init__(self, turn, assertion,
on_add=None, on_del=None, on_msg=None):
self.assertion = None
self.id = None
self.on_add = on_add or _ignore self.on_add = on_add or _ignore
self.on_del = on_del or _ignore self.on_del = on_del or _ignore
self.on_msg = on_msg or _ignore self.on_msg = on_msg or _ignore
self.on_end = on_end or _ignore
self.cache = set() self.cache = set()
self.conn._update_endpoint(self) self.set(turn, assertion)
def set(self, new_assertion): def set(self, turn, new_assertion, on_transition=None):
if self.id is not None:
turn.conn._unmap_endpoint(turn, self, on_end=on_transition)
self.id = None
self.assertion = new_assertion self.assertion = new_assertion
if self.conn: if self.assertion is not None:
self.conn._update_endpoint(self) self.id = _fresh_id(self.assertion)
turn.conn._map_endpoint(turn, self)
def send(self, message): def clear(self, turn, on_cleared=None):
'''Shortcut to Connection.send.''' self.set(turn, None, on_transition=on_cleared)
if self.conn:
self.conn.send(message)
def destroy(self): def _reset(self, turn):
if self.conn:
self.conn._clear_endpoint(self)
self.conn = None
def _reset(self):
for captures in set(self.cache): for captures in set(self.cache):
self._del(captures) self._del(turn, captures)
def _add(self, captures): def _add(self, turn, captures):
if captures in self.cache: if captures in self.cache:
log.error('Server error: duplicate captures %r added for endpoint %r %r' % ( log.error('Server error: duplicate captures %r added for endpoint %r %r' % (
captures, captures,
@ -64,107 +95,109 @@ class Endpoint(object):
self.assertion)) self.assertion))
else: else:
self.cache.add(captures) self.cache.add(captures)
self.on_add(*captures) self.on_add(turn, *captures)
def _del(self, captures): def _del(self, turn, captures):
if captures in self.cache: if captures in self.cache:
self.cache.discard(captures) self.cache.discard(captures)
self.on_del(*captures) self.on_del(turn, *captures)
else: else:
log.error('Server error: nonexistent captures %r removed from endpoint %r %r' % ( log.error('Server error: nonexistent captures %r removed from endpoint %r %r' % (
captures, captures,
self.id, self.id,
self.assertion)) self.assertion))
def _msg(self, captures): def _msg(self, turn, captures):
self.on_msg(*captures) self.on_msg(turn, *captures)
def _end(self):
self.on_end()
class DummyEndpoint(object): class DummyEndpoint(object):
def _add(self, captures): pass def _add(self, turn, captures): pass
def _del(self, captures): pass def _del(self, turn, captures): pass
def _msg(self, captures): pass def _msg(self, turn, captures): pass
def _end(self): pass
_dummy_endpoint = DummyEndpoint() _dummy_endpoint = DummyEndpoint()
class Connection(object): class Connection(object):
def __init__(self, scope): def __init__(self, scope):
self.endpoints = {}
self.scope = scope self.scope = scope
self.commitNeeded = False self.endpoints = {}
self.worklist = [] self.end_callbacks = {}
def _each_endpoint(self): def _each_endpoint(self):
return list(self.endpoints.values()) return list(self.endpoints.values())
def turn(self):
return Turn(self)
def destroy(self): def destroy(self):
for ep in self._each_endpoint(): with self.turn() as t:
ep.destroy() for ep in self._each_endpoint():
ep.clear(t)
t._reset() ## don't actually Clear the endpoints, we are about to disconnect
self._disconnect() self._disconnect()
def _encode(self, event): def _unmap_endpoint(self, turn, ep, on_end=None):
e = protocol.Encoder() del self.endpoints[ep.id]
e.append(event) if on_end:
return e.contents() self.end_callbacks[ep.id] = on_end
turn._extend(protocol.Clear(ep.id))
def _update_endpoint(self, ep): def _on_end(self, turn, id):
if id in self.end_callbacks:
self.end_callbacks[id](turn)
del self.end_callbacks[id]
def _map_endpoint(self, turn, ep):
self.endpoints[ep.id] = ep self.endpoints[ep.id] = ep
self._send(self._encode(protocol.Assert(ep.id, ep.assertion)), commitNeeded = True) turn._extend(protocol.Assert(ep.id, ep.assertion))
def _clear_endpoint(self, ep):
if ep.id in self.endpoints:
del self.endpoints[ep.id]
self._send(self._encode(protocol.Clear(ep.id)), commitNeeded = True)
def send(self, message):
self._send(self._encode(protocol.Message(message)), commitNeeded = True)
self._commit_if_needed()
def _on_disconnected(self): def _on_disconnected(self):
for ep in self._each_endpoint(): with self.turn() as t:
ep._reset() for ep in self._each_endpoint():
ep._reset(t)
t._reset() ## we have been disconnected, no point in keeping the actions
self._disconnect() self._disconnect()
def _on_connected(self): def _on_connected(self):
self._send(self._encode(protocol.Connect(self.scope))) self._send(protocol.Connect(self.scope))
for ep in self._each_endpoint(): with self.turn() as t:
self._update_endpoint(ep) for ep in self._each_endpoint():
self._commit_work() self._map_endpoint(t, ep)
def _lookup(self, endpointId): def _lookup(self, endpointId):
return self.endpoints.get(endpointId, _dummy_endpoint) return self.endpoints.get(endpointId, _dummy_endpoint)
def _push_work(self, thunk):
self.worklist.append(thunk)
def _commit_work(self):
for thunk in self.worklist:
thunk()
self.worklist.clear()
self._commit_if_needed()
def _commit_if_needed(self):
if self.commitNeeded:
self._send(self._encode(protocol.Commit()))
self.commitNeeded = False
def _on_event(self, v): def _on_event(self, v):
if protocol.Add.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._add(v[1])) with self.turn() as t:
if protocol.Del.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._del(v[1])) self._handle_event(t, v)
if protocol.Msg.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._msg(v[1]))
if protocol.End.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._end())
if protocol.Commit.isClassOf(v): return self._commit_work()
if protocol.Err.isClassOf(v): return self._on_error(v[0])
if protocol.Ping.isClassOf(v): self._send(self._encode(protocol.Pong()))
def _on_error(self, detail): def _handle_event(self, turn, v):
log.error('%s: error from server: %r' % (self.__class__.__qualname__, detail)) if protocol.Turn.isClassOf(v):
for item in protocol.Turn._items(v):
if protocol.Add.isClassOf(item): self._lookup(item[0])._add(turn, item[1])
elif protocol.Del.isClassOf(item): self._lookup(item[0])._del(turn, item[1])
elif protocol.Msg.isClassOf(item): self._lookup(item[0])._msg(turn, item[1])
elif protocol.End.isClassOf(item): self._on_end(turn, item[0])
else: log.error('Unhandled server Turn item: %r' % (item,))
return
elif protocol.Err.isClassOf(v):
self._on_error(v[0], v[1])
return
elif protocol.Ping.isClassOf(v):
self._send_bytes(_encode(protocol.Pong()))
return
else:
log.error('Unhandled server message: %r' % (v,))
def _on_error(self, detail, context):
log.error('%s: error from server: %r (context: %r)' % (
self.__class__.__qualname__, detail, context))
self._disconnect() self._disconnect()
def _send(self, bs, commitNeeded = False): def _send(self, m):
return self._send_bytes(_encode(m))
def _send_bytes(self, bs, commitNeeded = False):
raise Exception('subclassresponsibility') raise Exception('subclassresponsibility')
def _disconnect(self): def _disconnect(self):
@ -191,7 +224,7 @@ class _StreamConnection(Connection, asyncio.Protocol):
if v is None: break if v is None: break
self._on_event(v) self._on_event(v)
def _send(self, bs, commitNeeded = False): def _send_bytes(self, bs, commitNeeded = False):
if self.transport: if self.transport:
self.transport.write(bs) self.transport.write(bs)
if commitNeeded: if commitNeeded:
@ -251,7 +284,7 @@ class WebsocketConnection(Connection):
self.loop = None self.loop = None
self.ws = None self.ws = None
def _send(self, bs, commitNeeded = False): def _send_bytes(self, bs, commitNeeded = False):
if self.loop: if self.loop:
def _do_send(): def _do_send():
if self.ws: if self.ws:

View File

@ -3,10 +3,9 @@ from preserves import Record, Symbol
## Enrolment ## Enrolment
Connect = Record.makeConstructor('Connect', 'scope') Connect = Record.makeConstructor('Connect', 'scope')
Peer = Record.makeConstructor('Peer', 'scope')
## Bidirectional ## Bidirectional
Commit = Record.makeConstructor('Commit', '') Turn = Record.makeConstructor('Turn', 'items')
## Client -> Server ## Client -> Server
Assert = Record.makeConstructor('Assert', 'endpointName assertion') Assert = Record.makeConstructor('Assert', 'endpointName assertion')
@ -18,7 +17,7 @@ Add = Record.makeConstructor('Add', 'endpointName captures')
Del = Record.makeConstructor('Del', 'endpointName captures') Del = Record.makeConstructor('Del', 'endpointName captures')
Msg = Record.makeConstructor('Msg', 'endpointName captures') Msg = Record.makeConstructor('Msg', 'endpointName captures')
End = Record.makeConstructor('End', 'endpointName') End = Record.makeConstructor('End', 'endpointName')
Err = Record.makeConstructor('Err', 'detail') Err = Record.makeConstructor('Err', 'detail context')
## Bidirectional ## Bidirectional
Ping = Record.makeConstructor('Ping', '') Ping = Record.makeConstructor('Ping', '')