Compare commits
4 Commits
0364c38068
...
561aa01fea
Author | SHA1 | Date |
---|---|---|
Tony Garnock-Jones | 561aa01fea | |
Tony Garnock-Jones | 0429e59ad1 | |
Tony Garnock-Jones | f2b8b433cc | |
Tony Garnock-Jones | 6f6993ce4c |
|
@ -230,11 +230,7 @@ class Facet:
|
|||
if run_in_executor:
|
||||
inner_coro_fn = coro_fn
|
||||
async def outer_coro_fn(facet):
|
||||
try:
|
||||
await self.loop.run_in_executor(None, lambda: inner_coro_fn(facet))
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
await self.loop.run_in_executor(None, lambda: inner_coro_fn(facet))
|
||||
coro_fn = outer_coro_fn
|
||||
@self.on_stop_or_crash
|
||||
def cancel_linked_task():
|
||||
|
@ -246,6 +242,9 @@ class Facet:
|
|||
async def guarded_task():
|
||||
try:
|
||||
await coro_fn(self)
|
||||
except:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
Turn.external(self, cancel_linked_task)
|
||||
task = self.loop.create_task(guarded_task())
|
||||
|
|
|
@ -71,6 +71,7 @@ class TunnelRelay:
|
|||
publish_oid = 0,
|
||||
on_connected = None,
|
||||
on_disconnected = None,
|
||||
connection_timeout = None,
|
||||
):
|
||||
self.facet = turn.active_facet()
|
||||
self.facet.on_stop(self._shutdown)
|
||||
|
@ -79,6 +80,7 @@ class TunnelRelay:
|
|||
self.gatekeeper_oid = gatekeeper_oid
|
||||
self.publish_service = publish_service
|
||||
self.publish_oid = publish_oid
|
||||
self.connection_timeout = connection_timeout
|
||||
self._reset()
|
||||
self.facet.linked_task(
|
||||
lambda facet: self._reconnecting_main(facet.actor._system,
|
||||
|
@ -265,11 +267,14 @@ class TunnelRelay:
|
|||
return transport.connection_from_str(conn_str, **kwargs)
|
||||
|
||||
# decorator
|
||||
def connect(conn_str, cap, **kwargs):
|
||||
def connect(conn_str, cap = None, **kwargs):
|
||||
def prepare_resolution_handler(handler):
|
||||
@During().add_handler
|
||||
def handle_gatekeeper(gk):
|
||||
gatekeeper.resolve(gk.embeddedValue, cap)(handler)
|
||||
if cap is None:
|
||||
handler(gk.embeddedValue)
|
||||
else:
|
||||
gatekeeper.resolve(gk.embeddedValue, cap)(handler)
|
||||
return transport.connection_from_str(
|
||||
conn_str,
|
||||
gatekeeper_peer = turn.ref(handle_gatekeeper),
|
||||
|
@ -338,10 +343,6 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
|
|||
def connection_lost(self, exc):
|
||||
self._on_disconnected()
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self._on_connected()
|
||||
|
||||
def data_received(self, chunk):
|
||||
self.decoder.extend(chunk)
|
||||
while True:
|
||||
|
@ -372,17 +373,26 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
|
|||
self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode)
|
||||
self.stop_signal = system.loop.create_future()
|
||||
try:
|
||||
_transport, _protocol = await self._create_connection(system)
|
||||
except OSError as e:
|
||||
log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
|
||||
return False
|
||||
try:
|
||||
transport, _protocol = await asyncio.wait_for(
|
||||
self._create_connection(system), timeout=self.connection_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.facet.log.error(
|
||||
'%s: Timeout connecting to server' % (self.__class__.__qualname__,))
|
||||
return False
|
||||
except OSError as e:
|
||||
self.facet.log.error(
|
||||
'%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
|
||||
return False
|
||||
|
||||
try:
|
||||
self.transport = transport
|
||||
self._on_connected()
|
||||
if on_connected: await on_connected(self)
|
||||
await self.stop_signal
|
||||
return True
|
||||
finally:
|
||||
self.transport.close()
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
self.stop_signal = None
|
||||
self.decoder = None
|
||||
|
@ -431,7 +441,10 @@ class WebsocketTunnelRelay(TunnelRelay):
|
|||
self.system = system
|
||||
|
||||
try:
|
||||
self.ws = await websockets.connect(self.address.url)
|
||||
self.ws = await websockets.connect(
|
||||
self.address.url, open_timeout=self.connection_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return self.__connection_error('timeout')
|
||||
except OSError as e:
|
||||
return self.__connection_error(e)
|
||||
except websockets.exceptions.InvalidHandshake as e:
|
||||
|
|
Loading…
Reference in New Issue