|
|
@ -0,0 +1,188 @@ |
|
|
|
import asyncio |
|
|
|
import dataclasses |
|
|
|
import logging |
|
|
|
import os |
|
|
|
from typing import Dict, List, Set |
|
|
|
|
|
|
|
clients_by_nick: Dict[str, "Client"] = {} |
|
|
|
all_channels: Dict[str, "Channel"] = {} |
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
|
class Client: |
|
|
|
hostname: str |
|
|
|
reader: asyncio.StreamReader |
|
|
|
writer: asyncio.StreamWriter |
|
|
|
msg_queue: asyncio.Queue |
|
|
|
|
|
|
|
nickname: str = "" |
|
|
|
username: str = "" |
|
|
|
realname: str = "" |
|
|
|
registered: bool = False |
|
|
|
|
|
|
|
channels: Set[str] = dataclasses.field(default_factory=set) |
|
|
|
|
|
|
|
def id(self) -> str: |
|
|
|
return f"{self.nickname}!{self.username}@{self.hostname}" |
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
|
class Channel: |
|
|
|
name: str |
|
|
|
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict) |
|
|
|
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
|
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO |
|
|
|
) |
|
|
|
|
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
|
|
|
|
async def handle_reader(client: Client) -> None: |
|
|
|
while True: |
|
|
|
raw_msg = await client.reader.readuntil(b"\r\n") |
|
|
|
if not raw_msg.endswith(b"\r\n"): |
|
|
|
raise RuntimeError("malformed message") |
|
|
|
msg = parse_irc_msg(raw_msg) |
|
|
|
await handle_irc_msg(client, msg) |
|
|
|
|
|
|
|
|
|
|
|
async def handle_irc_msg(client: Client, msg: List[str]) -> None: |
|
|
|
if msg[0] == "CAP": |
|
|
|
# https://ircv3.net/specs/extensions/capability-negotiation.html |
|
|
|
logging.warning("TODO: implement support for client capability negotiation") |
|
|
|
elif msg[0] == "NICK": |
|
|
|
assert_argc(msg, 2) |
|
|
|
if client.nickname: |
|
|
|
del clients_by_nick[client.nickname] |
|
|
|
client.nickname = msg[1] |
|
|
|
clients_by_nick[client.nickname] = client |
|
|
|
if client.username and client.realname: |
|
|
|
client.registered = True |
|
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
|
elif msg[0] == "USER": |
|
|
|
assert_argc(msg, 5) |
|
|
|
if client.registered: |
|
|
|
raise RuntimeError("USER command issued after registration") |
|
|
|
client.username = msg[1] |
|
|
|
client.realname = msg[4] |
|
|
|
if client.nickname: |
|
|
|
client.registered = True |
|
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
|
elif msg[0] == "JOIN": |
|
|
|
assert_argc(msg, 2) |
|
|
|
assert_registered(client, msg) |
|
|
|
if not msg[1].startswith("#"): |
|
|
|
raise RuntimeError("invalid channel name") |
|
|
|
if msg[1] not in all_channels: |
|
|
|
all_channels[msg[1]] = Channel(name=msg[1]) |
|
|
|
asyncio.create_task(process_channel(all_channels[msg[1]])) |
|
|
|
channel = all_channels[msg[1]] |
|
|
|
client.channels.add(channel.name) |
|
|
|
channel.clients_by_host[client.hostname] = client |
|
|
|
logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}") |
|
|
|
await channel.msg_queue.put( |
|
|
|
f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8") |
|
|
|
) |
|
|
|
elif msg[0] == "PRIVMSG": |
|
|
|
assert_argc(msg, 3) |
|
|
|
assert_registered(client, msg) |
|
|
|
await privmsg(client, msg[1], msg[2]) |
|
|
|
else: |
|
|
|
logging.warning(f"unsupported message {msg[0]}") |
|
|
|
|
|
|
|
|
|
|
|
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: |
|
|
|
msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8") |
|
|
|
|
|
|
|
for name, other_client in clients_by_nick.items(): |
|
|
|
if name == recipient: |
|
|
|
other_client.msg_queue.put_nowait(msg) |
|
|
|
return |
|
|
|
|
|
|
|
for name, channel in all_channels.items(): |
|
|
|
if name == recipient: |
|
|
|
channel.msg_queue.put_nowait(msg) |
|
|
|
return |
|
|
|
|
|
|
|
raise RuntimeError(f"unknown recipient {recipient}") |
|
|
|
|
|
|
|
|
|
|
|
async def process_channel(channel: Channel) -> None: |
|
|
|
while True: |
|
|
|
msg = await channel.msg_queue.get() |
|
|
|
for client in channel.clients_by_host.values(): |
|
|
|
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")): |
|
|
|
continue |
|
|
|
client.msg_queue.put_nowait(msg) |
|
|
|
|
|
|
|
|
|
|
|
def assert_registered(client: Client, msg: List[str]) -> None: |
|
|
|
if client.registered: |
|
|
|
return |
|
|
|
raise RuntimeError(f"{msg[0]} command issued before client fully registered") |
|
|
|
|
|
|
|
|
|
|
|
def assert_argc(xs: List[str], i: int) -> None: |
|
|
|
if len(xs) == i: |
|
|
|
return |
|
|
|
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") |
|
|
|
|
|
|
|
|
|
|
|
def parse_irc_msg(raw_msg: str) -> List[str]: |
|
|
|
tokens: List[str] = [] |
|
|
|
raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ") |
|
|
|
|
|
|
|
for i, token in enumerate(raw_tokens): |
|
|
|
if token.startswith(":"): |
|
|
|
trailing = token[1:] + " ".join(raw_tokens[i + 1 :]) |
|
|
|
tokens.append(trailing) |
|
|
|
break |
|
|
|
tokens.append(token) |
|
|
|
|
|
|
|
if len(tokens) > 0: |
|
|
|
tokens[-1] = tokens[-1].strip() |
|
|
|
else: |
|
|
|
raise RuntimeError("empty message") |
|
|
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
async def handle_writer(client: Client) -> None: |
|
|
|
while True: |
|
|
|
msg = await client.msg_queue.get() |
|
|
|
client.writer.write(msg) |
|
|
|
await client.writer.drain() |
|
|
|
|
|
|
|
|
|
|
|
async def register_client(reader, writer): |
|
|
|
client = Client( |
|
|
|
hostname=writer.get_extra_info("peername")[0], |
|
|
|
reader=reader, |
|
|
|
writer=writer, |
|
|
|
msg_queue=asyncio.Queue(), |
|
|
|
) |
|
|
|
asyncio.create_task(handle_reader(client)) |
|
|
|
asyncio.create_task(handle_writer(client)) |
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
|
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" |
|
|
|
port = os.getenv("PORT") or 6667 |
|
|
|
|
|
|
|
server = await asyncio.start_server( |
|
|
|
register_client, |
|
|
|
bind_addr, |
|
|
|
port, |
|
|
|
reuse_port=True, |
|
|
|
) |
|
|
|
|
|
|
|
logger.info(f"Listening on {bind_addr}:{port}...") |
|
|
|
async with server: |
|
|
|
await server.serve_forever() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
asyncio.run(main()) |