From e4a2c3405d6835ee5b68b26cf0d4d77fade5eee7 Mon Sep 17 00:00:00 2001 From: Forest Belton <65484+forestbelton@users.noreply.github.com> Date: Mon, 21 Jun 2021 21:43:52 -0400 Subject: [PATCH] Implement basic IRCd --- README.md | 14 ++++ README.rst | 0 paircd/main.py | 188 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 202 insertions(+) create mode 100644 README.md delete mode 100644 README.rst create mode 100644 paircd/main.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..a42c22a --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +# paircd +a **P**ython **A**synchronous **IRC** **D**aemon + +## setup + +``` +$ poetry install +``` + +## run + +``` +$ poetry run python +``` \ No newline at end of file diff --git a/README.rst b/README.rst deleted file mode 100644 index e69de29..0000000 diff --git a/paircd/main.py b/paircd/main.py new file mode 100644 index 0000000..506b0d6 --- /dev/null +++ b/paircd/main.py @@ -0,0 +1,188 @@ +import asyncio +import dataclasses +import logging +import os +from typing import Dict, List, Set + +clients_by_nick: Dict[str, "Client"] = {} +all_channels: Dict[str, "Channel"] = {} + + +@dataclasses.dataclass +class Client: + hostname: str + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + msg_queue: asyncio.Queue + + nickname: str = "" + username: str = "" + realname: str = "" + registered: bool = False + + channels: Set[str] = dataclasses.field(default_factory=set) + + def id(self) -> str: + return f"{self.nickname}!{self.username}@{self.hostname}" + + +@dataclasses.dataclass +class Channel: + name: str + clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict) + msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) + + +logging.basicConfig( + format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO +) + +logger = logging.getLogger() + + +async def handle_reader(client: Client) -> None: + while True: + raw_msg = await client.reader.readuntil(b"\r\n") + if not raw_msg.endswith(b"\r\n"): + raise RuntimeError("malformed message") + msg = parse_irc_msg(raw_msg) + await handle_irc_msg(client, msg) + + +async def handle_irc_msg(client: Client, msg: List[str]) -> None: + if msg[0] == "CAP": + # https://ircv3.net/specs/extensions/capability-negotiation.html + logging.warning("TODO: implement support for client capability negotiation") + elif msg[0] == "NICK": + assert_argc(msg, 2) + if client.nickname: + del clients_by_nick[client.nickname] + client.nickname = msg[1] + clients_by_nick[client.nickname] = client + if client.username and client.realname: + client.registered = True + logging.info(f"{client.hostname} ({client.id()}) registered") + elif msg[0] == "USER": + assert_argc(msg, 5) + if client.registered: + raise RuntimeError("USER command issued after registration") + client.username = msg[1] + client.realname = msg[4] + if client.nickname: + client.registered = True + logging.info(f"{client.hostname} ({client.id()}) registered") + elif msg[0] == "JOIN": + assert_argc(msg, 2) + assert_registered(client, msg) + if not msg[1].startswith("#"): + raise RuntimeError("invalid channel name") + if msg[1] not in all_channels: + all_channels[msg[1]] = Channel(name=msg[1]) + asyncio.create_task(process_channel(all_channels[msg[1]])) + channel = all_channels[msg[1]] + client.channels.add(channel.name) + channel.clients_by_host[client.hostname] = client + logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}") + await channel.msg_queue.put( + f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8") + ) + elif msg[0] == "PRIVMSG": + assert_argc(msg, 3) + assert_registered(client, msg) + await privmsg(client, msg[1], msg[2]) + else: + logging.warning(f"unsupported message {msg[0]}") + + +async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: + msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8") + + for name, other_client in clients_by_nick.items(): + if name == recipient: + other_client.msg_queue.put_nowait(msg) + return + + for name, channel in all_channels.items(): + if name == recipient: + channel.msg_queue.put_nowait(msg) + return + + raise RuntimeError(f"unknown recipient {recipient}") + + +async def process_channel(channel: Channel) -> None: + while True: + msg = await channel.msg_queue.get() + for client in channel.clients_by_host.values(): + if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")): + continue + client.msg_queue.put_nowait(msg) + + +def assert_registered(client: Client, msg: List[str]) -> None: + if client.registered: + return + raise RuntimeError(f"{msg[0]} command issued before client fully registered") + + +def assert_argc(xs: List[str], i: int) -> None: + if len(xs) == i: + return + raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") + + +def parse_irc_msg(raw_msg: str) -> List[str]: + tokens: List[str] = [] + raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ") + + for i, token in enumerate(raw_tokens): + if token.startswith(":"): + trailing = token[1:] + " ".join(raw_tokens[i + 1 :]) + tokens.append(trailing) + break + tokens.append(token) + + if len(tokens) > 0: + tokens[-1] = tokens[-1].strip() + else: + raise RuntimeError("empty message") + + return tokens + + +async def handle_writer(client: Client) -> None: + while True: + msg = await client.msg_queue.get() + client.writer.write(msg) + await client.writer.drain() + + +async def register_client(reader, writer): + client = Client( + hostname=writer.get_extra_info("peername")[0], + reader=reader, + writer=writer, + msg_queue=asyncio.Queue(), + ) + asyncio.create_task(handle_reader(client)) + asyncio.create_task(handle_writer(client)) + + +async def main(): + bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" + port = os.getenv("PORT") or 6667 + + server = await asyncio.start_server( + register_client, + bind_addr, + port, + reuse_port=True, + ) + + logger.info(f"Listening on {bind_addr}:{port}...") + async with server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main())