|
|
- import asyncio
- import dataclasses
- import logging
- import os
- from typing import Dict, List, Set
-
- from paircd.message import IRCMessage
-
- clients_by_nick: Dict[str, "Client"] = {}
- all_channels: Dict[str, "Channel"] = {}
-
-
- @dataclasses.dataclass
- class Client:
- hostname: str
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- msg_queue: asyncio.Queue
-
- nickname: str = ""
- username: str = ""
- realname: str = ""
- registered: bool = False
-
- channels: Set[str] = dataclasses.field(default_factory=set)
-
- def id(self) -> str:
- return f"{self.nickname}!{self.username}@{self.hostname}"
-
-
- @dataclasses.dataclass
- class Channel:
- name: str
- clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
- msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
-
-
- logging.basicConfig(
- format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
- )
-
- logger = logging.getLogger()
-
-
- async def handle_reader(client: Client) -> None:
- while True:
- raw_msg = await client.reader.readuntil(b"\r\n")
- msg = IRCMessage.parse(raw_msg)
- await handle_irc_msg(client, msg)
-
-
- async def handle_irc_msg(client: Client, msg: IRCMessage) -> None:
- if msg.cmd == "CAP":
- # https://ircv3.net/specs/extensions/capability-negotiation.html
- logging.warning("TODO: implement support for client capability negotiation")
- elif msg.cmd == "NICK":
- if client.nickname:
- del clients_by_nick[client.nickname]
- client.nickname = msg.args[0]
- clients_by_nick[client.nickname] = client
- if client.username and client.realname:
- client.registered = True
- logging.info(f"{client.hostname} ({client.id()}) registered")
- elif msg.cmd == "USER":
- if client.registered:
- raise RuntimeError("USER command issued after registration")
- client.username = msg.args[0]
- client.realname = msg.args[3]
- if client.nickname:
- client.registered = True
- logging.info(f"{client.hostname} ({client.id()}) registered")
- elif msg.cmd == "JOIN":
- channel_name = msg.args[0]
- assert_registered(client, msg)
- if not channel_name.startswith("#"):
- raise RuntimeError("invalid channel name")
- if channel_name not in all_channels:
- all_channels[channel_name] = Channel(name=channel_name)
- asyncio.create_task(process_channel(all_channels[channel_name]))
- client.channels.add(channel_name)
- channel = all_channels[channel_name]
- channel.clients_by_host[client.hostname] = client
- logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}")
- await channel.msg_queue.put(
- IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode()
- )
- elif msg.cmd == "PRIVMSG":
- assert_registered(client, msg)
- await privmsg(client, msg.args[0], msg.args[1])
- else:
- logging.warning(f"unsupported message {msg.cmd}")
-
-
- async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
- msg = IRCMessage(
- "PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname
- ).encode()
-
- for name, other_client in clients_by_nick.items():
- if name == recipient:
- other_client.msg_queue.put_nowait(msg)
- return
-
- for name, channel in all_channels.items():
- if name == recipient:
- channel.msg_queue.put_nowait(msg)
- return
-
- raise RuntimeError(f"unknown recipient {recipient}")
-
-
- async def process_channel(channel: Channel) -> None:
- while True:
- msg = await channel.msg_queue.get()
- for client in channel.clients_by_host.values():
- if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
- continue
- client.msg_queue.put_nowait(msg)
-
-
- def assert_registered(client: Client, msg: IRCMessage) -> None:
- if client.registered:
- return
- raise RuntimeError(f"{msg.cmd} issued before client fully registered")
-
-
- def assert_argc(xs: List[str], i: int) -> None:
- if len(xs) == i:
- return
- raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
-
-
- async def handle_writer(client: Client) -> None:
- while True:
- msg = await client.msg_queue.get()
- client.writer.write(msg)
- await client.writer.drain()
-
-
- async def register_client(reader, writer):
- client = Client(
- hostname=writer.get_extra_info("peername")[0],
- reader=reader,
- writer=writer,
- msg_queue=asyncio.Queue(),
- )
- asyncio.create_task(handle_reader(client))
- asyncio.create_task(handle_writer(client))
-
-
- async def main():
- bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
- port = os.getenv("PORT") or 6667
-
- server = await asyncio.start_server(
- register_client,
- bind_addr,
- port,
- reuse_port=True,
- )
-
- logger.info(f"Listening on {bind_addr}:{port}...")
- async with server:
- await server.serve_forever()
-
-
- if __name__ == "__main__":
- asyncio.run(main())
|