import asyncio import dataclasses import logging import os from typing import Dict, List, Set from paircd.message import IRCMessage clients_by_nick: Dict[str, "Client"] = {} all_channels: Dict[str, "Channel"] = {} @dataclasses.dataclass class Client: hostname: str reader: asyncio.StreamReader writer: asyncio.StreamWriter msg_queue: asyncio.Queue nickname: str = "" username: str = "" realname: str = "" registered: bool = False channels: Set[str] = dataclasses.field(default_factory=set) def id(self) -> str: return f"{self.nickname}!{self.username}@{self.hostname}" @dataclasses.dataclass class Channel: name: str clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict) msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) logging.basicConfig( format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO ) logger = logging.getLogger() async def handle_reader(client: Client) -> None: while True: raw_msg = await client.reader.readuntil(b"\r\n") msg = IRCMessage.parse(raw_msg) await handle_irc_msg(client, msg) async def handle_irc_msg(client: Client, msg: IRCMessage) -> None: if msg.cmd == "CAP": # https://ircv3.net/specs/extensions/capability-negotiation.html logging.warning("TODO: implement support for client capability negotiation") elif msg.cmd == "NICK": if client.nickname: del clients_by_nick[client.nickname] client.nickname = msg.args[0] clients_by_nick[client.nickname] = client if client.username and client.realname: client.registered = True logging.info(f"{client.hostname} ({client.id()}) registered") elif msg.cmd == "USER": if client.registered: raise RuntimeError("USER command issued after registration") client.username = msg.args[0] client.realname = msg.args[3] if client.nickname: client.registered = True logging.info(f"{client.hostname} ({client.id()}) registered") elif msg.cmd == "JOIN": channel_name = msg.args[0] assert_registered(client, msg) if not channel_name.startswith("#"): raise RuntimeError("invalid channel name") if channel_name not in all_channels: all_channels[channel_name] = Channel(name=channel_name) asyncio.create_task(process_channel(all_channels[channel_name])) client.channels.add(channel_name) channel = all_channels[channel_name] channel.clients_by_host[client.hostname] = client logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}") await channel.msg_queue.put( IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode() ) elif msg.cmd == "PRIVMSG": assert_registered(client, msg) await privmsg(client, msg.args[0], msg.args[1]) else: logging.warning(f"unsupported message {msg.cmd}") async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: msg = IRCMessage( "PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname ).encode() for name, other_client in clients_by_nick.items(): if name == recipient: other_client.msg_queue.put_nowait(msg) return for name, channel in all_channels.items(): if name == recipient: channel.msg_queue.put_nowait(msg) return raise RuntimeError(f"unknown recipient {recipient}") async def process_channel(channel: Channel) -> None: while True: msg = await channel.msg_queue.get() for client in channel.clients_by_host.values(): if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")): continue client.msg_queue.put_nowait(msg) def assert_registered(client: Client, msg: IRCMessage) -> None: if client.registered: return raise RuntimeError(f"{msg.cmd} issued before client fully registered") def assert_argc(xs: List[str], i: int) -> None: if len(xs) == i: return raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") async def handle_writer(client: Client) -> None: while True: msg = await client.msg_queue.get() client.writer.write(msg) await client.writer.drain() async def register_client(reader, writer): client = Client( hostname=writer.get_extra_info("peername")[0], reader=reader, writer=writer, msg_queue=asyncio.Queue(), ) asyncio.create_task(handle_reader(client)) asyncio.create_task(handle_writer(client)) async def main(): bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" port = os.getenv("PORT") or 6667 server = await asyncio.start_server( register_client, bind_addr, port, reuse_port=True, ) logger.info(f"Listening on {bind_addr}:{port}...") async with server: await server.serve_forever() if __name__ == "__main__": asyncio.run(main())