From 2e6b2787723ec23a2080a4f0f7cc53dca71368c2 Mon Sep 17 00:00:00 2001 From: Forest Belton <65484+forestbelton@users.noreply.github.com> Date: Mon, 21 Jun 2021 22:24:44 -0400 Subject: [PATCH] Factor out message (de)serialization --- paircd/main.py | 76 +++++++++-------------- paircd/message.py | 65 ++++++++++++++++++++ poetry.lock | 152 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 48 deletions(-) create mode 100644 paircd/message.py create mode 100644 poetry.lock diff --git a/paircd/main.py b/paircd/main.py index 506b0d6..854e076 100644 --- a/paircd/main.py +++ b/paircd/main.py @@ -4,6 +4,8 @@ import logging import os from typing import Dict, List, Set +from paircd.message import IRCMessage + clients_by_nick: Dict[str, "Client"] = {} all_channels: Dict[str, "Channel"] = {} @@ -43,59 +45,56 @@ logger = logging.getLogger() async def handle_reader(client: Client) -> None: while True: raw_msg = await client.reader.readuntil(b"\r\n") - if not raw_msg.endswith(b"\r\n"): - raise RuntimeError("malformed message") - msg = parse_irc_msg(raw_msg) + msg = IRCMessage.parse(raw_msg) await handle_irc_msg(client, msg) -async def handle_irc_msg(client: Client, msg: List[str]) -> None: - if msg[0] == "CAP": +async def handle_irc_msg(client: Client, msg: IRCMessage) -> None: + if msg.cmd == "CAP": # https://ircv3.net/specs/extensions/capability-negotiation.html logging.warning("TODO: implement support for client capability negotiation") - elif msg[0] == "NICK": - assert_argc(msg, 2) + elif msg.cmd == "NICK": if client.nickname: del clients_by_nick[client.nickname] - client.nickname = msg[1] + client.nickname = msg.args[0] clients_by_nick[client.nickname] = client if client.username and client.realname: client.registered = True logging.info(f"{client.hostname} ({client.id()}) registered") - elif msg[0] == "USER": - assert_argc(msg, 5) + elif msg.cmd == "USER": if client.registered: raise RuntimeError("USER command issued after registration") - client.username = msg[1] - client.realname = msg[4] + client.username = msg.args[0] + client.realname = msg.args[3] if client.nickname: client.registered = True logging.info(f"{client.hostname} ({client.id()}) registered") - elif msg[0] == "JOIN": - assert_argc(msg, 2) + elif msg.cmd == "JOIN": + channel_name = msg.args[0] assert_registered(client, msg) - if not msg[1].startswith("#"): + if not channel_name.startswith("#"): raise RuntimeError("invalid channel name") - if msg[1] not in all_channels: - all_channels[msg[1]] = Channel(name=msg[1]) - asyncio.create_task(process_channel(all_channels[msg[1]])) - channel = all_channels[msg[1]] - client.channels.add(channel.name) + if channel_name not in all_channels: + all_channels[channel_name] = Channel(name=channel_name) + asyncio.create_task(process_channel(all_channels[channel_name])) + client.channels.add(channel_name) + channel = all_channels[channel_name] channel.clients_by_host[client.hostname] = client - logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}") + logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}") await channel.msg_queue.put( - f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8") + IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode() ) - elif msg[0] == "PRIVMSG": - assert_argc(msg, 3) + elif msg.cmd == "PRIVMSG": assert_registered(client, msg) - await privmsg(client, msg[1], msg[2]) + await privmsg(client, msg.args[0], msg.args[1]) else: - logging.warning(f"unsupported message {msg[0]}") + logging.warning(f"unsupported message {msg.cmd}") async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: - msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8") + msg = IRCMessage( + "PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname + ).encode() for name, other_client in clients_by_nick.items(): if name == recipient: @@ -119,10 +118,10 @@ async def process_channel(channel: Channel) -> None: client.msg_queue.put_nowait(msg) -def assert_registered(client: Client, msg: List[str]) -> None: +def assert_registered(client: Client, msg: IRCMessage) -> None: if client.registered: return - raise RuntimeError(f"{msg[0]} command issued before client fully registered") + raise RuntimeError(f"{msg.cmd} issued before client fully registered") def assert_argc(xs: List[str], i: int) -> None: @@ -131,25 +130,6 @@ def assert_argc(xs: List[str], i: int) -> None: raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") -def parse_irc_msg(raw_msg: str) -> List[str]: - tokens: List[str] = [] - raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ") - - for i, token in enumerate(raw_tokens): - if token.startswith(":"): - trailing = token[1:] + " ".join(raw_tokens[i + 1 :]) - tokens.append(trailing) - break - tokens.append(token) - - if len(tokens) > 0: - tokens[-1] = tokens[-1].strip() - else: - raise RuntimeError("empty message") - - return tokens - - async def handle_writer(client: Client) -> None: while True: msg = await client.msg_queue.get() diff --git a/paircd/message.py b/paircd/message.py new file mode 100644 index 0000000..b52159b --- /dev/null +++ b/paircd/message.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass +from typing import List, Optional + +MAX_MESSAGE_SIZE = 512 + +EXPECTED_ARG_COUNT = { + "NICK": 1, + "USER": 4, + "JOIN": 1, + "PRIVMSG": 2, +} + + +class IRCParsingError(Exception): + message: str + + def __init__(self, message: Optional[str] = None) -> None: + self.message = message + + +@dataclass +class IRCMessage: + cmd: str + args: List[str] + prefix: str = "" + + def encode(self) -> bytes: + prefix = self.prefix + if prefix != "": + prefix = f":{prefix} " + + # TODO: Raise exception if formatted message exceeds 512 bytes + return f"{prefix}{self.cmd} {' '.join(self.args)}\r\n".encode("utf-8") + + @staticmethod + def parse(raw: bytes) -> "IRCMessage": + if len(raw) > MAX_MESSAGE_SIZE: + raise IRCParsingError( + f"Message is {len(raw)} bytes, larger than allowed {MAX_MESSAGE_SIZE}" + ) + + if not raw.endswith(b"\r\n"): + raise IRCParsingError("Message does not terminate in CRLF") + + tokens: List[str] = [] + raw_tokens: List[str] = raw.decode("utf-8").split(" ") + + for i, token in enumerate(raw_tokens): + if token.startswith(":"): + trailing = token[1:] + " " + " ".join(raw_tokens[i + 1 :]) + tokens.append(trailing) + break + tokens.append(token) + + if len(tokens) == 0: + raise IRCParsingError("Message has no command") + + cmd = tokens[0] + if cmd in EXPECTED_ARG_COUNT and EXPECTED_ARG_COUNT[cmd] != len(tokens) - 1: + raise IRCParsingError( + f"{cmd} had {len(tokens)-1} arguments, expected {EXPECTED_ARG_COUNT[cmd]}" + ) + + tokens[-1] = tokens[-1].strip() + return IRCMessage(cmd=tokens[0], args=tokens[1:]) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..d72045a --- /dev/null +++ b/poetry.lock @@ -0,0 +1,152 @@ +[[package]] +name = "atomicwrites" +version = "1.4.0" +description = "Atomic file writes." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "attrs" +version = "21.2.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[package.extras] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] +docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] + +[[package]] +name = "colorama" +version = "0.4.4" +description = "Cross-platform colored terminal text." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "more-itertools" +version = "8.8.0" +description = "More routines for operating on iterables, beyond itertools" +category = "dev" +optional = false +python-versions = ">=3.5" + +[[package]] +name = "packaging" +version = "20.9" +description = "Core utilities for Python packages" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.dependencies] +pyparsing = ">=2.0.2" + +[[package]] +name = "pluggy" +version = "0.13.1" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.extras] +dev = ["pre-commit", "tox"] + +[[package]] +name = "py" +version = "1.10.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "pyparsing" +version = "2.4.7" +description = "Python parsing module" +category = "dev" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" + +[[package]] +name = "pytest" +version = "5.4.3" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} +attrs = ">=17.4.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +more-itertools = ">=4.0.0" +packaging = "*" +pluggy = ">=0.12,<1.0" +py = ">=1.5.0" +wcwidth = "*" + +[package.extras] +checkqa-mypy = ["mypy (==v0.761)"] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] + +[[package]] +name = "wcwidth" +version = "0.2.5" +description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" +optional = false +python-versions = "*" + +[metadata] +lock-version = "1.1" +python-versions = "^3.9" +content-hash = "4d1de49710d78bd295469a572576efe3d5b96e6e8760458e870affe880e8d10e" + +[metadata.files] +atomicwrites = [ + {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, + {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, +] +attrs = [ + {file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, + {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, +] +colorama = [ + {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, + {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, +] +more-itertools = [ + {file = "more-itertools-8.8.0.tar.gz", hash = "sha256:83f0308e05477c68f56ea3a888172c78ed5d5b3c282addb67508e7ba6c8f813a"}, + {file = "more_itertools-8.8.0-py3-none-any.whl", hash = "sha256:2cf89ec599962f2ddc4d568a05defc40e0a587fbc10d5989713638864c36be4d"}, +] +packaging = [ + {file = "packaging-20.9-py2.py3-none-any.whl", hash = "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"}, + {file = "packaging-20.9.tar.gz", hash = "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5"}, +] +pluggy = [ + {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, + {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, +] +py = [ + {file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"}, + {file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"}, +] +pyparsing = [ + {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, + {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, +] +pytest = [ + {file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"}, + {file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, +] +wcwidth = [ + {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, + {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, +]