From 561aa01fea1aa6f71e63e411613370b17788aa29 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Fri, 29 Mar 2024 14:06:34 +0100 Subject: [PATCH] Support connection_timeout --- syndicate/relay.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/syndicate/relay.py b/syndicate/relay.py index d47f5e7..e4d4d0e 100644 --- a/syndicate/relay.py +++ b/syndicate/relay.py @@ -71,6 +71,7 @@ class TunnelRelay: publish_oid = 0, on_connected = None, on_disconnected = None, + connection_timeout = None, ): self.facet = turn.active_facet() self.facet.on_stop(self._shutdown) @@ -79,6 +80,7 @@ class TunnelRelay: self.gatekeeper_oid = gatekeeper_oid self.publish_service = publish_service self.publish_oid = publish_oid + self.connection_timeout = connection_timeout self._reset() self.facet.linked_task( lambda facet: self._reconnecting_main(facet.actor._system, @@ -341,10 +343,6 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol): def connection_lost(self, exc): self._on_disconnected() - def connection_made(self, transport): - self.transport = transport - self._on_connected() - def data_received(self, chunk): self.decoder.extend(chunk) while True: @@ -375,17 +373,26 @@ class _StreamTunnelRelay(TunnelRelay, asyncio.Protocol): self.decoder = Decoder(decode_embedded = sturdy.WireRef.decode) self.stop_signal = system.loop.create_future() try: - _transport, _protocol = await self._create_connection(system) - except OSError as e: - self.facet.log.error('%s: Could not connect to server: %s' % (self.__class__.__qualname__, e)) - return False + try: + transport, _protocol = await asyncio.wait_for( + self._create_connection(system), timeout=self.connection_timeout) + except asyncio.TimeoutError: + self.facet.log.error( + '%s: Timeout connecting to server' % (self.__class__.__qualname__,)) + return False + except OSError as e: + self.facet.log.error( + '%s: Could not connect to server: %s' % (self.__class__.__qualname__, e)) + return False - try: + self.transport = transport + self._on_connected() if on_connected: await on_connected(self) await self.stop_signal return True finally: - self.transport.close() + if self.transport: + self.transport.close() self.transport = None self.stop_signal = None self.decoder = None @@ -434,7 +441,10 @@ class WebsocketTunnelRelay(TunnelRelay): self.system = system try: - self.ws = await websockets.connect(self.address.url) + self.ws = await websockets.connect( + self.address.url, open_timeout=self.connection_timeout) + except asyncio.TimeoutError: + return self.__connection_error('timeout') except OSError as e: return self.__connection_error(e) except websockets.exceptions.InvalidHandshake as e: