import asyncio
|
|
import dataclasses
|
|
import logging
|
|
import os
|
|
from typing import Dict, List, Set
|
|
|
|
clients_by_nick: Dict[str, "Client"] = {}
|
|
all_channels: Dict[str, "Channel"] = {}
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Client:
|
|
hostname: str
|
|
reader: asyncio.StreamReader
|
|
writer: asyncio.StreamWriter
|
|
msg_queue: asyncio.Queue
|
|
|
|
nickname: str = ""
|
|
username: str = ""
|
|
realname: str = ""
|
|
registered: bool = False
|
|
|
|
channels: Set[str] = dataclasses.field(default_factory=set)
|
|
|
|
def id(self) -> str:
|
|
return f"{self.nickname}!{self.username}@{self.hostname}"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Channel:
|
|
name: str
|
|
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
|
|
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
|
|
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
|
|
)
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
async def handle_reader(client: Client) -> None:
|
|
while True:
|
|
raw_msg = await client.reader.readuntil(b"\r\n")
|
|
if not raw_msg.endswith(b"\r\n"):
|
|
raise RuntimeError("malformed message")
|
|
msg = parse_irc_msg(raw_msg)
|
|
await handle_irc_msg(client, msg)
|
|
|
|
|
|
async def handle_irc_msg(client: Client, msg: List[str]) -> None:
|
|
if msg[0] == "CAP":
|
|
# https://ircv3.net/specs/extensions/capability-negotiation.html
|
|
logging.warning("TODO: implement support for client capability negotiation")
|
|
elif msg[0] == "NICK":
|
|
assert_argc(msg, 2)
|
|
if client.nickname:
|
|
del clients_by_nick[client.nickname]
|
|
client.nickname = msg[1]
|
|
clients_by_nick[client.nickname] = client
|
|
if client.username and client.realname:
|
|
client.registered = True
|
|
logging.info(f"{client.hostname} ({client.id()}) registered")
|
|
elif msg[0] == "USER":
|
|
assert_argc(msg, 5)
|
|
if client.registered:
|
|
raise RuntimeError("USER command issued after registration")
|
|
client.username = msg[1]
|
|
client.realname = msg[4]
|
|
if client.nickname:
|
|
client.registered = True
|
|
logging.info(f"{client.hostname} ({client.id()}) registered")
|
|
elif msg[0] == "JOIN":
|
|
assert_argc(msg, 2)
|
|
assert_registered(client, msg)
|
|
if not msg[1].startswith("#"):
|
|
raise RuntimeError("invalid channel name")
|
|
if msg[1] not in all_channels:
|
|
all_channels[msg[1]] = Channel(name=msg[1])
|
|
asyncio.create_task(process_channel(all_channels[msg[1]]))
|
|
channel = all_channels[msg[1]]
|
|
client.channels.add(channel.name)
|
|
channel.clients_by_host[client.hostname] = client
|
|
logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}")
|
|
await channel.msg_queue.put(
|
|
f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8")
|
|
)
|
|
elif msg[0] == "PRIVMSG":
|
|
assert_argc(msg, 3)
|
|
assert_registered(client, msg)
|
|
await privmsg(client, msg[1], msg[2])
|
|
else:
|
|
logging.warning(f"unsupported message {msg[0]}")
|
|
|
|
|
|
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
|
|
msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8")
|
|
|
|
for name, other_client in clients_by_nick.items():
|
|
if name == recipient:
|
|
other_client.msg_queue.put_nowait(msg)
|
|
return
|
|
|
|
for name, channel in all_channels.items():
|
|
if name == recipient:
|
|
channel.msg_queue.put_nowait(msg)
|
|
return
|
|
|
|
raise RuntimeError(f"unknown recipient {recipient}")
|
|
|
|
|
|
async def process_channel(channel: Channel) -> None:
|
|
while True:
|
|
msg = await channel.msg_queue.get()
|
|
for client in channel.clients_by_host.values():
|
|
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
|
|
continue
|
|
client.msg_queue.put_nowait(msg)
|
|
|
|
|
|
def assert_registered(client: Client, msg: List[str]) -> None:
|
|
if client.registered:
|
|
return
|
|
raise RuntimeError(f"{msg[0]} command issued before client fully registered")
|
|
|
|
|
|
def assert_argc(xs: List[str], i: int) -> None:
|
|
if len(xs) == i:
|
|
return
|
|
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
|
|
|
|
|
|
def parse_irc_msg(raw_msg: str) -> List[str]:
|
|
tokens: List[str] = []
|
|
raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ")
|
|
|
|
for i, token in enumerate(raw_tokens):
|
|
if token.startswith(":"):
|
|
trailing = token[1:] + " ".join(raw_tokens[i + 1 :])
|
|
tokens.append(trailing)
|
|
break
|
|
tokens.append(token)
|
|
|
|
if len(tokens) > 0:
|
|
tokens[-1] = tokens[-1].strip()
|
|
else:
|
|
raise RuntimeError("empty message")
|
|
|
|
return tokens
|
|
|
|
|
|
async def handle_writer(client: Client) -> None:
|
|
while True:
|
|
msg = await client.msg_queue.get()
|
|
client.writer.write(msg)
|
|
await client.writer.drain()
|
|
|
|
|
|
async def register_client(reader, writer):
|
|
client = Client(
|
|
hostname=writer.get_extra_info("peername")[0],
|
|
reader=reader,
|
|
writer=writer,
|
|
msg_queue=asyncio.Queue(),
|
|
)
|
|
asyncio.create_task(handle_reader(client))
|
|
asyncio.create_task(handle_writer(client))
|
|
|
|
|
|
async def main():
|
|
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
|
|
port = os.getenv("PORT") or 6667
|
|
|
|
server = await asyncio.start_server(
|
|
register_client,
|
|
bind_addr,
|
|
port,
|
|
reuse_port=True,
|
|
)
|
|
|
|
logger.info(f"Listening on {bind_addr}:{port}...")
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|