Browse Source

Factorize remainder of code

master
Forest Belton 3 years ago
parent
commit
c4878ff5e7
11 changed files with 238 additions and 167 deletions
  1. +24
    -0
      paircd/channel.py
  2. +27
    -0
      paircd/client.py
  3. +21
    -0
      paircd/handle.py
  4. +0
    -0
      paircd/handler/__init__.py
  5. +26
    -0
      paircd/handler/join.py
  6. +20
    -0
      paircd/handler/nick.py
  7. +22
    -0
      paircd/handler/privmsg.py
  8. +17
    -0
      paircd/handler/user.py
  9. +27
    -134
      paircd/main.py
  10. +33
    -33
      paircd/message.py
  11. +21
    -0
      paircd/server.py

+ 24
- 0
paircd/channel.py View File

@ -0,0 +1,24 @@
from asyncio import Queue
from dataclasses import dataclass, field
from typing import Dict
from paircd.client import Client
@dataclass
class Channel:
name: str
clients_by_nick: Dict[str, Client] = field(default_factory=dict)
msg_queue: Queue = field(default_factory=Queue)
def add_client(self, client: Client) -> None:
self.clients_by_nick[client.nickname] = client
async def process(self) -> None:
while True:
msg = await self.msg_queue.get()
for client in self.clients_by_nick.values():
# Don't broadcast client's messages back to themselves
if msg.startswith(f":{client.id()} PRIVMSG".encode("utf-8")):
continue
client.msg_queue.put_nowait(msg)

+ 27
- 0
paircd/client.py View File

@ -0,0 +1,27 @@
from asyncio import StreamReader, StreamWriter, Queue
from dataclasses import dataclass, field
from typing import Any, Set
@dataclass
class Client:
hostname: str
reader: StreamReader
writer: StreamWriter
msg_queue: Queue = field(default_factory=Queue)
nickname: str = ""
username: str = ""
realname: str = ""
registered: bool = False
channels: Set[str] = field(default_factory=set)
def id(self) -> str:
return f"{self.nickname}!{self.username}@{self.hostname}"
async def write_forever(self) -> None:
while True:
msg = await self.msg_queue.get()
self.writer.write(msg)
await self.writer.drain()

+ 21
- 0
paircd/handle.py View File

@ -0,0 +1,21 @@
import logging
from typing import Any, Dict
from paircd.client import Client
from paircd.message import Message
from paircd.server import Server
CMD_HANDLERS: Dict[str, Any] = {}
CMD_EXPECTED_ARGC: Dict[str, int] = {}
def register_cmd_handler(cmd: str, argc: int, handler) -> None:
CMD_EXPECTED_ARGC[cmd] = argc
CMD_HANDLERS[cmd] = handler
async def handle_cmd(server: Server, client: Client, msg: Message) -> None:
if msg.cmd not in CMD_HANDLERS:
logging.warning(f"Unknown command: {msg.cmd}")
return
await CMD_HANDLERS[msg.cmd](server, client, msg)

+ 0
- 0
paircd/handler/__init__.py View File


+ 26
- 0
paircd/handler/join.py View File

@ -0,0 +1,26 @@
from asyncio import create_task
import logging
from paircd.channel import Channel
from paircd.client import Client
from paircd.message import Message
from paircd.server import Server
async def handle_join(server: Server, client: Client, msg: Message) -> None:
channel_name = msg.args[0]
if not client.registered:
raise RuntimeError("JOIN: not registered")
if not channel_name.startswith("#"):
raise RuntimeError("invalid channel name")
channel = server.get_channel_by_name(channel_name)
channel.add_client(client)
client.channels.add(channel_name)
logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}")
await channel.msg_queue.put(
Message(cmd="JOIN", args=[channel_name], prefix=client.id()).encode()
)

+ 20
- 0
paircd/handler/nick.py View File

@ -0,0 +1,20 @@
import logging
from paircd.client import Client
from paircd.message import Message
from paircd.server import Server
async def handle_nick(server: Server, client: Client, msg: Message) -> None:
nickname = msg.args[0]
if client.nickname:
del server.clients_by_nick[client.nickname]
# TODO: Update all channel references
client.nickname = nickname
server.add_client(client)
if client.username and client.realname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")

+ 22
- 0
paircd/handler/privmsg.py View File

