Fix refcounting in relay

This commit is contained in:
Tony Garnock-Jones 2021-09-07 14:57:24 +02:00
parent 03522c32ce
commit d6dc75e41d
1 changed files with 80 additions and 65 deletions

View File

@ -13,48 +13,57 @@ from .idgen import IdGenerator
from .schema import externalProtocol as protocol, sturdy, transportAddress from .schema import externalProtocol as protocol, sturdy, transportAddress
class InboundAssertion: class InboundAssertion:
def __init__(self, remote_handle, local_handle, wire_symbols): def __init__(self, remote_handle, local_handle, pins):
self.remote_handle = remote_handle self.remote_handle = remote_handle
self.local_handle = local_handle self.local_handle = local_handle
self.wire_symbols = wire_symbols self.pins = pins
_next_local_oid = IdGenerator() _next_local_oid = IdGenerator()
class WireSymbol: class WireSymbol:
def __init__(self, oid, ref): def __init__(self, oid, ref, membrane):
self.oid = oid self.oid = oid
self.ref = ref self.ref = ref
self.membrane = membrane
self.count = 0 self.count = 0
def __repr__(self): def __repr__(self):
return '<ws:%d/%d:%r>' % (self.oid, self.count, self.ref) return '<ws:%d/%d:%r>' % (self.oid, self.count, self.ref)
def grab(self, pins):
self.count = self.count + 1
pins.append(self)
def drop(self):
self.count = self.count - 1
if self.count == 0:
del self.membrane.oid_map[self.oid]
del self.membrane.ref_map[self.ref]
class Membrane: class Membrane:
def __init__(self): def __init__(self):
self.oid_map = {} self.oid_map = {}
self.ref_map = {} self.ref_map = {}
def _get(self, map, key, is_transient, ws_maker): def _get(self, pins, map, key, is_transient, ws_maker):
ws = map.get(key, None) ws = map.get(key, None)
if ws is None: if ws is None and ws_maker is not None:
ws = ws_maker() ws = ws_maker()
self.oid_map[ws.oid] = ws self.oid_map[ws.oid] = ws
self.ref_map[ws.ref] = ws self.ref_map[ws.ref] = ws
if not is_transient: if not is_transient and ws is not None:
ws.count = ws.count + 1 ws.grab(pins)
return ws return ws
def get_ref(self, local_ref, is_transient, ws_maker): def get_ref(self, pins, local_ref, is_transient, ws_maker):
return self._get(self.ref_map, local_ref, is_transient, ws_maker) return self._get(pins, self.ref_map, local_ref, is_transient, ws_maker)
def get_oid(self, remote_oid, ws_maker): def get_oid(self, pins, remote_oid, ws_maker):
return self._get(self.oid_map, remote_oid, False, ws_maker) return self._get(pins, self.oid_map, remote_oid, False, ws_maker)
def drop(self, ws): def drop_all(wss):
ws.count = ws.count - 1 for ws in wss:
if ws.count == 0: ws.drop()
del self.oid_map[ws.oid]
del self.ref_map[ws.ref]
# There are other kinds of relay. This one has exactly two participants connected to each other. # There are other kinds of relay. This one has exactly two participants connected to each other.
class TunnelRelay: class TunnelRelay:
@ -79,7 +88,7 @@ class TunnelRelay:
def _reset(self): def _reset(self):
self.inbound_assertions = {} # map remote handle to InboundAssertion self.inbound_assertions = {} # map remote handle to InboundAssertion
self.outbound_assertions = {} # map local handle to wire_symbols self.outbound_assertions = {} # map local handle to `WireSymbol`s
self.exported_references = Membrane() self.exported_references = Membrane()
self.imported_references = Membrane() self.imported_references = Membrane()
self.pending_turn = [] self.pending_turn = []
@ -94,49 +103,53 @@ class TunnelRelay:
self._disconnect() self._disconnect()
def deregister(self, handle): def deregister(self, handle):
for ws in self.outbound_assertions.pop(handle, ()): drop_all(self.outbound_assertions.pop(handle, ()))
self.exported_references.drop(ws)
def _lookup(self, local_oid): def _lookup_exported_oid(self, local_oid, pins):
ws = self.exported_references.oid_map.get(local_oid, None) ws = self.exported_references.get_oid(pins, local_oid, None)
return _inert_ref if ws is None else ws.ref if ws is None:
return _inert_ref
return ws.ref
def register(self, assertion, maybe_handle): def register_imported_oid(self, remote_oid, pins):
exported = [] self.imported_references.get_oid(pins, remote_oid, None)
def register(self, target_oid, assertion, maybe_handle):
pins = []
self.register_imported_oid(target_oid, pins)
rewritten = map_embeddeds( rewritten = map_embeddeds(
lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, exported)), lambda r: Embedded(self.rewrite_ref_out(r, maybe_handle is None, pins)),
assertion) assertion)
if maybe_handle is not None: if maybe_handle is not None:
self.outbound_assertions[maybe_handle] = exported self.outbound_assertions[maybe_handle] = pins
return rewritten return rewritten
def rewrite_ref_out(self, r, is_transient, exported): def rewrite_ref_out(self, r, is_transient, pins):
if isinstance(r.entity, RelayEntity) and r.entity.relay == self: if isinstance(r.entity, RelayEntity) and r.entity.relay == self:
# TODO attenuation # TODO attenuation
return sturdy.WireRef.yours(sturdy.Oid(r.entity.oid), ()) return sturdy.WireRef.yours(sturdy.Oid(r.entity.oid), ())
else: else:
ws = self.exported_references.get_ref( ws = self.exported_references.get_ref(
r, is_transient, lambda: WireSymbol(next(_next_local_oid), r)) pins, r, is_transient, lambda: WireSymbol(next(_next_local_oid), r,
exported.append(ws) self.exported_references))
return sturdy.WireRef.mine(sturdy.Oid(ws.oid)) return sturdy.WireRef.mine(sturdy.Oid(ws.oid))
def rewrite_in(self, turn, assertion): def rewrite_in(self, turn, assertion, pins):
imported = []
rewritten = map_embeddeds( rewritten = map_embeddeds(
lambda wire_ref: Embedded(self.rewrite_ref_in(turn, wire_ref, imported)), lambda wire_ref: Embedded(self.rewrite_ref_in(turn, wire_ref, pins)),
assertion) assertion)
return (rewritten, imported) return rewritten
def rewrite_ref_in(self, turn, wire_ref, imported): def rewrite_ref_in(self, turn, wire_ref, pins):
if wire_ref.VARIANT.name == 'mine': if wire_ref.VARIANT.name == 'mine':
oid = wire_ref.oid.value oid = wire_ref.oid.value
ws = self.imported_references.get_oid( ws = self.imported_references.get_oid(
oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid)))) pins, oid, lambda: WireSymbol(oid, turn.ref(RelayEntity(self, oid)),
imported.append(ws) self.imported_references))
return ws.ref return ws.ref
else: else:
oid = wire_ref.oid.value oid = wire_ref.oid.value
local_ref = self._lookup(oid) local_ref = self._lookup_exported_oid(oid, pins)
attenuation = wire_ref.attenuation attenuation = wire_ref.attenuation
if len(attenuation) > 0: if len(attenuation) > 0:
raise NotImplementedError('Non-empty attenuations not yet implemented') # TODO raise NotImplementedError('Non-empty attenuations not yet implemented') # TODO
@ -178,44 +191,45 @@ class TunnelRelay:
def _handle_turn_events(self, turn, events): def _handle_turn_events(self, turn, events):
for e in events: for e in events:
ref = self._lookup(e.oid.value) pins = []
ref = self._lookup_exported_oid(e.oid.value, pins)
event = e.event event = e.event
variant = event.VARIANT.name variant = event.VARIANT.name
if variant == 'Assert': if variant == 'Assert':
self._handle_publish(turn, ref, event.value.assertion.value, event.value.handle.value) self._handle_publish(pins, turn, ref, event.value.assertion.value, event.value.handle.value)
elif variant == 'Retract': elif variant == 'Retract':
self._handle_retract(turn, ref, event.value.handle.value) self._handle_retract(pins, turn, ref, event.value.handle.value)
elif variant == 'Message': elif variant == 'Message':
self._handle_message(turn, ref, event.value.body.value) self._handle_message(pins, turn, ref, event.value.body.value)
elif variant == 'Sync': elif variant == 'Sync':
self._handle_sync(turn, ref, event.value.peer) self._handle_sync(pins, turn, ref, event.value.peer)
def _handle_publish(self, turn, ref, assertion, remote_handle): def _handle_publish(self, pins, turn, ref, assertion, remote_handle):
(assertion, imported) = self.rewrite_in(turn, assertion) assertion = self.rewrite_in(turn, assertion, pins)
self.inbound_assertions[remote_handle] = \ self.inbound_assertions[remote_handle] = \
InboundAssertion(remote_handle, turn.publish(ref, assertion), imported) InboundAssertion(remote_handle, turn.publish(ref, assertion), pins)
def _handle_retract(self, turn, ref, remote_handle): def _handle_retract(self, pins, turn, ref, remote_handle):
ia = self.inbound_assertions.pop(remote_handle, None) ia = self.inbound_assertions.pop(remote_handle, None)
if ia is None: if ia is None:
raise ValueError('Peer retracted invalid handle %s' % (remote_handle,)) raise ValueError('Peer retracted invalid handle %s' % (remote_handle,))
for ws in ia.wire_symbols: drop_all(ia.pins)
self.imported_references.drop(ws) drop_all(pins)
turn.retract(ia.local_handle) turn.retract(ia.local_handle)
def _handle_message(self, turn, ref, message): def _handle_message(self, pins, turn, ref, message):
(message, imported) = self.rewrite_in(turn, message) message = self.rewrite_in(turn, message, pins)
if len(imported) > 0: for ws in pins:
raise ValueError('Cannot receive transient reference') if ws.count == 1:
raise ValueError('Cannot receive transient reference')
turn.send(ref, message) turn.send(ref, message)
drop_all(pins)
def _handle_sync(self, turn, ref, wire_peer): def _handle_sync(self, pins, turn, ref, wire_peer):
imported = [] peer = self.rewrite_ref_in(turn, wire_peer, pins)
peer = self.rewrite_ref_in(turn, wire_peer, imported)
def done(turn): def done(turn):
turn.send(peer, True) turn.send(peer, True)
for ws in imported: drop_all(pins)
self.imported_references.drop(ws)
turn.sync(ref, done) turn.sync(ref, done)
def _send(self, remote_oid, turn_event): def _send(self, remote_oid, turn_event):
@ -269,7 +283,7 @@ class RelayEntity(actor.Entity):
def on_publish(self, turn, assertion, handle): def on_publish(self, turn, assertion, handle):
self._send(protocol.Event.Assert(protocol.Assert( self._send(protocol.Event.Assert(protocol.Assert(
protocol.Assertion(self.relay.register(assertion, handle)), protocol.Assertion(self.relay.register(self.oid, assertion, handle)),
protocol.Handle(handle)))) protocol.Handle(handle))))
def on_retract(self, turn, handle): def on_retract(self, turn, handle):
@ -278,22 +292,23 @@ class RelayEntity(actor.Entity):
def on_message(self, turn, message): def on_message(self, turn, message):
self._send(protocol.Event.Message(protocol.Message( self._send(protocol.Event.Message(protocol.Message(
protocol.Assertion(self.relay.register(message, None))))) protocol.Assertion(self.relay.register(self.oid, message, None)))))
def on_sync(self, turn, peer): def on_sync(self, turn, peer):
exported = [] pins = []
entity = SyncPeerEntity(self.relay, peer, exported) self.relay.register_imported_oid(self.oid, pins)
rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, exported)) entity = SyncPeerEntity(self.relay, peer, pins)
rewritten = Embedded(self.relay.rewrite_ref_out(turn.ref(entity), False, pins))
self._send(protocol.Event.Sync(protocol.Sync(rewritten))) self._send(protocol.Event.Sync(protocol.Sync(rewritten)))
class SyncPeerEntity(actor.Entity): class SyncPeerEntity(actor.Entity):
def __init__(self, relay, peer, exported): def __init__(self, relay, peer, pins):
self.relay = relay self.relay = relay
self.peer = peer self.peer = peer
self.exported = exported self.pins = pins
def on_message(self, turn, body): def on_message(self, turn, body):
self.relay.exported_references.drop(self.exported[0]) drop_all(self.pins)
turn.send(self.peer, body) turn.send(self.peer, body)
async def _default_on_connected(relay): async def _default_on_connected(relay):