|
@ -4,6 +4,8 @@ import logging |
|
|
import os |
|
|
import os |
|
|
from typing import Dict, List, Set |
|
|
from typing import Dict, List, Set |
|
|
|
|
|
|
|
|
|
|
|
from paircd.message import IRCMessage |
|
|
|
|
|
|
|
|
clients_by_nick: Dict[str, "Client"] = {} |
|
|
clients_by_nick: Dict[str, "Client"] = {} |
|
|
all_channels: Dict[str, "Channel"] = {} |
|
|
all_channels: Dict[str, "Channel"] = {} |
|
|
|
|
|
|
|
@ -43,59 +45,56 @@ logger = logging.getLogger() |
|
|
async def handle_reader(client: Client) -> None: |
|
|
async def handle_reader(client: Client) -> None: |
|
|
while True: |
|
|
while True: |
|
|
raw_msg = await client.reader.readuntil(b"\r\n") |
|
|
raw_msg = await client.reader.readuntil(b"\r\n") |
|
|
if not raw_msg.endswith(b"\r\n"): |
|
|
|
|
|
raise RuntimeError("malformed message") |
|
|
|
|
|
msg = parse_irc_msg(raw_msg) |
|
|
|
|
|
|
|
|
msg = IRCMessage.parse(raw_msg) |
|
|
await handle_irc_msg(client, msg) |
|
|
await handle_irc_msg(client, msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_irc_msg(client: Client, msg: List[str]) -> None: |
|
|
|
|
|
if msg[0] == "CAP": |
|
|
|
|
|
|
|
|
async def handle_irc_msg(client: Client, msg: IRCMessage) -> None: |
|
|
|
|
|
if msg.cmd == "CAP": |
|
|
# https://ircv3.net/specs/extensions/capability-negotiation.html |
|
|
# https://ircv3.net/specs/extensions/capability-negotiation.html |
|
|
logging.warning("TODO: implement support for client capability negotiation") |
|
|
logging.warning("TODO: implement support for client capability negotiation") |
|
|
elif msg[0] == "NICK": |
|
|
|
|
|
assert_argc(msg, 2) |
|
|
|
|
|
|
|
|
elif msg.cmd == "NICK": |
|
|
if client.nickname: |
|
|
if client.nickname: |
|
|
del clients_by_nick[client.nickname] |
|
|
del clients_by_nick[client.nickname] |
|
|
client.nickname = msg[1] |
|
|
|
|
|
|
|
|
client.nickname = msg.args[0] |
|
|
clients_by_nick[client.nickname] = client |
|
|
clients_by_nick[client.nickname] = client |
|
|
if client.username and client.realname: |
|
|
if client.username and client.realname: |
|
|
client.registered = True |
|
|
client.registered = True |
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
elif msg[0] == "USER": |
|
|
|
|
|
assert_argc(msg, 5) |
|
|
|
|
|
|
|
|
elif msg.cmd == "USER": |
|
|
if client.registered: |
|
|
if client.registered: |
|
|
raise RuntimeError("USER command issued after registration") |
|
|
raise RuntimeError("USER command issued after registration") |
|
|
client.username = msg[1] |
|
|
|
|
|
client.realname = msg[4] |
|
|
|
|
|
|
|
|
client.username = msg.args[0] |
|
|
|
|
|
client.realname = msg.args[3] |
|
|
if client.nickname: |
|
|
if client.nickname: |
|
|
client.registered = True |
|
|
client.registered = True |
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
logging.info(f"{client.hostname} ({client.id()}) registered") |
|
|
elif msg[0] == "JOIN": |
|
|
|
|
|
assert_argc(msg, 2) |
|
|
|
|
|
|
|
|
elif msg.cmd == "JOIN": |
|
|
|
|
|
channel_name = msg.args[0] |
|
|
assert_registered(client, msg) |
|
|
assert_registered(client, msg) |
|
|
if not msg[1].startswith("#"): |
|
|
|
|
|
|
|
|
if not channel_name.startswith("#"): |
|
|
raise RuntimeError("invalid channel name") |
|
|
raise RuntimeError("invalid channel name") |
|
|
if msg[1] not in all_channels: |
|
|
|
|
|
all_channels[msg[1]] = Channel(name=msg[1]) |
|
|
|
|
|
asyncio.create_task(process_channel(all_channels[msg[1]])) |
|
|
|
|
|
channel = all_channels[msg[1]] |
|
|
|
|
|
client.channels.add(channel.name) |
|
|
|
|
|
|
|
|
if channel_name not in all_channels: |
|
|
|
|
|
all_channels[channel_name] = Channel(name=channel_name) |
|
|
|
|
|
asyncio.create_task(process_channel(all_channels[channel_name])) |
|
|
|
|
|
client.channels.add(channel_name) |
|
|
|
|
|
channel = all_channels[channel_name] |
|
|
channel.clients_by_host[client.hostname] = client |
|
|
channel.clients_by_host[client.hostname] = client |
|
|
logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}") |
|
|
|
|
|
|
|
|
logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}") |
|
|
await channel.msg_queue.put( |
|
|
await channel.msg_queue.put( |
|
|
f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8") |
|
|
|
|
|
|
|
|
IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode() |
|
|
) |
|
|
) |
|
|
elif msg[0] == "PRIVMSG": |
|
|
|
|
|
assert_argc(msg, 3) |
|
|
|
|
|
|
|
|
elif msg.cmd == "PRIVMSG": |
|
|
assert_registered(client, msg) |
|
|
assert_registered(client, msg) |
|
|
await privmsg(client, msg[1], msg[2]) |
|
|
|
|
|
|
|
|
await privmsg(client, msg.args[0], msg.args[1]) |
|
|
else: |
|
|
else: |
|
|
logging.warning(f"unsupported message {msg[0]}") |
|
|
|
|
|
|
|
|
logging.warning(f"unsupported message {msg.cmd}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: |
|
|
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None: |
|
|
msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8") |
|
|
|
|
|
|
|
|
msg = IRCMessage( |
|
|
|
|
|
"PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname |
|
|
|
|
|
).encode() |
|
|
|
|
|
|
|
|
for name, other_client in clients_by_nick.items(): |
|
|
for name, other_client in clients_by_nick.items(): |
|
|
if name == recipient: |
|
|
if name == recipient: |
|
@ -119,10 +118,10 @@ async def process_channel(channel: Channel) -> None: |
|
|
client.msg_queue.put_nowait(msg) |
|
|
client.msg_queue.put_nowait(msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_registered(client: Client, msg: List[str]) -> None: |
|
|
|
|
|
|
|
|
def assert_registered(client: Client, msg: IRCMessage) -> None: |
|
|
if client.registered: |
|
|
if client.registered: |
|
|
return |
|
|
return |
|
|
raise RuntimeError(f"{msg[0]} command issued before client fully registered") |
|
|
|
|
|
|
|
|
raise RuntimeError(f"{msg.cmd} issued before client fully registered") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assert_argc(xs: List[str], i: int) -> None: |
|
|
def assert_argc(xs: List[str], i: int) -> None: |
|
@ -131,25 +130,6 @@ def assert_argc(xs: List[str], i: int) -> None: |
|
|
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") |
|
|
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_irc_msg(raw_msg: str) -> List[str]: |
|
|
|
|
|
tokens: List[str] = [] |
|
|
|
|
|
raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ") |
|
|
|
|
|
|
|
|
|
|
|
for i, token in enumerate(raw_tokens): |
|
|
|
|
|
if token.startswith(":"): |
|
|
|
|
|
trailing = token[1:] + " ".join(raw_tokens[i + 1 :]) |
|
|
|
|
|
tokens.append(trailing) |
|
|
|
|
|
break |
|
|
|
|
|
tokens.append(token) |
|
|
|
|
|
|
|
|
|
|
|
if len(tokens) > 0: |
|
|
|
|
|
tokens[-1] = tokens[-1].strip() |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError("empty message") |
|
|
|
|
|
|
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def handle_writer(client: Client) -> None: |
|
|
async def handle_writer(client: Client) -> None: |
|
|
while True: |
|
|
while True: |
|
|
msg = await client.msg_queue.get() |
|
|
msg = await client.msg_queue.get() |
|
|