@ -0,0 +1,22 @@
from paircd.client import Client
from paircd.message import Message
from paircd.server import Server
async def handle_privmsg(server: Server, client: Client, msg: Message) -> None:
recipient = msg.args[0]
raw_msg = msg.args[1]
msg = Message("PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.id()).encode()
for name, other_client in server.clients_by_nick.items():
if name == recipient:
other_client.msg_queue.put_nowait(msg)
return
for name, channel in server.channels_by_name.items():
if name == recipient:
channel.msg_queue.put_nowait(msg)
return
raise RuntimeError(f"Unknown recipient {recipient}")

+ 17
- 0
paircd/handler/user.py View File

@ -0,0 +1,17 @@
import logging
from paircd.client import Client
from paircd.message import Message
from paircd.server import Server
async def handle_user(server: Server, client: Client, msg: Message) -> None:
if client.registered:
raise RuntimeError("USER command issued after registration")
client.username = msg.args[0]
client.realname = msg.args[3]
if client.nickname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")

+ 27
- 134
paircd/main.py View File

@ -1,39 +1,15 @@
import asyncio import asyncio
import dataclasses
import logging import logging
import os import os
from typing import Dict, List, Set
from paircd.message import IRCMessage
clients_by_nick: Dict[str, "Client"] = {}
all_channels: Dict[str, "Channel"] = {}
@dataclasses.dataclass
class Client:
hostname: str
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
msg_queue: asyncio.Queue
nickname: str = ""
username: str = ""
realname: str = ""
registered: bool = False
channels: Set[str] = dataclasses.field(default_factory=set)
def id(self) -> str:
return f"{self.nickname}!{self.username}@{self.hostname}"
@dataclasses.dataclass
class Channel:
name: str
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
from paircd.client import Client
from paircd.handle import handle_cmd, register_cmd_handler
from paircd.handler.join import handle_join
from paircd.handler.nick import handle_nick
from paircd.handler.privmsg import handle_privmsg
from paircd.handler.user import handle_user
from paircd.message import parse_message
from paircd.server import Server
logging.basicConfig( logging.basicConfig(
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
@ -42,116 +18,33 @@ logging.basicConfig(
logger = logging.getLogger() logger = logging.getLogger()
async def handle_reader(client: Client) -> None:
async def read_forever(server: Server, client: Client) -> None:
while True: while True:
raw_msg = await client.reader.readuntil(b"\r\n") raw_msg = await client.reader.readuntil(b"\r\n")
msg = IRCMessage.parse(raw_msg)
await handle_irc_msg(client, msg)
async def handle_irc_msg(client: Client, msg: IRCMessage) -> None:
if msg.cmd == "CAP":
# https://ircv3.net/specs/extensions/capability-negotiation.html
logging.warning("TODO: implement support for client capability negotiation")
elif msg.cmd == "NICK":
if client.nickname:
del clients_by_nick[client.nickname]
client.nickname = msg.args[0]
clients_by_nick[client.nickname] = client
if client.username and client.realname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")
elif msg.cmd == "USER":
if client.registered:
raise RuntimeError("USER command issued after registration")
client.username = msg.args[0]
client.realname = msg.args[3]
if client.nickname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")
elif msg.cmd == "JOIN":
channel_name = msg.args[0]
assert_registered(client, msg)
if not channel_name.startswith("#"):
raise RuntimeError("invalid channel name")
if channel_name not in all_channels:
all_channels[channel_name] = Channel(name=channel_name)
asyncio.create_task(process_channel(all_channels[channel_name]))
client.channels.add(channel_name)
channel = all_channels[channel_name]
channel.clients_by_host[client.hostname] = client
logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}")
await channel.msg_queue.put(
IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode()
)
elif msg.cmd == "PRIVMSG":
assert_registered(client, msg)
await privmsg(client, msg.args[0], msg.args[1])
else:
logging.warning(f"unsupported message {msg.cmd}")
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
msg = IRCMessage(
"PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname
).encode()
for name, other_client in clients_by_nick.items():
if name == recipient:
other_client.msg_queue.put_nowait(msg)
return
for name, channel in all_channels.items():
if name == recipient:
channel.msg_queue.put_nowait(msg)
return
raise RuntimeError(f"unknown recipient {recipient}")
async def process_channel(channel: Channel) -> None:
while True:
msg = await channel.msg_queue.get()
for client in channel.clients_by_host.values():
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
continue
client.msg_queue.put_nowait(msg)
def assert_registered(client: Client, msg: IRCMessage) -> None:
if client.registered:
return
raise RuntimeError(f"{msg.cmd} issued before client fully registered")
def assert_argc(xs: List[str], i: int) -> None:
if len(xs) == i:
return
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
async def handle_writer(client: Client) -> None:
while True:
msg = await client.msg_queue.get()
client.writer.write(msg)
await client.writer.drain()
async def register_client(reader, writer):
client = Client(
hostname=writer.get_extra_info("peername")[0],
reader=reader,
writer=writer,
msg_queue=asyncio.Queue(),
)
asyncio.create_task(handle_reader(client))
asyncio.create_task(handle_writer(client))
msg = parse_message(raw_msg)
await handle_cmd(server, client, msg)
async def main(): async def main():
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0" bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
port = os.getenv("PORT") or 6667 port = os.getenv("PORT") or 6667
register_cmd_handler("JOIN", 1, handle_join)
register_cmd_handler("NICK", 1, handle_nick)
register_cmd_handler("PRIVMSG", 2, handle_privmsg)
register_cmd_handler("USER", 4, handle_user)
irc_server = Server()
async def register_client(reader, writer):
client = Client(
hostname=writer.get_extra_info("peername")[0],
reader=reader,
writer=writer,
)
asyncio.create_task(read_forever(irc_server, client))
asyncio.create_task(client.write_forever())
server = await asyncio.start_server( server = await asyncio.start_server(
register_client, register_client,
bind_addr, bind_addr,

+ 33
- 33
paircd/message.py View File

@ -11,7 +11,7 @@ EXPECTED_ARG_COUNT = {
} }
class IRCParsingError(Exception):
class ParsingError(Exception):
message: str message: str
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
@ -19,7 +19,7 @@ class IRCParsingError(Exception):
@dataclass @dataclass
class IRCMessage:
class Message:
cmd: str cmd: str
args: List[str] args: List[str]
prefix: str = "" prefix: str = ""
@ -32,34 +32,34 @@ class IRCMessage:
# TODO: Raise exception if formatted message exceeds 512 bytes # TODO: Raise exception if formatted message exceeds 512 bytes
return f"{prefix}{self.cmd} {' '.join(self.args)}\r\n".encode("utf-8") return f"{prefix}{self.cmd} {' '.join(self.args)}\r\n".encode("utf-8")
@staticmethod
def parse(raw: bytes) -> >"IRCMessage":
if len(raw) > MAX_MESSAGE_SIZE:
raise IRCParsingError(
f"Message is {len(raw)} bytes, larger than allowed {MAX_MESSAGE_SIZE}"
)
if not raw.endswith(b"\r\n"):
raise IRCParsingError("Message does not terminate in CRLF")
tokens: List[str] = []
raw_tokens: List[str] = raw.decode("utf-8").split(" ")
for i, token in enumerate(raw_tokens):
if token.startswith(":"):
trailing = token[1:] + " " + " ".join(raw_tokens[i + 1 :])
tokens.append(trailing)
break
tokens.append(token)
if len(tokens) == 0:
raise IRCParsingError("Message has no command")
cmd = tokens[0]
if cmd in EXPECTED_ARG_COUNT and EXPECTED_ARG_COUNT[cmd] != len(tokens) - 1:
raise IRCParsingError(
f"{cmd} had {len(tokens)-1} arguments, expected {EXPECTED_ARG_COUNT[cmd]}"
)
tokens[-1] = tokens[-1].strip()
return IRCMessage(cmd=tokens[0], args=tokens[1:])
def parse_message(raw: bytes) -> Message:
if len(raw) > MAX_MESSAGE_SIZE:
raise ParsingError(
f"Message is {len(raw)} bytes, larger than allowed {MAX_MESSAGE_SIZE}"
)
if not raw.endswith(b"\r\n"):
raise ParsingError("Message does not terminate in CRLF")
tokens: List[str] = []
raw_tokens: List[str] = raw.decode("utf-8").split(" ")
for i, token in enumerate(raw_tokens):
if token.startswith(":"):
trailing = token[1:] + " " + " ".join(raw_tokens[i + 1 :])
tokens.append(trailing)
break
tokens.append(token)
if len(tokens) == 0:
raise ParsingError("Message has no command")
cmd = tokens[0]
if cmd in EXPECTED_ARG_COUNT and EXPECTED_ARG_COUNT[cmd] != len(tokens) - 1:
raise ParsingError(
f"{cmd} had {len(tokens)-1} arguments, expected {EXPECTED_ARG_COUNT[cmd]}"
)
tokens[-1] = tokens[-1].strip()
return Message(cmd=tokens[0], args=tokens[1:])

+ 21
- 0
paircd/server.py View File

@ -0,0 +1,21 @@
from asyncio import create_task
from dataclasses import dataclass, field
from typing import Dict
from paircd.client import Client
from paircd.channel import Channel
@dataclass
class Server:
clients_by_nick: Dict[str, Client] = field(default_factory=dict)
channels_by_name: Dict[str, Channel] = field(default_factory=dict)
def add_client(self, client: Client) -> None:
self.clients_by_nick[client.nickname] = client
def get_channel_by_name(self, name: str) -> Channel:
if name not in self.channels_by_name:
self.channels_by_name[name] = Channel(name=name)
create_task(self.channels_by_name[name].process())
return self.channels_by_name[name]

Loading…
Cancel
Save