import asyncio import dataclasses import logging import os from typing import Dict, List, Set clients_by_nick: Dict[str, "Client"] = {} all_channels: Dict[str, "Channel"] = {} @dataclasses.dataclass class Client: hostname: str reader: asyncio.StreamReader writer: asyncio.StreamWriter msg_queue: asyncio.Queue nickname: str = "" username: str = "" realname: str = "" registered: bool = False channels: Set[str] = dataclasses.field(default_factory=set) def id(self) -> str: return f"{self.nickname}!{self.username}@{self.hostname}" @dataclasses.dataclass class Channel: name: str clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict) msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) logging.basicConfig( format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO ) logger = logging.getLogger() async def handle_reader(client: Client) -> None: while True: raw_msg = await client.reader.readuntil(b"\r\n") if not raw_msg.endswith(b"\r\n"): raise RuntimeError("malformed message") msg = parse_irc_msg(raw_msg) await handle_irc_msg(client, msg) async def handle_irc_msg(client: Client, msg: List[str]) -> None: if msg[0] == "CAP": # https://ircv3.net/specs/extensions/capability-negotiation.html logging.warning("TODO: implement support for client capability negotiation") elif msg[0] == "NICK": assert_argc(msg, 2) if client.nickname: del clients_by_nick[client.nickname] client.nickname = msg[1] clients_by_nick[client.nickname] = client if client.username and client.realname: client.registered = True logging.info(f"{client.hostname} ({client.id()}) registered") elif msg[0] == "USER": assert_argc(msg, 5) if client.registered: raise RuntimeError("USER command issued after registration") client.username = msg[1] client.realname = msg[4] if client.nickname: client.registered = True logging.info(f"{client.hostname} ({client.id()}) registered") elif msg[0] == "JOIN": assert_argc(msg, 2) assert_registered(client, msg) if not msg[1].startswith("#"): raise RuntimeError("invalid channel name") if msg[1] not in all_channels: all_channels[msg[1]] = Channel(name=msg[1]) asyncio.create_task(process_channel(all_channels[msg[1]])) channel = all_channels[msg[1]] client.channels.add(channel.name) channel.clients_by_host[client.hostname] = client logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}") await channel.msg_queue.put( f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8") ) elif msg[0] == "PRIVMSG": assert_argc(msg, 3) assert_registered(client, msg) await privmsg(client, msg[1], msg[2]) else: logging.warning(f"unsupported message {msg[0]}") async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8") for name, other_client in clients_by_nick.items(): if name == recipient: other_client.msg_queue.put_nowait(msg) return for name, channel in all_channels.items(): if name == recipient: channel.msg_queue.put_nowait(msg) return raise RuntimeError(f"unknown recipient {recipient}") async def process_channel(channel: Channel) -> None: while True: msg = await channel.msg_queue.get() for client in channel.clients_by_host.values(): if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")): continue client.msg_queue.put_nowait(msg) def assert_registered(client: Client, msg: List[str]) -> None: if client.registered: return raise RuntimeError(f"{msg[0]} command issued before client fully registered") def assert_argc(xs: List[str], i: int) -> None: if len(xs) == i: return raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") def parse_irc_msg(raw_msg: str) -> List[str]: tokens: List[str] = [] raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ") for i, token in enumerate(raw_tokens): if token.startswith(":"): trailing = token[1:] + " ".join(raw_tokens[i + 1 :]) tokens.append(trailing) break tokens.append(token) if len(tokens) > 0: tokens[-1] = tokens[-1].strip() else: raise RuntimeError("empty message") return tokens async def handle_writer(client: Client) -> None: while True: msg = await client.msg_queue.get() client.writer.write(msg) await client.writer.drain() async def register_client(reader, writer): client = Client( hostname=writer.get_extra_info("peername")[0], reader=reader, writer=writer, msg_queue=asyncio.Queue(), ) asyncio.create_task(handle_reader(client)) asyncio.create_task(handle_writer(client)) async def main(): bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" port = os.getenv("PORT") or 6667 server = await asyncio.start_server( register_client, bind_addr, port, reuse_port=True, ) logger.info(f"Listening on {bind_addr}:{port}...") async with server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main())