Support connection_timeout

This commit is contained in:
Tony Garnock-Jones 2024-03-29 14:06:34 +01:00
parent 0429e59ad1
commit 561aa01fea
1 changed files with 21 additions and 11 deletions

View File

@ -71,6 +71,7 @@ class TunnelRelay:
publish_oid = 0,
on_connected = None,
on_disconnected = None,
connection_timeout = None,
):
self.facet = turn.active_facet()
self.facet.on_stop(self._shutdown)
@ -79,6 +80,7 @@ class TunnelRelay:
self.gatekeeper_oid = gatekeeper_oid
self.publish_service = publish_service
self.publish_oid = publish_oid
self.connection_timeout = connection_timeout
self._reset()
self.facet.linked_task(
lambda facet: self._reconnecting_main(facet.actor._system,
@ -341,10 +343,6 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
def connection_lost(self, exc):
self._on_disconnected()
def connection_made(self, transport):
self.transport = transport
self._on_connected()
def data_received(self, chunk):
self.decoder.extend(chunk)
while True:
@ -375,17 +373,26 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol):
self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode)
self.stop_signal = system.loop.create_future()
try:
_transport, _protocol = await self._create_connection(system)
except OSError as e:
self.facet.log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
try:
transport, _protocol = await asyncio.wait_for(
self._create_connection(system), timeout=self.connection_timeout)
except asyncio.TimeoutError:
self.facet.log.error(
'%s: Timeout connecting to server' % (self.__class__.__qualname__,))
return False
except OSError as e:
self.facet.log.error(
'%s: Could not connect to server: %s' % (self.__class__.__qualname__, e))
return False
try:
self.transport = transport
self._on_connected()
if on_connected: await on_connected(self)
await self.stop_signal
return True
finally:
self.transport.close()
if self.transport:
self.transport.close()
self.transport = None
self.stop_signal = None
self.decoder = None
@ -434,7 +441,10 @@ class WebsocketTunnelRelay(TunnelRelay):
self.system = system
try:
self.ws = await websockets.connect(self.address.url)
self.ws = await websockets.connect(
self.address.url, open_timeout=self.connection_timeout)
except asyncio.TimeoutError:
return self.__connection_error('timeout')
except OSError as e:
return self.__connection_error(e)
except websockets.exceptions.InvalidHandshake as e: