diff --git a/paircd/channel.py b/paircd/channel.py new file mode 100644 index 0000000..5a06301 --- /dev/null +++ b/paircd/channel.py @@ -0,0 +1,24 @@ +from asyncio import Queue +from dataclasses import dataclass, field +from typing import Dict + +from paircd.client import Client + + +@dataclass +class Channel: + name: str + clients_by_nick: Dict[str, Client] = field(default_factory=dict) + msg_queue: Queue = field(default_factory=Queue) + + def add_client(self, client: Client) -> None: + self.clients_by_nick[client.nickname] = client + + async def process(self) -> None: + while True: + msg = await self.msg_queue.get() + for client in self.clients_by_nick.values(): + # Don't broadcast client's messages back to themselves + if msg.startswith(f":{client.id()} PRIVMSG".encode("utf-8")): + continue + client.msg_queue.put_nowait(msg) diff --git a/paircd/client.py b/paircd/client.py new file mode 100644 index 0000000..70a15eb --- /dev/null +++ b/paircd/client.py @@ -0,0 +1,27 @@ +from asyncio import StreamReader, StreamWriter, Queue +from dataclasses import dataclass, field +from typing import Any, Set + + +@dataclass +class Client: + hostname: str + reader: StreamReader + writer: StreamWriter + msg_queue: Queue = field(default_factory=Queue) + + nickname: str = "" + username: str = "" + realname: str = "" + registered: bool = False + + channels: Set[str] = field(default_factory=set) + + def id(self) -> str: + return f"{self.nickname}!{self.username}@{self.hostname}" + + async def write_forever(self) -> None: + while True: + msg = await self.msg_queue.get() + self.writer.write(msg) + await self.writer.drain() diff --git a/paircd/handle.py b/paircd/handle.py new file mode 100644 index 0000000..a4ffbde --- /dev/null +++ b/paircd/handle.py @@ -0,0 +1,21 @@ +import logging +from typing import Any, Dict + +from paircd.client import Client +from paircd.message import Message +from paircd.server import Server + +CMD_HANDLERS: Dict[str, Any] = {} +CMD_EXPECTED_ARGC: Dict[str, int] = {} + + +def register_cmd_handler(cmd: str, argc: int, handler) -> None: + CMD_EXPECTED_ARGC[cmd] = argc + CMD_HANDLERS[cmd] = handler + + +async def handle_cmd(server: Server, client: Client, msg: Message) -> None: + if msg.cmd not in CMD_HANDLERS: + logging.warning(f"Unknown command: {msg.cmd}") + return + await CMD_HANDLERS[msg.cmd](server, client, msg) diff --git a/paircd/handler/__init__.py b/paircd/handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/paircd/handler/join.py b/paircd/handler/join.py new file mode 100644 index 0000000..4e11a52 --- /dev/null +++ b/paircd/handler/join.py @@ -0,0 +1,26 @@ +from asyncio import create_task +import logging + +from paircd.channel import Channel +from paircd.client import Client +from paircd.message import Message +from paircd.server import Server + + +async def handle_join(server: Server, client: Client, msg: Message) -> None: + channel_name = msg.args[0] + if not client.registered: + raise RuntimeError("JOIN: not registered") + + if not channel_name.startswith("#"): + raise RuntimeError("invalid channel name") + + channel = server.get_channel_by_name(channel_name) + channel.add_client(client) + + client.channels.add(channel_name) + logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}") + + await channel.msg_queue.put( + Message(cmd="JOIN", args=[channel_name], prefix=client.id()).encode() + ) diff --git a/paircd/handler/nick.py b/paircd/handler/nick.py new file mode 100644 index 0000000..851eec3 --- /dev/null +++ b/paircd/handler/nick.py @@ -0,0 +1,20 @@ +import logging + +from paircd.client import Client +from paircd.message import Message +from paircd.server import Server + + +async def handle_nick(server: Server, client: Client, msg: Message) -> None: + nickname = msg.args[0] + + if client.nickname: + del server.clients_by_nick[client.nickname] + # TODO: Update all channel references + + client.nickname = nickname + server.add_client(client) + + if client.username and client.realname: + client.registered = True + logging.info(f"{client.hostname} ({client.id()}) registered") diff --git a/paircd/handler/privmsg.py b/paircd/handler/privmsg.py new file mode 100644 index 0000000..bf40000 --- /dev/null +++ b/paircd/handler/privmsg.py @@ -0,0 +1,22 @@ +from paircd.client import Client +from paircd.message import Message +from paircd.server import Server + + +async def handle_privmsg(server: Server, client: Client, msg: Message) -> None: + recipient = msg.args[0] + raw_msg = msg.args[1] + + msg = Message("PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.id()).encode() + + for name, other_client in server.clients_by_nick.items(): + if name == recipient: + other_client.msg_queue.put_nowait(msg) + return + + for name, channel in server.channels_by_name.items(): + if name == recipient: + channel.msg_queue.put_nowait(msg) + return + + raise RuntimeError(f"Unknown recipient {recipient}") diff --git a/paircd/handler/user.py b/paircd/handler/user.py new file mode 100644 index 0000000..1debfc0 --- /dev/null +++ b/paircd/handler/user.py @@ -0,0 +1,17 @@ +import logging + +from paircd.client import Client +from paircd.message import Message +from paircd.server import Server + + +async def handle_user(server: Server, client: Client, msg: Message) -> None: + if client.registered: + raise RuntimeError("USER command issued after registration") + + client.username = msg.args[0] + client.realname = msg.args[3] + + if client.nickname: + client.registered = True + logging.info(f"{client.hostname} ({client.id()}) registered") diff --git a/paircd/main.py b/paircd/main.py index 854e076..2a26710 100644 --- a/paircd/main.py +++ b/paircd/main.py @@ -1,39 +1,15 @@ import asyncio -import dataclasses import logging import os -from typing import Dict, List, Set - -from paircd.message import IRCMessage - -clients_by_nick: Dict[str, "Client"] = {} -all_channels: Dict[str, "Channel"] = {} - - -@dataclasses.dataclass -class Client: - hostname: str - reader: asyncio.StreamReader - writer: asyncio.StreamWriter - msg_queue: asyncio.Queue - - nickname: str = "" - username: str = "" - realname: str = "" - registered: bool = False - - channels: Set[str] = dataclasses.field(default_factory=set) - - def id(self) -> str: - return f"{self.nickname}!{self.username}@{self.hostname}" - - -@dataclasses.dataclass -class Channel: - name: str - clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict) - msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) +from paircd.client import Client +from paircd.handle import handle_cmd, register_cmd_handler +from paircd.handler.join import handle_join +from paircd.handler.nick import handle_nick +from paircd.handler.privmsg import handle_privmsg +from paircd.handler.user import handle_user +from paircd.message import parse_message +from paircd.server import Server logging.basicConfig( format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO @@ -42,116 +18,33 @@ logging.basicConfig( logger = logging.getLogger() -async def handle_reader(client: Client) -> None: +async def read_forever(server: Server, client: Client) -> None: while True: raw_msg = await client.reader.readuntil(b"\r\n") - msg = IRCMessage.parse(raw_msg) - await handle_irc_msg(client, msg) - - -async def handle_irc_msg(client: Client, msg: IRCMessage) -> None: - if msg.cmd == "CAP": - # https://ircv3.net/specs/extensions/capability-negotiation.html - logging.warning("TODO: implement support for client capability negotiation") - elif msg.cmd == "NICK": - if client.nickname: - del clients_by_nick[client.nickname] - client.nickname = msg.args[0] - clients_by_nick[client.nickname] = client - if client.username and client.realname: - client.registered = True - logging.info(f"{client.hostname} ({client.id()}) registered") - elif msg.cmd == "USER": - if client.registered: - raise RuntimeError("USER command issued after registration") - client.username = msg.args[0] - client.realname = msg.args[3] - if client.nickname: - client.registered = True - logging.info(f"{client.hostname} ({client.id()}) registered") - elif msg.cmd == "JOIN": - channel_name = msg.args[0] - assert_registered(client, msg) - if not channel_name.startswith("#"): - raise RuntimeError("invalid channel name") - if channel_name not in all_channels: - all_channels[channel_name] = Channel(name=channel_name) - asyncio.create_task(process_channel(all_channels[channel_name])) - client.channels.add(channel_name) - channel = all_channels[channel_name] - channel.clients_by_host[client.hostname] = client - logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}") - await channel.msg_queue.put( - IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode() - ) - elif msg.cmd == "PRIVMSG": - assert_registered(client, msg) - await privmsg(client, msg.args[0], msg.args[1]) - else: - logging.warning(f"unsupported message {msg.cmd}") - - -async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: - msg = IRCMessage( - "PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname - ).encode() - - for name, other_client in clients_by_nick.items(): - if name == recipient: - other_client.msg_queue.put_nowait(msg) - return - - for name, channel in all_channels.items(): - if name == recipient: - channel.msg_queue.put_nowait(msg) - return - - raise RuntimeError(f"unknown recipient {recipient}") - - -async def process_channel(channel: Channel) -> None: - while True: - msg = await channel.msg_queue.get() - for client in channel.clients_by_host.values(): - if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")): - continue - client.msg_queue.put_nowait(msg) - - -def assert_registered(client: Client, msg: IRCMessage) -> None: - if client.registered: - return - raise RuntimeError(f"{msg.cmd} issued before client fully registered") - - -def assert_argc(xs: List[str], i: int) -> None: - if len(xs) == i: - return - raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") - - -async def handle_writer(client: Client) -> None: - while True: - msg = await client.msg_queue.get() - client.writer.write(msg) - await client.writer.drain() - - -async def register_client(reader, writer): - client = Client( - hostname=writer.get_extra_info("peername")[0], - reader=reader, - writer=writer, - msg_queue=asyncio.Queue(), - ) - asyncio.create_task(handle_reader(client)) - asyncio.create_task(handle_writer(client)) + msg = parse_message(raw_msg) + await handle_cmd(server, client, msg) async def main(): bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" port = os.getenv("PORT") or 6667 + register_cmd_handler("JOIN", 1, handle_join) + register_cmd_handler("NICK", 1, handle_nick) + register_cmd_handler("PRIVMSG", 2, handle_privmsg) + register_cmd_handler("USER", 4, handle_user) + + irc_server = Server() + + async def register_client(reader, writer): + client = Client( + hostname=writer.get_extra_info("peername")[0], + reader=reader, + writer=writer, + ) + asyncio.create_task(read_forever(irc_server, client)) + asyncio.create_task(client.write_forever()) + server = await asyncio.start_server( register_client, bind_addr, diff --git a/paircd/message.py b/paircd/message.py index b52159b..8808329 100644 --- a/paircd/message.py +++ b/paircd/message.py @@ -11,7 +11,7 @@ EXPECTED_ARG_COUNT = { } -class IRCParsingError(Exception): +class ParsingError(Exception): message: str def __init__(self, message: Optional[str] = None) -> None: @@ -19,7 +19,7 @@ class IRCParsingError(Exception): @dataclass -class IRCMessage: +class Message: cmd: str args: List[str] prefix: str = "" @@ -32,34 +32,34 @@ class IRCMessage: # TODO: Raise exception if formatted message exceeds 512 bytes return f"{prefix}{self.cmd} {' '.join(self.args)}\r\n".encode("utf-8") - @staticmethod - def parse(raw: bytes) -> "IRCMessage": - if len(raw) > MAX_MESSAGE_SIZE: - raise IRCParsingError( - f"Message is {len(raw)} bytes, larger than allowed {MAX_MESSAGE_SIZE}" - ) - - if not raw.endswith(b"\r\n"): - raise IRCParsingError("Message does not terminate in CRLF") - - tokens: List[str] = [] - raw_tokens: List[str] = raw.decode("utf-8").split(" ") - - for i, token in enumerate(raw_tokens): - if token.startswith(":"): - trailing = token[1:] + " " + " ".join(raw_tokens[i + 1 :]) - tokens.append(trailing) - break - tokens.append(token) - - if len(tokens) == 0: - raise IRCParsingError("Message has no command") - - cmd = tokens[0] - if cmd in EXPECTED_ARG_COUNT and EXPECTED_ARG_COUNT[cmd] != len(tokens) - 1: - raise IRCParsingError( - f"{cmd} had {len(tokens)-1} arguments, expected {EXPECTED_ARG_COUNT[cmd]}" - ) - - tokens[-1] = tokens[-1].strip() - return IRCMessage(cmd=tokens[0], args=tokens[1:]) + +def parse_message(raw: bytes) -> Message: + if len(raw) > MAX_MESSAGE_SIZE: + raise ParsingError( + f"Message is {len(raw)} bytes, larger than allowed {MAX_MESSAGE_SIZE}" + ) + + if not raw.endswith(b"\r\n"): + raise ParsingError("Message does not terminate in CRLF") + + tokens: List[str] = [] + raw_tokens: List[str] = raw.decode("utf-8").split(" ") + + for i, token in enumerate(raw_tokens): + if token.startswith(":"): + trailing = token[1:] + " " + " ".join(raw_tokens[i + 1 :]) + tokens.append(trailing) + break + tokens.append(token) + + if len(tokens) == 0: + raise ParsingError("Message has no command") + + cmd = tokens[0] + if cmd in EXPECTED_ARG_COUNT and EXPECTED_ARG_COUNT[cmd] != len(tokens) - 1: + raise ParsingError( + f"{cmd} had {len(tokens)-1} arguments, expected {EXPECTED_ARG_COUNT[cmd]}" + ) + + tokens[-1] = tokens[-1].strip() + return Message(cmd=tokens[0], args=tokens[1:]) diff --git a/paircd/server.py b/paircd/server.py new file mode 100644 index 0000000..ec86e89 --- /dev/null +++ b/paircd/server.py @@ -0,0 +1,21 @@ +from asyncio import create_task +from dataclasses import dataclass, field +from typing import Dict + +from paircd.client import Client +from paircd.channel import Channel + + +@dataclass +class Server: + clients_by_nick: Dict[str, Client] = field(default_factory=dict) + channels_by_name: Dict[str, Channel] = field(default_factory=dict) + + def add_client(self, client: Client) -> None: + self.clients_by_nick[client.nickname] = client + + def get_channel_by_name(self, name: str) -> Channel: + if name not in self.channels_by_name: + self.channels_by_name[name] = Channel(name=name) + create_task(self.channels_by_name[name].process()) + return self.channels_by_name[name]