|
@ -1,39 +1,15 @@ |
|
|
import asyncio |
|
|
import asyncio |
|
|
import dataclasses |
|
|
|
|
|
import logging |
|
|
import logging |
|
|
import os |
|
|
import os |
|
|
from typing import Dict, List, Set |
|
|
|
|
|
|
|
|
|
|
|
from paircd.message import IRCMessage |
|
|
|
|
|
|
|
|
|
|
|
clients_by_nick: Dict[str, "Client"] = {} |
|
|
|
|
|
all_channels: Dict[str, "Channel"] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
|
|
|
class Client: |
|
|
|
|
|
hostname: str |
|
|
|
|
|
reader: asyncio.StreamReader |
|
|
|
|
|
writer: asyncio.StreamWriter |
|
|
|
|
|
msg_queue: asyncio.Queue |
|
|
|
|
|
|
|
|
|
|
|
nickname: str = "" |
|
|
|
|
|
username: str = "" |
|
|
|
|
|
realname: str = "" |
|
|
|
|
|
registered: bool = False |
|
|
|
|
|
|
|
|
|
|
|
channels: Set[str] = dataclasses.field(default_factory=set) |
|
|
|
|
|
|
|
|
|
|
|
def id(self) -> str: |
|
|
|
|
|
return f"{self.nickname}!{self.username}@{self.hostname}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
|
|
|
class Channel: |
|
|
|
|
|
name: str |
|
|
|
|
|
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict) |
|
|
|
|
|
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from paircd.client import Client |
|
|
|
|
|
from paircd.handle import handle_cmd, register_cmd_handler |
|
|
|
|
|
from paircd.handler.join import handle_join |
|
|
|
|
|
from paircd.handler.nick import handle_nick |
|
|
|
|
|
from paircd.handler.privmsg import handle_privmsg |
|
|
|
|
|
from paircd.handler.user import handle_user |
|
|
|
|
|
from paircd.message import parse_message |
|
|
|
|
|
from paircd.server import Server |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO |
|
|
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO |
|
@ -42,116 +18,33 @@ logging.basicConfig( |
|
|
logger = logging.getLogger() |
|
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_reader(client: Client) -> None: |
|
|
|
|
|
|
|
|
async def read_forever(server: Server, client: Client) -> None: |
|
|
while True: |
|
|
while True: |
|
|
raw_msg = await client.reader.readuntil(b"\r\n") |
|
|
raw_msg = await client.reader.readuntil(b"\r\n") |
|
|
msg = IRCMessage.parse(raw_msg) |
|
|
|
|
|
await handle_irc_msg(client, msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_irc_msg(client: Client, msg: IRCMessage) -> None: |
|
|
|
|
|
if msg.cmd == "CAP": |
|
|
|
|
|
# https://ircv3.net/specs/extensions/capability-negotiation.html |
|
|
|
|
|
logging.warning("TODO: implement support for client capability negotiation") |
|
|
|
|
|
elif msg.cmd == "NICK": |
|
|
|
|
|
if client.nickname: |
|
|
|
|
|
del clients_by_nick[client.nickname] |
|
|
|
|
|
client.nickname = msg.args[0] |
|
|
|
|
|
clients_by_nick[client.nickname] = client |
|
|
|
|
|
if client.username and client.realname: |
|
|
|
|
|
client.registered = True |
|
|
|
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
|
|
|
elif msg.cmd == "USER": |
|
|
|
|
|
if client.registered: |
|
|
|
|
|
raise RuntimeError("USER command issued after registration") |
|
|
|
|
|
client.username = msg.args[0] |
|
|
|
|
|
client.realname = msg.args[3] |
|
|
|
|
|
if client.nickname: |
|
|
|
|
|
client.registered = True |
|
|
|
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
|
|
|
elif msg.cmd == "JOIN": |
|
|
|
|
|
channel_name = msg.args[0] |
|
|
|
|
|
assert_registered(client, msg) |
|
|
|
|
|
if not channel_name.startswith("#"): |
|
|
|
|
|
raise RuntimeError("invalid channel name") |
|
|
|
|
|
if channel_name not in all_channels: |
|
|
|
|
|
all_channels[channel_name] = Channel(name=channel_name) |
|
|
|
|
|
asyncio.create_task(process_channel(all_channels[channel_name])) |
|
|
|
|
|
client.channels.add(channel_name) |
|
|
|
|
|
channel = all_channels[channel_name] |
|
|
|
|
|
channel.clients_by_host[client.hostname] = client |
|
|
|
|
|
logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}") |
|
|
|
|
|
await channel.msg_queue.put( |
|
|
|
|
|
IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode() |
|
|
|
|
|
) |
|
|
|
|
|
elif msg.cmd == "PRIVMSG": |
|
|
|
|
|
assert_registered(client, msg) |
|
|
|
|
|
await privmsg(client, msg.args[0], msg.args[1]) |
|
|
|
|
|
else: |
|
|
|
|
|
logging.warning(f"unsupported message {msg.cmd}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: |
|
|
|
|
|
msg = IRCMessage( |
|
|
|
|
|
"PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname |
|
|
|
|
|
).encode() |
|
|
|
|
|
|
|
|
|
|
|
for name, other_client in clients_by_nick.items(): |
|
|
|
|
|
if name == recipient: |
|
|
|
|
|
other_client.msg_queue.put_nowait(msg) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
for name, channel in all_channels.items(): |
|
|
|
|
|
if name == recipient: |
|
|
|
|
|
channel.msg_queue.put_nowait(msg) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
raise RuntimeError(f"unknown recipient {recipient}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_channel(channel: Channel) -> None: |
|
|
|
|
|
while True: |
|
|
|
|
|
msg = await channel.msg_queue.get() |
|
|
|
|
|
for client in channel.clients_by_host.values(): |
|
|
|
|
|
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")): |
|
|
|
|
|
continue |
|
|
|
|
|
client.msg_queue.put_nowait(msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_registered(client: Client, msg: IRCMessage) -> None: |
|
|
|
|
|
if client.registered: |
|
|
|
|
|
return |
|
|
|
|
|
raise RuntimeError(f"{msg.cmd} issued before client fully registered") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_argc(xs: List[str], i: int) -> None: |
|
|
|
|
|
if len(xs) == i: |
|
|
|
|
|
return |
|
|
|
|
|
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_writer(client: Client) -> None: |
|
|
|
|
|
while True: |
|
|
|
|
|
msg = await client.msg_queue.get() |
|
|
|
|
|
client.writer.write(msg) |
|
|
|
|
|
await client.writer.drain() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def register_client(reader, writer): |
|
|
|
|
|
client = Client( |
|
|
|
|
|
hostname=writer.get_extra_info("peername")[0], |
|
|
|
|
|
reader=reader, |
|
|
|
|
|
writer=writer, |
|
|
|
|
|
msg_queue=asyncio.Queue(), |
|
|
|
|
|
) |
|
|
|
|
|
asyncio.create_task(handle_reader(client)) |
|
|
|
|
|
asyncio.create_task(handle_writer(client)) |
|
|
|
|
|
|
|
|
msg = parse_message(raw_msg) |
|
|
|
|
|
await handle_cmd(server, client, msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
async def main(): |
|
|
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" |
|
|
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" |
|
|
port = os.getenv("PORT") or 6667 |
|
|
port = os.getenv("PORT") or 6667 |
|
|
|
|
|
|
|
|
|
|
|
register_cmd_handler("JOIN", 1, handle_join) |
|
|
|
|
|
register_cmd_handler("NICK", 1, handle_nick) |
|
|
|
|
|
register_cmd_handler("PRIVMSG", 2, handle_privmsg) |
|
|
|
|
|
register_cmd_handler("USER", 4, handle_user) |
|
|
|
|
|
|
|
|
|
|
|
irc_server = Server() |
|
|
|
|
|
|
|
|
|
|
|
async def register_client(reader, writer): |
|
|
|
|
|
client = Client( |
|
|
|
|
|
hostname=writer.get_extra_info("peername")[0], |
|
|
|
|
|
reader=reader, |
|
|
|
|
|
writer=writer, |
|
|
|
|
|
) |
|
|
|
|
|
asyncio.create_task(read_forever(irc_server, client)) |
|
|
|
|
|
asyncio.create_task(client.write_forever()) |
|
|
|
|
|
|
|
|
server = await asyncio.start_server( |
|
|
server = await asyncio.start_server( |
|
|
register_client, |
|
|
register_client, |
|
|
bind_addr, |
|
|
bind_addr, |
|
|