import asyncio
|
|
import dataclasses
|
|
import logging
|
|
import os
|
|
from typing import Dict, List, Set
|
|
|
|
from paircd.message import IRCMessage
|
|
|
|
clients_by_nick: Dict[str, "Client"] = {}
|
|
all_channels: Dict[str, "Channel"] = {}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Client:
|
|
hostname: str
|
|
reader: asyncio.StreamReader
|
|
writer: asyncio.StreamWriter
|
|
msg_queue: asyncio.Queue
|
|
|
|
nickname: str = ""
|
|
username: str = ""
|
|
realname: str = ""
|
|
registered: bool = False
|
|
|
|
channels: Set[str] = dataclasses.field(default_factory=set)
|
|
|
|
def id(self) -> str:
|
|
return f"{self.nickname}!{self.username}@{self.hostname}"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Channel:
|
|
name: str
|
|
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
|
|
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
|
|
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
|
|
)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
async def handle_reader(client: Client) -> None:
|
|
while True:
|
|
raw_msg = await client.reader.readuntil(b"\r\n")
|
|
msg = IRCMessage.parse(raw_msg)
|
|
await handle_irc_msg(client, msg)
|
|
|
|
|
|
async def handle_irc_msg(client: Client, msg: IRCMessage) -> None:
|
|
if msg.cmd == "CAP":
|
|
# https://ircv3.net/specs/extensions/capability-negotiation.html
|
|
logging.warning("TODO: implement support for client capability negotiation")
|
|
elif msg.cmd == "NICK":
|
|
if client.nickname:
|
|
del clients_by_nick[client.nickname]
|
|
client.nickname = msg.args[0]
|
|
clients_by_nick[client.nickname] = client
|
|
if client.username and client.realname:
|
|
client.registered = True
|
|
logging.info(f"{client.hostname} ({client.id()}) registered")
|
|
elif msg.cmd == "USER":
|
|
if client.registered:
|
|
raise RuntimeError("USER command issued after registration")
|
|
client.username = msg.args[0]
|
|
client.realname = msg.args[3]
|
|
if client.nickname:
|
|
client.registered = True
|
|
logging.info(f"{client.hostname} ({client.id()}) registered")
|
|
elif msg.cmd == "JOIN":
|
|
channel_name = msg.args[0]
|
|
assert_registered(client, msg)
|
|
if not channel_name.startswith("#"):
|
|
raise RuntimeError("invalid channel name")
|
|
if channel_name not in all_channels:
|
|
all_channels[channel_name] = Channel(name=channel_name)
|
|
asyncio.create_task(process_channel(all_channels[channel_name]))
|
|
client.channels.add(channel_name)
|
|
channel = all_channels[channel_name]
|
|
channel.clients_by_host[client.hostname] = client
|
|
logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}")
|
|
await channel.msg_queue.put(
|
|
IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode()
|
|
)
|
|
elif msg.cmd == "PRIVMSG":
|
|
assert_registered(client, msg)
|
|
await privmsg(client, msg.args[0], msg.args[1])
|
|
else:
|
|
logging.warning(f"unsupported message {msg.cmd}")
|
|
|
|
|
|
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
|
|
msg = IRCMessage(
|
|
"PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname
|
|
).encode()
|
|
|
|
for name, other_client in clients_by_nick.items():
|
|
if name == recipient:
|
|
other_client.msg_queue.put_nowait(msg)
|
|
return
|
|
|
|
for name, channel in all_channels.items():
|
|
if name == recipient:
|
|
channel.msg_queue.put_nowait(msg)
|
|
return
|
|
|
|
raise RuntimeError(f"unknown recipient {recipient}")
|
|
|
|
|
|
async def process_channel(channel: Channel) -> None:
|
|
while True:
|
|
msg = await channel.msg_queue.get()
|
|
for client in channel.clients_by_host.values():
|
|
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
|
|
continue
|
|
client.msg_queue.put_nowait(msg)
|
|
|
|
|
|
def assert_registered(client: Client, msg: IRCMessage) -> None:
|
|
if client.registered:
|
|
return
|
|
raise RuntimeError(f"{msg.cmd} issued before client fully registered")
|
|
|
|
|
|
def assert_argc(xs: List[str], i: int) -> None:
|
|
if len(xs) == i:
|
|
return
|
|
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
|
|
|
|
|
|
async def handle_writer(client: Client) -> None:
|
|
while True:
|
|
msg = await client.msg_queue.get()
|
|
client.writer.write(msg)
|
|
await client.writer.drain()
|
|
|
|
|
|
async def register_client(reader, writer):
|
|
client = Client(
|
|
hostname=writer.get_extra_info("peername")[0],
|
|
reader=reader,
|
|
writer=writer,
|
|
msg_queue=asyncio.Queue(),
|
|
)
|
|
asyncio.create_task(handle_reader(client))
|
|
asyncio.create_task(handle_writer(client))
|
|
|
|
|
|
async def main():
|
|
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
|
|
port = os.getenv("PORT") or 6667
|
|
|
|
server = await asyncio.start_server(
|
|
register_client,
|
|
bind_addr,
|
|
port,
|
|
reuse_port=True,
|
|
)
|
|
|
|
logger.info(f"Listening on {bind_addr}:{port}...")
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|