diff --git a/syndicate/relay.py b/syndicate/relay.py index 20ded56..df200f2 100644 --- a/syndicate/relay.py +++ b/syndicate/relay.py @@ -13,48 +13,57 @@ from .idgen import IdGenerator from .schema import externalProtocol as protocol, sturdy, transportAddress class InboundAssertion: - def __init__(self, remote_handle, local_handle, wire_symbols): + def __init__(self, remote_handle, local_handle, pins): self.remote_handle = remote_handle self.local_handle = local_handle - self.wire_symbols = wire_symbols + self.pins = pins _next_local_oid = IdGenerator() class WireSymbol: - def __init__(self, oid, ref): + def __init__(self, oid, ref, membrane): self.oid = oid self.ref = ref + self.membrane = membrane self.count = 0 def __repr__(self): return '' % (self.oid, self.count, self.ref) + def grab(self, pins): + self.count = self.count + 1 + pins.append(self) + + def drop(self): + self.count = self.count - 1 + if self.count == 0: + del self.membrane.oid_map[self.oid] + del self.membrane.ref_map[self.ref] + class Membrane: def __init__(self): self.oid_map = {} self.ref_map = {} - def _get(self, map, key, is_transient, ws_maker): + def _get(self, pins, map, key, is_transient, ws_maker): ws = map.get(key, None) - if ws is None: + if ws is None and ws_maker is not None: ws = ws_maker() self.oid_map[ws.oid] = ws self.ref_map[ws.ref] = ws - if not is_transient: - ws.count = ws.count + 1 + if not is_transient and ws is not None: + ws.grab(pins) return ws - def get_ref(self, local_ref, is_transient, ws_maker): - return self._get(self.ref_map, local_ref, is_transient, ws_maker) + def get_ref(self, pins, local_ref, is_transient, ws_maker): + return self._get(pins, self.ref_map, local_ref, is_transient, ws_maker) - def get_oid(self, remote_oid, ws_maker): - return self._get(self.oid_map, remote_oid, False, ws_maker) + def get_oid(self, pins, remote_oid, ws_maker): + return self._get(pins, self.oid_map, remote_oid, False, ws_maker) - def drop(self, ws): - ws.count = ws.count - 1 - if ws.count == 0: - del self.oid_map[ws.oid] - del self.ref_map[ws.ref] +def drop_all(wss): + for ws in wss: + ws.drop() # There are other kinds of relay. This one has exactly two participants connected to each other. class TunnelRelay: @@ -79,7 +88,7 @@ class TunnelRelay: def _reset(self): self.inbound_assertions = {} # map remote handle to InboundAssertion - self.outbound_assertions = {} # map local handle to wire_symbols + self.outbound_assertions = {} # map local handle to `WireSymbol`s self.exported_references = Membrane() self.imported_references = Membrane() self.pending_turn = [] @@ -94,49 +103,53 @@ class TunnelRelay: self._disconnect() def deregister(self, handle): - for ws in self.outbound_assertions.pop(handle, ()): - self.exported_references.drop(ws) + drop_all(self.outbound_assertions.pop(handle, ())) - def _lookup(self, local_oid): - ws = self.exported_references.oid_map.get(local_oid, None) - return _inert_ref if ws is None else ws.ref + def _lookup_exported_oid(self, local_oid, pins): + ws = self.exported_references.get_oid(pins, local_oid, None) + if ws is None: + return _inert_ref + return ws.ref - def register(self, assertion, maybe_handle): - exported = [] + def register_imported_oid(self, remote_oid, pins): + self.imported_references.get_oid(pins, remote_oid, None) + + def register(self, target_oid, assertion, maybe_handle): + pins = [] + self.register_imported_oid(target_oid, pins) rewritten = map_embeddeds( - lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, exported)), + lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, pins)), assertion) if maybe_handle is not None: - self.outbound_assertions[maybe_handle] = exported + self.outbound_assertions[maybe_handle] = pins return rewritten - def rewrite_ref_out(self, r, is_transient, exported): + def rewrite_ref_out(self, r, is_transient, pins): if isinstance(r.entity, RelayEntity) and r.entity.relay == self: # TODO attenuation return sturdy.WireRef.yours(sturdy.Oid(r.entity.oid), ()) else: ws = self.exported_references.get_ref( - r, is_transient, lambda: WireSymbol(next(_next_local_oid), r)) - exported.append(ws) + pins, r, is_transient, lambda: WireSymbol(next(_next_local_oid), r, + self.exported_references)) return sturdy.WireRef.mine(sturdy.Oid(ws.oid)) - def rewrite_in(self, turn, assertion): - imported = [] + def rewrite_in(self, turn, assertion, pins): rewritten = map_embeddeds( - lambda wire_ref: Embedded(self.rewrite_ref_in(turn, wire_ref, imported)), + lambda wire_ref: Embedded(self.rewrite_ref_in(turn, wire_ref, pins)), assertion) - return (rewritten, imported) + return rewritten - def rewrite_ref_in(self, turn, wire_ref, imported): + def rewrite_ref_in(self, turn, wire_ref, pins): if wire_ref.VARIANT.name == 'mine': oid = wire_ref.oid.value ws = self.imported_references.get_oid( - oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid)))) - imported.append(ws) + pins, oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid)), + self.imported_references)) return ws.ref else: oid = wire_ref.oid.value - local_ref = self._lookup(oid) + local_ref = self._lookup_exported_oid(oid, pins) attenuation = wire_ref.attenuation if len(attenuation) > 0: raise NotImplementedError('Non-empty attenuations not yet implemented') # TODO @@ -178,44 +191,45 @@ class TunnelRelay: def _handle_turn_events(self, turn, events): for e in events: - ref = self._lookup(e.oid.value) + pins = [] + ref = self._lookup_exported_oid(e.oid.value, pins) event = e.event variant = event.VARIANT.name if variant == 'Assert': - self._handle_publish(turn, ref, event.value.assertion.value, event.value.handle.value) + self._handle_publish(pins, turn, ref, event.value.assertion.value, event.value.handle.value) elif variant == 'Retract': - self._handle_retract(turn, ref, event.value.handle.value) + self._handle_retract(pins, turn, ref, event.value.handle.value) elif variant == 'Message': - self._handle_message(turn, ref, event.value.body.value) + self._handle_message(pins, turn, ref, event.value.body.value) elif variant == 'Sync': - self._handle_sync(turn, ref, event.value.peer) + self._handle_sync(pins, turn, ref, event.value.peer) - def _handle_publish(self, turn, ref, assertion, remote_handle): - (assertion, imported) = self.rewrite_in(turn, assertion) + def _handle_publish(self, pins, turn, ref, assertion, remote_handle): + assertion = self.rewrite_in(turn, assertion, pins) self.inbound_assertions[remote_handle] = \ - InboundAssertion(remote_handle, turn.publish(ref, assertion), imported) + InboundAssertion(remote_handle, turn.publish(ref, assertion), pins) - def _handle_retract(self, turn, ref, remote_handle): + def _handle_retract(self, pins, turn, ref, remote_handle): ia = self.inbound_assertions.pop(remote_handle, None) if ia is None: raise ValueError('Peer retracted invalid handle %s' % (remote_handle,)) - for ws in ia.wire_symbols: - self.imported_references.drop(ws) + drop_all(ia.pins) + drop_all(pins) turn.retract(ia.local_handle) - def _handle_message(self, turn, ref, message): - (message, imported) = self.rewrite_in(turn, message) - if len(imported) > 0: - raise ValueError('Cannot receive transient reference') + def _handle_message(self, pins, turn, ref, message): + message = self.rewrite_in(turn, message, pins) + for ws in pins: + if ws.count == 1: + raise ValueError('Cannot receive transient reference') turn.send(ref, message) + drop_all(pins) - def _handle_sync(self, turn, ref, wire_peer): - imported = [] - peer = self.rewrite_ref_in(turn, wire_peer, imported) + def _handle_sync(self, pins, turn, ref, wire_peer): + peer = self.rewrite_ref_in(turn, wire_peer, pins) def done(turn): turn.send(peer, True) - for ws in imported: - self.imported_references.drop(ws) + drop_all(pins) turn.sync(ref, done) def _send(self, remote_oid, turn_event): @@ -269,7 +283,7 @@ class RelayEntity(actor.Entity): def on_publish(self, turn, assertion, handle): self._send(protocol.Event.Assert(protocol.Assert( - protocol.Assertion(self.relay.register(assertion, handle)), + protocol.Assertion(self.relay.register(self.oid, assertion, handle)), protocol.Handle(handle)))) def on_retract(self, turn, handle): @@ -278,22 +292,23 @@ class RelayEntity(actor.Entity): def on_message(self, turn, message): self._send(protocol.Event.Message(protocol.Message( - protocol.Assertion(self.relay.register(message, None))))) + protocol.Assertion(self.relay.register(self.oid, message, None))))) def on_sync(self, turn, peer): - exported = [] - entity = SyncPeerEntity(self.relay, peer, exported) - rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, exported)) + pins = [] + self.relay.register_imported_oid(self.oid, pins) + entity = SyncPeerEntity(self.relay, peer, pins) + rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, pins)) self._send(protocol.Event.Sync(protocol.Sync(rewritten))) class SyncPeerEntity(actor.Entity): - def __init__(self, relay, peer, exported): + def __init__(self, relay, peer, pins): self.relay = relay self.peer = peer - self.exported = exported + self.pins = pins def on_message(self, turn, body): - self.relay.exported_references.drop(self.exported[0]) + drop_all(self.pins) turn.send(self.peer, body) async def _default_on_connected(relay):