diff --git a/paircd/command_handler.py b/paircd/command_handler.py index e25912d..6fad9f8 100644 --- a/paircd/command_handler.py +++ b/paircd/command_handler.py @@ -1,18 +1,18 @@ -from dataclasses import dataclass -from typing import Awaitable, Callable +from abc import abstractmethod, ABC from paircd.client import Client from paircd.message import Message from paircd.server import Server -HandlerFunc = Callable[ - [Server, Client, Message], - Awaitable[None], -] - -@dataclass -class CommandHandler: +class CommandHandler(ABC): cmd: str argc: int - handler: HandlerFunc + + def __init__(self, cmd: str, argc: int) -> None: + self.cmd = cmd + self.argc = argc + + @abstractmethod + async def handle(self, server: Server, client: Client, msg: Message) -> None: + pass diff --git a/paircd/handler/join.py b/paircd/handler/join.py index 283cef3..36693d4 100644 --- a/paircd/handler/join.py +++ b/paircd/handler/join.py @@ -7,25 +7,26 @@ from paircd.message import Message from paircd.server import Server -async def handle_join(server: Server, client: Client, msg: Message) -> None: - if not client.registered: - log_client(client, "join: not registered", level=logging.WARN) - return +class JoinHandler(CommandHandler): + def __init__(self) -> None: + super().__init__("JOIN", 1) - channel_name = msg.args[0] - if not channel_name.startswith("#"): - log_client(client, "tried to join invalid channel", level=logging.WARN) - return + async def handle(self, server: Server, client: Client, msg: Message) -> None: + if not client.registered: + log_client(client, "join: not registered", level=logging.WARN) + return - channel = server.get_channel_by_name(channel_name) - channel.add_client(client) + channel_name = msg.args[0] + if not channel_name.startswith("#"): + log_client(client, "tried to join invalid channel", level=logging.WARN) + return - client.channels.add(channel_name) - log_client(client, f"joined {channel_name}") + channel = server.get_channel_by_name(channel_name) + channel.add_client(client) - await channel.msg_queue.put( - Message(cmd="JOIN", args=[channel_name], prefix=client.id()).encode() - ) + client.channels.add(channel_name) + log_client(client, f"joined {channel_name}") - -JOIN = CommandHandler("JOIN", 1, handle_join) + await channel.msg_queue.put( + Message(cmd="JOIN", args=[channel_name], prefix=client.id()).encode() + ) diff --git a/paircd/handler/nick.py b/paircd/handler/nick.py index 0a6dba9..b72f432 100644 --- a/paircd/handler/nick.py +++ b/paircd/handler/nick.py @@ -5,19 +5,20 @@ from paircd.message import Message from paircd.server import Server -async def handle_nick(server: Server, client: Client, msg: Message) -> None: - nickname = msg.args[0] +class NickHandler(CommandHandler): + def __init__(self) -> None: + super().__init__("NICK", 1) - if client.nickname: - del server.clients_by_nick[client.nickname] - # TODO: Update all channel references + async def handle(self, server: Server, client: Client, msg: Message) -> None: + nickname = msg.args[0] - client.nickname = nickname - server.add_client(client) + if client.nickname: + del server.clients_by_nick[client.nickname] + # TODO: Update all channel references - if client.username and client.realname: - client.registered = True - log_client(client, "registered") + client.nickname = nickname + server.add_client(client) - -NICK = CommandHandler("NICK", 1, handle_nick) + if client.username and client.realname: + client.registered = True + log_client(client, "registered") diff --git a/paircd/handler/privmsg.py b/paircd/handler/privmsg.py index 2818576..26eb847 100644 --- a/paircd/handler/privmsg.py +++ b/paircd/handler/privmsg.py @@ -7,23 +7,26 @@ from paircd.message import Message from paircd.server import Server -async def handle_privmsg(server: Server, client: Client, msg: Message) -> None: - recipient = msg.args[0] - raw_msg = msg.args[1] +class PrivmsgHandler(CommandHandler): + def __init__(self) -> None: + super().__init__("PRIVMSG", 2) - out = Message("PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.id()).encode() + async def handle(self, server: Server, client: Client, msg: Message) -> None: + recipient = msg.args[0] + raw_msg = msg.args[1] - for name, other_client in server.clients_by_nick.items(): - if name == recipient: - other_client.msg_queue.put_nowait(out) - return + out = Message( + "PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.id() + ).encode() - for name, channel in server.channels_by_name.items(): - if name == recipient: - channel.msg_queue.put_nowait(out) - return + for name, other_client in server.clients_by_nick.items(): + if name == recipient: + other_client.msg_queue.put_nowait(out) + return - log_client(client, "unknown recipient", level=logging.WARN) + for name, channel in server.channels_by_name.items(): + if name == recipient: + channel.msg_queue.put_nowait(out) + return - -PRIVMSG = CommandHandler("PRIVMSG", 2, handle_privmsg) + log_client(client, "unknown recipient", level=logging.WARN) diff --git a/paircd/handler/user.py b/paircd/handler/user.py index 482b883..4c6a311 100644 --- a/paircd/handler/user.py +++ b/paircd/handler/user.py @@ -7,17 +7,18 @@ from paircd.message import Message from paircd.server import Server -async def handle_user(server: Server, client: Client, msg: Message) -> None: - if client.registered: - log_client(client, "USER issued after registration", level=logging.WARN) - return +class UserHandler(CommandHandler): + def __init__(self) -> None: + super().__init__("USER", 4) - client.username = msg.args[0] - client.realname = msg.args[3] + async def handle(self, server: Server, client: Client, msg: Message) -> None: + if client.registered: + log_client(client, "USER issued after registration", level=logging.WARN) + return - if client.nickname: - client.registered = True - log_client(client, "registered") + client.username = msg.args[0] + client.realname = msg.args[3] - -USER = CommandHandler("USER", 4, handle_user) + if client.nickname: + client.registered = True + log_client(client, "registered") diff --git a/paircd/handlers.py b/paircd/handlers.py index 29a7340..3023af3 100644 --- a/paircd/handlers.py +++ b/paircd/handlers.py @@ -7,23 +7,24 @@ from paircd.log import log_client from paircd.message import Message from paircd.server import Server -from paircd.handler.join import JOIN -from paircd.handler.nick import NICK -from paircd.handler.privmsg import PRIVMSG -from paircd.handler.user import USER - -ALL_HANDLERS = [ - JOIN, - NICK, - PRIVMSG, - USER, +from paircd.handler.join import JoinHandler +from paircd.handler.nick import NickHandler +from paircd.handler.privmsg import PrivmsgHandler +from paircd.handler.user import UserHandler + +HANDLER_CLASSES = [ + JoinHandler, + NickHandler, + PrivmsgHandler, + UserHandler, ] CMD_HANDLERS: Dict[str, CommandHandler] = {} def register_cmd_handlers() -> None: - for handler in ALL_HANDLERS: + for handler_cls in HANDLER_CLASSES: + handler = handler_cls() # type: ignore CMD_HANDLERS[handler.cmd] = handler @@ -31,4 +32,4 @@ async def handle_cmd(server: Server, client: Client, msg: Message) -> None: if msg.cmd not in CMD_HANDLERS: log_client(client, f"used unknown command {msg.cmd}", level=WARN) return - await CMD_HANDLERS[msg.cmd].handler(server, client, msg) + await CMD_HANDLERS[msg.cmd].handle(server, client, msg)