Compare commits

...

4 Commits

2 changed files with 30 additions and 18 deletions

View File

@ -230,11 +230,7 @@ class Facet:
if run_in_executor:
inner_coro_fn = coro_fn
async def outer_coro_fn(facet):
try:
await self.loop.run_in_executor(None, lambda: inner_coro_fn(facet))
except:
import traceback
traceback.print_exc()
await self.loop.run_in_executor(None, lambda: inner_coro_fn(facet))
coro_fn = outer_coro_fn
@self.on_stop_or_crash
def cancel_linked_task():
@ -246,6 +242,9 @@ class Facet:
async def guarded_task():
try:
await coro_fn(self)
except:
import traceback
traceback.print_exc()
finally:
Turn.external(self, cancel_linked_task)
task = self.loop.create_task(guarded_task())

View File

@ -71,6 +71,7 @@ class TunnelRelay:
publish_oid = 0,
on_connected = None,
on_disconnected = None,
connection_timeout = None,
):
self.facet = turn.active_facet()
self.facet.on_stop(self._shutdown)
@ -79,6 +80,7 @@ class TunnelRelay:
self.gatekeeper_oid = gatekeeper_oid
self.publish_service = publish_service
self.publish_oid = publish_oid
self.connection_timeout = connection_timeout
self._reset()
self.facet.linked_task(
lambda facet: self._reconnecting_main(facet.actor._system,
@ -265,11 +267,14 @@ class TunnelRelay:
return transport.connection_from_str(conn_str, **kwargs)
# decorator
def connect(conn_str, cap, **kwargs):
def connect(conn_str, cap = None, **kwargs):
def prepare_resolution_handler(handler):
@During().add_handler
def handle_gatekeeper(gk):
gatekeeper.resolve(gk.embeddedValue, cap)(handler)
if cap is None:
handler(gk.embeddedValue)
else:
gatekeeper.resolve(gk.embeddedValue, cap)(handler)
return transport.connection_from_str(
conn_str,
gatekeeper_peer = turn.ref(handle_gatekeeper),
@ -338,10 +343,6 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
def connection_lost(self, exc):
self._on_disconnected()
def connection_made(self, transport):
self.transport = transport
self._on_connected()
def data_received(self, chunk):
self.decoder.extend(chunk)
while True:
@ -372,17 +373,26 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode)
self.stop_signal = system.loop.create_future()
try:
_transport, _protocol = await self._create_connection(system)
except OSError as e:
log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
try:
transport, _protocol = await asyncio.wait_for(
self._create_connection(system), timeout=self.connection_timeout)
except asyncio.TimeoutError:
self.facet.log.error(
'%s: Timeout connecting to server' % (self.__class__.__qualname__,))
return False
except OSError as e:
self.facet.log.error(
'%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
try:
self.transport = transport
self._on_connected()
if on_connected: await on_connected(self)
await self.stop_signal
return True
finally:
self.transport.close()
if self.transport:
self.transport.close()
self.transport = None
self.stop_signal = None
self.decoder = None
@ -431,7 +441,10 @@ class WebsocketTunnelRelay(TunnelRelay):
self.system = system
try:
self.ws = await websockets.connect(self.address.url)
self.ws = await websockets.connect(
self.address.url, open_timeout=self.connection_timeout)
except asyncio.TimeoutError:
return self.__connection_error('timeout')
except OSError as e:
return self.__connection_error(e)
except websockets.exceptions.InvalidHandshake as e: