python ircd using asyncio
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

168 lines
5.0 KiB

import asyncio
import dataclasses
import logging
import os
from typing import Dict, List, Set
from paircd.message import IRCMessage
clients_by_nick: Dict[str, "Client"] = {}
all_channels: Dict[str, "Channel"] = {}
@dataclasses.dataclass
class Client:
hostname: str
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
msg_queue: asyncio.Queue
nickname: str = ""
username: str = ""
realname: str = ""
registered: bool = False
channels: Set[str] = dataclasses.field(default_factory=set)
def id(self) -> str:
return f"{self.nickname}!{self.username}@{self.hostname}"
@dataclasses.dataclass
class Channel:
name: str
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
logging.basicConfig(
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
)
logger = logging.getLogger()
async def handle_reader(client: Client) -> None:
while True:
raw_msg = await client.reader.readuntil(b"\r\n")
msg = IRCMessage.parse(raw_msg)
await handle_irc_msg(client, msg)
async def handle_irc_msg(client: Client, msg: IRCMessage) -> None:
if msg.cmd == "CAP":
# https://ircv3.net/specs/extensions/capability-negotiation.html
logging.warning("TODO: implement support for client capability negotiation")
elif msg.cmd == "NICK":
if client.nickname:
del clients_by_nick[client.nickname]
client.nickname = msg.args[0]
clients_by_nick[client.nickname] = client
if client.username and client.realname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")
elif msg.cmd == "USER":
if client.registered:
raise RuntimeError("USER command issued after registration")
client.username = msg.args[0]
client.realname = msg.args[3]
if client.nickname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")
elif msg.cmd == "JOIN":
channel_name = msg.args[0]
assert_registered(client, msg)
if not channel_name.startswith("#"):
raise RuntimeError("invalid channel name")
if channel_name not in all_channels:
all_channels[channel_name] = Channel(name=channel_name)
asyncio.create_task(process_channel(all_channels[channel_name]))
client.channels.add(channel_name)
channel = all_channels[channel_name]
channel.clients_by_host[client.hostname] = client
logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}")
await channel.msg_queue.put(
IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode()
)
elif msg.cmd == "PRIVMSG":
assert_registered(client, msg)
await privmsg(client, msg.args[0], msg.args[1])
else:
logging.warning(f"unsupported message {msg.cmd}")
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
msg = IRCMessage(
"PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname
).encode()
for name, other_client in clients_by_nick.items():
if name == recipient:
other_client.msg_queue.put_nowait(msg)
return
for name, channel in all_channels.items():
if name == recipient:
channel.msg_queue.put_nowait(msg)
return
raise RuntimeError(f"unknown recipient {recipient}")
async def process_channel(channel: Channel) -> None:
while True:
msg = await channel.msg_queue.get()
for client in channel.clients_by_host.values():
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
continue
client.msg_queue.put_nowait(msg)
def assert_registered(client: Client, msg: IRCMessage) -> None:
if client.registered:
return
raise RuntimeError(f"{msg.cmd} issued before client fully registered")
def assert_argc(xs: List[str], i: int) -> None:
if len(xs) == i:
return
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
async def handle_writer(client: Client) -> None:
while True:
msg = await client.msg_queue.get()
client.writer.write(msg)
await client.writer.drain()
async def register_client(reader, writer):
client = Client(
hostname=writer.get_extra_info("peername")[0],
reader=reader,
writer=writer,
msg_queue=asyncio.Queue(),
)
asyncio.create_task(handle_reader(client))
asyncio.create_task(handle_writer(client))
async def main():
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
port = os.getenv("PORT") or 6667
server = await asyncio.start_server(
register_client,
bind_addr,
port,
reuse_port=True,
)
logger.info(f"Listening on {bind_addr}:{port}...")
async with server:
await server.serve_forever()
if __name__ == "__main__":
asyncio.run(main())