diff --git a/chat.py b/chat.py index c3f4f3c..2e2f47c 100644 --- a/chat.py +++ b/chat.py @@ -7,19 +7,15 @@ import syndicate.mini.core as S Present = S.Record.makeConstructor('Present', 'who') Says = S.Record.makeConstructor('Says', 'who what') -if len(sys.argv) == 4: - conn = S.TcpConnection(sys.argv[1], int(sys.argv[2]), sys.argv[3]) -elif len(sys.argv) == 3: - if sys.argv[1].startswith('ws:') or sys.argv[1].startswith('wss:'): - conn = S.WebsocketConnection(sys.argv[1], sys.argv[2]) - else: - conn = S.UnixSocketConnection(sys.argv[1], sys.argv[2]) -elif len(sys.argv) == 1: - conn = S.WebsocketConnection('ws://localhost:8000/', 'chat') +if len(sys.argv) == 1: + conn_url = 'ws://localhost:8000/#chat' +elif len(sys.argv) == 2: + conn_url = sys.argv[1] else: sys.stderr.write( - 'Usage: chat.py [ HOST PORT SCOPE | WEBSOCKETURL SCOPE | UNIXSOCKETPATH SCOPE ]\n') + 'Usage: chat.py [ tcp://HOST[:PORT]#SCOPE | ws://HOST[:PORT]#SCOPE | unix:PATH#SCOPE ]\n') sys.exit(1) +conn = S.Connection.from_url(conn_url) _print = print def print(*items): diff --git a/ovlinfo.py b/ovlinfo.py index be788e0..f1e9c44 100644 --- a/ovlinfo.py +++ b/ovlinfo.py @@ -6,7 +6,7 @@ import syndicate.mini.core as S OverlayLink = S.Record.makeConstructor('OverlayLink', 'downNode upNode') -conn = S.WebsocketConnection(sys.argv[1], sys.argv[2]) +conn = S.Connection.from_url(sys.argv[1]) uplinks = {} def add_uplink(turn, src, tgt): diff --git a/setup.py b/setup.py index ba7bc45..c1d8040 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ except ImportError: setup( name="mini-syndicate", - version="0.0.1", + version="0.0.2", author="Tony Garnock-Jones", author_email="tonyg@leastfixedpoint.com", license="GNU General Public License v3 or later (GPLv3+)", diff --git a/syndicate/mini/core.py b/syndicate/mini/core.py index c5f06ad..7677f06 100644 --- a/syndicate/mini/core.py +++ b/syndicate/mini/core.py @@ -2,10 +2,13 @@ import asyncio import secrets import logging import websockets +import re +from urllib.parse import urlparse, urlunparse log = logging.getLogger(__name__) import syndicate.mini.protocol as protocol +import syndicate.mini.url as url from syndicate.mini.protocol import Capture, Discard, Observe CAPTURE = Capture(Discard()) @@ -203,6 +206,10 @@ class Connection(object): def _disconnect(self): raise Exception('subclassresponsibility') + @classmethod + def from_url(cls, s): + return url.connection_from_url(s) + class _StreamConnection(Connection, asyncio.Protocol): def __init__(self, scope): super().__init__(scope) @@ -260,6 +267,7 @@ class _StreamConnection(Connection, asyncio.Protocol): self.stop_signal = None self.decoder = None +@url.schema('tcp') class TcpConnection(_StreamConnection): def __init__(self, host, port, scope): super().__init__(scope) @@ -269,6 +277,19 @@ class TcpConnection(_StreamConnection): async def _create_connection(self, loop): return await loop.create_connection(lambda: self, self.host, self.port) + @classmethod + def default_port(cls): + return 21369 + + @classmethod + def from_url(cls, s): + u = urlparse(s) + host, port = url._hostport(u.netloc, cls.default_port()) + if not host: return + scope = u.fragment + return cls(host, port, scope) + +@url.schema('unix') class UnixSocketConnection(_StreamConnection): def __init__(self, path, scope): super().__init__(scope) @@ -277,6 +298,13 @@ class UnixSocketConnection(_StreamConnection): async def _create_connection(self, loop): return await loop.create_unix_connection(lambda: self, self.path) + @classmethod + def from_url(cls, s): + u = urlparse(s) + return cls(u.path, u.fragment) + +@url.schema('ws') +@url.schema('wss') class WebsocketConnection(Connection): def __init__(self, url, scope): super().__init__(scope) @@ -328,3 +356,8 @@ class WebsocketConnection(Connection): self.loop = None self.ws = None return True + + @classmethod + def from_url(cls, s): + u = urlparse(s) + return cls(urlunparse(u._replace(fragment='')), u.fragment) diff --git a/syndicate/mini/url.py b/syndicate/mini/url.py new file mode 100644 index 0000000..a5aa2fe --- /dev/null +++ b/syndicate/mini/url.py @@ -0,0 +1,33 @@ +# URLs denoting Syndicate servers. + +class InvalidSyndicateUrl(ValueError): pass + +schemas = {} + +def schema(schema_name): + def k(factory_class): + schemas[schema_name] = factory_class + return factory_class + return k + +def _bad_url(u): + raise InvalidSyndicateUrl('Invalid Syndicate server URL', u) + +def connection_from_url(u): + pieces = u.split(':', 1) + if len(pieces) != 2: _bad_url(u) + schema_name, _rest = pieces + if schema_name not in schemas: _bad_url(u) + conn = schemas[schema_name].from_url(u) + if not conn: _bad_url(u) + return conn + +def _hostport(s, default_port): + try: + i = s.rindex(':') + except ValueError: + i = None + if i is not None: + return (s[:i], int(s[i+1:])) + else: + return (s, default_port)