From 9dbf4a8c5a4a2a204873276acb28f97cca0a029a Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Thu, 30 May 2019 22:35:56 +0100 Subject: [PATCH] Turn commits --- syndicate/mini/core.py | 41 +++++++++++++++++++++++++++++--------- syndicate/mini/protocol.py | 3 +++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/syndicate/mini/core.py b/syndicate/mini/core.py index c0d8156..23cd98d 100644 --- a/syndicate/mini/core.py +++ b/syndicate/mini/core.py @@ -88,6 +88,8 @@ class Connection(object): def __init__(self, scope): self.endpoints = {} self.scope = scope + self.commitNeeded = False + self.worklist = [] def _each_endpoint(self): return list(self.endpoints.values()) @@ -104,15 +106,16 @@ class Connection(object): def _update_endpoint(self, ep): self.endpoints[ep.id] = ep - self._send(self._encode(protocol.Assert(ep.id, ep.assertion))) + self._send(self._encode(protocol.Assert(ep.id, ep.assertion)), commitNeeded = True) def _clear_endpoint(self, ep): if ep.id in self.endpoints: del self.endpoints[ep.id] - self._send(self._encode(protocol.Clear(ep.id))) + self._send(self._encode(protocol.Clear(ep.id)), commitNeeded = True) def send(self, message): - self._send(self._encode(protocol.Message(message))) + self._send(self._encode(protocol.Message(message)), commitNeeded = True) + self._commit_if_needed() def _on_disconnected(self): for ep in self._each_endpoint(): @@ -123,14 +126,30 @@ class Connection(object): self._send(self._encode(protocol.Connect(self.scope))) for ep in self._each_endpoint(): self._update_endpoint(ep) + self._commit_work() def _lookup(self, endpointId): return self.endpoints.get(endpointId, _dummy_endpoint) + def _push_work(self, thunk): + self.worklist.append(thunk) + + def _commit_work(self): + for thunk in self.worklist: + thunk() + self.worklist.clear() + self._commit_if_needed() + + def _commit_if_needed(self): + if self.commitNeeded: + self._send(self._encode(protocol.Commit())) + self.commitNeeded = False + def _on_event(self, v): - if protocol.Add.isClassOf(v): return self._lookup(v[0])._add(v[1]) - if protocol.Del.isClassOf(v): return self._lookup(v[0])._del(v[1]) - if protocol.Msg.isClassOf(v): return self._lookup(v[0])._msg(v[1]) + if protocol.Add.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._add(v[1])) + if protocol.Del.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._del(v[1])) + if protocol.Msg.isClassOf(v): return self._push_work(lambda: self._lookup(v[0])._msg(v[1])) + if protocol.Commit.isClassOf(v): return self._commit_work() if protocol.Err.isClassOf(v): return self._on_error(v[0]) if protocol.Ping.isClassOf(v): self._send(self._encode(protocol.Pong())) @@ -138,7 +157,7 @@ class Connection(object): log.error('%s: error from server: %r' % (self.__class__.__qualname__, detail)) self._disconnect() - def _send(self, bs): + def _send(self, bs, commitNeeded = False): raise Exception('subclassresponsibility') def _disconnect(self): @@ -165,9 +184,11 @@ class _StreamConnection(Connection, asyncio.Protocol): if v is None: break self._on_event(v) - def _send(self, bs): + def _send(self, bs, commitNeeded = False): if self.transport: self.transport.write(bs) + if commitNeeded: + self.commitNeeded = True def _disconnect(self): if self.stop_signal: @@ -223,12 +244,14 @@ class WebsocketConnection(Connection): self.loop = None self.ws = None - def _send(self, bs): + def _send(self, bs, commitNeeded = False): if self.loop: def _do_send(): if self.ws: self.loop.create_task(self.ws.send(bs)) self.loop.call_soon_threadsafe(_do_send) + if commitNeeded: + self.commitNeeded = True def _disconnect(self): if self.loop: diff --git a/syndicate/mini/protocol.py b/syndicate/mini/protocol.py index d809c1f..2a2cc9a 100644 --- a/syndicate/mini/protocol.py +++ b/syndicate/mini/protocol.py @@ -5,6 +5,9 @@ from preserves import Record, Symbol Connect = Record.makeConstructor('Connect', 'scope') Peer = Record.makeConstructor('Peer', 'scope') +## Bidirectional +Commit = Record.makeConstructor('Commit', '') + ## Client -> Server Assert = Record.makeConstructor('Assert', 'endpointName assertion') Clear = Record.makeConstructor('Clear', 'endpointName')