diff --git a/chat.py b/chat.py index 20420b0..f0a47d8 100644 --- a/chat.py +++ b/chat.py @@ -31,14 +31,13 @@ names = ['Daria', 'Kendra', 'Danny', 'Rufus', 'Diana', 'Arnetta', 'Dominick', 'M me = random.choice(names) + '_' + str(random.randint(10, 1000)) -S.Endpoint(conn, Present(me)) - -S.Endpoint(conn, S.Observe(Present(S.CAPTURE)), - on_add=lambda who: print(who, 'joined'), - on_del=lambda who: print(who, 'left')) - -S.Endpoint(conn, S.Observe(Says(S.CAPTURE, S.CAPTURE)), - on_msg=lambda who, what: print(who, 'said', repr(what))) +with conn.turn() as t: + S.Endpoint(t, Present(me)) + S.Endpoint(t, S.Observe(Present(S.CAPTURE)), + on_add=lambda t, who: print(who, 'joined'), + on_del=lambda t, who: print(who, 'left')) + S.Endpoint(t, S.Observe(Says(S.CAPTURE, S.CAPTURE)), + on_msg=lambda t, who, what: print(who, 'said', repr(what))) async def reconnect(loop): while conn: @@ -56,7 +55,8 @@ def accept_input(): conn.destroy() conn = None break - conn.send(Says(me, line.strip())) + with conn.turn() as t: + t.send(Says(me, line.strip())) loop = asyncio.get_event_loop() loop.set_debug(True) diff --git a/ovlinfo.py b/ovlinfo.py index f3fd746..be788e0 100644 --- a/ovlinfo.py +++ b/ovlinfo.py @@ -9,17 +9,19 @@ OverlayLink = S.Record.makeConstructor('OverlayLink', 'downNode upNode') conn = S.WebsocketConnection(sys.argv[1], sys.argv[2]) uplinks = {} -def add_uplink(src, tgt): +def add_uplink(turn, src, tgt): uplinks[src] = tgt summarise_uplinks() -def del_uplink(src, tgt): +def del_uplink(turn, src, tgt): del uplinks[src] summarise_uplinks() def summarise_uplinks(): print(repr(uplinks)) -S.Endpoint(conn, S.Observe(OverlayLink(S.CAPTURE, S.CAPTURE)), - on_add=add_uplink, - on_del=del_uplink) + +with conn.turn() as t: + S.Endpoint(t, S.Observe(OverlayLink(S.CAPTURE, S.CAPTURE)), + on_add=add_uplink, + on_del=del_uplink) async def reconnect(loop): while conn: diff --git a/syndicate/mini/core.py b/syndicate/mini/core.py index 44e9f84..c5f06ad 100644 --- a/syndicate/mini/core.py +++ b/syndicate/mini/core.py @@ -12,6 +12,11 @@ CAPTURE = Capture(Discard()) from preserves import * +def _encode(event): + e = protocol.Encoder() + e.append(event) + return e.contents() + _instance_id = secrets.token_urlsafe(8) _uuid_counter = 0 @@ -24,39 +29,65 @@ def uuid(prefix='__@syndicate'): def _ignore(*args, **kwargs): pass -class Endpoint(object): - def __init__(self, conn, assertion, id=None, - on_add=None, on_del=None, on_msg=None, on_end=None): +class Turn(object): + def __init__(self, conn): self.conn = conn - self.assertion = assertion - self.id = id or uuid('sub' if Observe.isClassOf(assertion) else 'pub') + self.items = [] + + def _extend(self, item): + self.items.append(item) + + def _reset(self): + self.items.clear() + + def _commit(self): + if self.items: + self.conn._send(protocol.Turn(self.items)) + self._reset() + + def send(self, message): + self._extend(protocol.Message(message)) + + def __enter__(self): + if self.items: + raise Exception('Cannot reenter with statement for Turn') + return self + + def __exit__(self, t, v, tb): + if t is None: + self._commit() + +def _fresh_id(assertion): + return uuid('sub' if Observe.isClassOf(assertion) else 'pub') + +class Endpoint(object): + def __init__(self, turn, assertion, + on_add=None, on_del=None, on_msg=None): + self.assertion = None + self.id = None self.on_add = on_add or _ignore self.on_del = on_del or _ignore self.on_msg = on_msg or _ignore - self.on_end = on_end or _ignore self.cache = set() - self.conn._update_endpoint(self) + self.set(turn, assertion) - def set(self, new_assertion): + def set(self, turn, new_assertion, on_transition=None): + if self.id is not None: + turn.conn._unmap_endpoint(turn, self, on_end=on_transition) + self.id = None self.assertion = new_assertion - if self.conn: - self.conn._update_endpoint(self) + if self.assertion is not None: + self.id = _fresh_id(self.assertion) + turn.conn._map_endpoint(turn, self) - def send(self, message): - '''Shortcut to Connection.send.''' - if self.conn: - self.conn.send(message) + def clear(self, turn, on_cleared=None): + self.set(turn, None, on_transition=on_cleared) - def destroy(self): - if self.conn: - self.conn._clear_endpoint(self) - self.conn = None - - def _reset(self): + def _reset(self, turn): for captures in set(self.cache): - self._del(captures) + self._del(turn, captures) - def _add(self, captures): + def _add(self, turn, captures): if captures in self.cache: log.error('Server error: duplicate captures %r added for endpoint %r %r' % ( captures, @@ -64,107 +95,109 @@ class Endpoint(object): self.assertion)) else: self.cache.add(captures) - self.on_add(*captures) + self.on_add(turn, *captures) - def _del(self, captures): + def _del(self, turn, captures): if captures in self.cache: self.cache.discard(captures) - self.on_del(*captures) + self.on_del(turn, *captures) else: log.error('Server error: nonexistent captures %r removed from endpoint %r %r' % ( captures, self.id, self.assertion)) - def _msg(self, captures): - self.on_msg(*captures) - - def _end(self): - self.on_end() + def _msg(self, turn, captures): + self.on_msg(turn, *captures) class DummyEndpoint(object): - def _add(self, captures): pass - def _del(self, captures): pass - def _msg(self, captures): pass - def _end(self): pass + def _add(self, turn, captures): pass + def _del(self, turn, captures): pass + def _msg(self, turn, captures): pass _dummy_endpoint = DummyEndpoint() class Connection(object): def __init__(self, scope): - self.endpoints = {} self.scope = scope - self.commitNeeded = False - self.worklist = [] + self.endpoints = {} + self.end_callbacks = {} def _each_endpoint(self): return list(self.endpoints.values()) + def turn(self): + return Turn(self) + def destroy(self): - for ep in self._each_endpoint(): - ep.destroy() + with self.turn() as t: + for ep in self._each_endpoint(): + ep.clear(t) + t._reset() ## don't actually Clear the endpoints, we are about to disconnect self._disconnect() - def _encode(self, event): - e = protocol.Encoder() - e.append(event) - return e.contents() + def _unmap_endpoint(self, turn, ep, on_end=None): + del self.endpoints[ep.id] + if on_end: + self.end_callbacks[ep.id] = on_end + turn._extend(protocol.Clear(ep.id)) - def _update_endpoint(self, ep): + def _on_end(self, turn, id): + if id in self.end_callbacks: + self.end_callbacks[id](turn) + del self.end_callbacks[id] + + def _map_endpoint(self, turn, ep): self.endpoints[ep.id] = ep - self._send(self._encode(protocol.Assert(ep.id, ep.assertion)), commitNeeded = True) - - def _clear_endpoint(self, ep): - if ep.id in self.endpoints: - del self.endpoints[ep.id] - self._send(self._encode(protocol.Clear(ep.id)), commitNeeded = True) - - def send(self, message): - self._send(self._encode(protocol.Message(message)), commitNeeded = True) - self._commit_if_needed() + turn._extend(protocol.Assert(ep.id, ep.assertion)) def _on_disconnected(self): - for ep in self._each_endpoint(): - ep._reset() + with self.turn() as t: + for ep in self._each_endpoint(): + ep._reset(t) + t._reset() ## we have been disconnected, no point in keeping the actions self._disconnect() def _on_connected(self): - self._send(self._encode(protocol.Connect(self.scope))) - for ep in self._each_endpoint(): - self._update_endpoint(ep) - self._commit_work() + self._send(protocol.Connect(self.scope)) + with self.turn() as t: + for ep in self._each_endpoint(): + self._map_endpoint(t, ep) def _lookup(self, endpointId): return self.endpoints.get(endpointId, _dummy_endpoint) - def _push_work(self, thunk): - self.worklist.append(thunk) - - def _commit_work(self): - for thunk in self.worklist: - thunk() - self.worklist.clear() - self._commit_if_needed() - - def _commit_if_needed(self): - if self.commitNeeded: - self._send(self._encode(protocol.Commit())) - self.commitNeeded = False - def _on_event(self, v): - if protocol.Add.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._add(v[1])) - if protocol.Del.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._del(v[1])) - if protocol.Msg.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._msg(v[1])) - if protocol.End.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._end()) - if protocol.Commit.isClassOf(v): return self._commit_work() - if protocol.Err.isClassOf(v): return self._on_error(v[0]) - if protocol.Ping.isClassOf(v): self._send(self._encode(protocol.Pong())) + with self.turn() as t: + self._handle_event(t, v) - def _on_error(self, detail): - log.error('%s: error from server: %r' % (self.__class__.__qualname__, detail)) + def _handle_event(self, turn, v): + if protocol.Turn.isClassOf(v): + for item in protocol.Turn._items(v): + if protocol.Add.isClassOf(item): self._lookup(item[0])._add(turn, item[1]) + elif protocol.Del.isClassOf(item): self._lookup(item[0])._del(turn, item[1]) + elif protocol.Msg.isClassOf(item): self._lookup(item[0])._msg(turn, item[1]) + elif protocol.End.isClassOf(item): self._on_end(turn, item[0]) + else: log.error('Unhandled server Turn item: %r' % (item,)) + return + elif protocol.Err.isClassOf(v): + self._on_error(v[0], v[1]) + return + elif protocol.Ping.isClassOf(v): + self._send_bytes(_encode(protocol.Pong())) + return + else: + log.error('Unhandled server message: %r' % (v,)) + + def _on_error(self, detail, context): + log.error('%s: error from server: %r (context: %r)' % ( + self.__class__.__qualname__, detail, context)) self._disconnect() - def _send(self, bs, commitNeeded = False): + def _send(self, m): + return self._send_bytes(_encode(m)) + + def _send_bytes(self, bs, commitNeeded = False): raise Exception('subclassresponsibility') def _disconnect(self): @@ -191,7 +224,7 @@ class _StreamConnection(Connection, asyncio.Protocol): if v is None: break self._on_event(v) - def _send(self, bs, commitNeeded = False): + def _send_bytes(self, bs, commitNeeded = False): if self.transport: self.transport.write(bs) if commitNeeded: @@ -251,7 +284,7 @@ class WebsocketConnection(Connection): self.loop = None self.ws = None - def _send(self, bs, commitNeeded = False): + def _send_bytes(self, bs, commitNeeded = False): if self.loop: def _do_send(): if self.ws: diff --git a/syndicate/mini/protocol.py b/syndicate/mini/protocol.py index fc8a4c7..4d3c724 100644 --- a/syndicate/mini/protocol.py +++ b/syndicate/mini/protocol.py @@ -3,10 +3,9 @@ from preserves import Record, Symbol ## Enrolment Connect = Record.makeConstructor('Connect', 'scope') -Peer = Record.makeConstructor('Peer', 'scope') ## Bidirectional -Commit = Record.makeConstructor('Commit', '') +Turn = Record.makeConstructor('Turn', 'items') ## Client -> Server Assert = Record.makeConstructor('Assert', 'endpointName assertion') @@ -18,7 +17,7 @@ Add = Record.makeConstructor('Add', 'endpointName captures') Del = Record.makeConstructor('Del', 'endpointName captures') Msg = Record.makeConstructor('Msg', 'endpointName captures') End = Record.makeConstructor('End', 'endpointName') -Err = Record.makeConstructor('Err', 'detail') +Err = Record.makeConstructor('Err', 'detail context') ## Bidirectional Ping = Record.makeConstructor('Ping', '')