Browse Source

Implement basic IRCd

master
Forest Belton 2 years ago
parent
commit
e4a2c3405d
3 changed files with 202 additions and 0 deletions
  1. +14
    -0
      README.md
  2. +0
    -0
      README.rst
  3. +188
    -0
      paircd/main.py

+ 14
- 0
README.md View File

@ -0,0 +1,14 @@
# paircd
a **P**ython **A**synchronous **IRC** **D**aemon
## setup
```
$ poetry install
```
## run
```
$ poetry run python
```

+ 0
- 0
README.rst View File


+ 188
- 0
paircd/main.py View File

@ -0,0 +1,188 @@
import asyncio
import dataclasses
import logging
import os
from typing import Dict, List, Set
clients_by_nick: Dict[str, "Client"] = {}
all_channels: Dict[str, "Channel"] = {}
@dataclasses.dataclass
class Client:
hostname: str
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
msg_queue: asyncio.Queue
nickname: str = ""
username: str = ""
realname: str = ""
registered: bool = False
channels: Set[str] = dataclasses.field(default_factory=set)
def id(self) -> str:
return f"{self.nickname}!{self.username}@{self.hostname}"
@dataclasses.dataclass
class Channel:
name: str
clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
logging.basicConfig(
format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
)
logger = logging.getLogger()
async def handle_reader(client: Client) -> None:
while True:
raw_msg = await client.reader.readuntil(b"\r\n")
if not raw_msg.endswith(b"\r\n"):
raise RuntimeError("malformed message")
msg = parse_irc_msg(raw_msg)
await handle_irc_msg(client, msg)
async def handle_irc_msg(client: Client, msg: List[str]) -> None:
if msg[0] == "CAP":
# https://ircv3.net/specs/extensions/capability-negotiation.html
logging.warning("TODO: implement support for client capability negotiation")
elif msg[0] == "NICK":
assert_argc(msg, 2)
if client.nickname:
del clients_by_nick[client.nickname]
client.nickname = msg[1]
clients_by_nick[client.nickname] = client
if client.username and client.realname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")
elif msg[0] == "USER":
assert_argc(msg, 5)
if client.registered:
raise RuntimeError("USER command issued after registration")
client.username = msg[1]
client.realname = msg[4]
if client.nickname:
client.registered = True
logging.info(f"{client.hostname} ({client.id()}) registered")
elif msg[0] == "JOIN":
assert_argc(msg, 2)
assert_registered(client, msg)
if not msg[1].startswith("#"):
raise RuntimeError("invalid channel name")
if msg[1] not in all_channels:
all_channels[msg[1]] = Channel(name=msg[1])
asyncio.create_task(process_channel(all_channels[msg[1]]))
channel = all_channels[msg[1]]
client.channels.add(channel.name)
channel.clients_by_host[client.hostname] = client
logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}")
await channel.msg_queue.put(
f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8")
)
elif msg[0] == "PRIVMSG":
assert_argc(msg, 3)
assert_registered(client, msg)
await privmsg(client, msg[1], msg[2])
else:
logging.warning(f"unsupported message {msg[0]}")
async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8")
for name, other_client in clients_by_nick.items():
if name == recipient:
other_client.msg_queue.put_nowait(msg)
return
for name, channel in all_channels.items():
if name == recipient:
channel.msg_queue.put_nowait(msg)
return
raise RuntimeError(f"unknown recipient {recipient}")
async def process_channel(channel: Channel) -> None:
while True:
msg = await channel.msg_queue.get()
for client in channel.clients_by_host.values():
if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
continue
client.msg_queue.put_nowait(msg)
def assert_registered(client: Client, msg: List[str]) -> None:
if client.registered:
return
raise RuntimeError(f"{msg[0]} command issued before client fully registered")
def assert_argc(xs: List[str], i: int) -> None:
if len(xs) == i:
return
raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
def parse_irc_msg(raw_msg: str) -> List[str]:
tokens: List[str] = []
raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ")
for i, token in enumerate(raw_tokens):
if token.startswith(":"):
trailing = token[1:] + " ".join(raw_tokens[i + 1 :])
tokens.append(trailing)
break
tokens.append(token)
if len(tokens) > 0:
tokens[-1] = tokens[-1].strip()
else:
raise RuntimeError("empty message")
return tokens
async def handle_writer(client: Client) -> None:
while True:
msg = await client.msg_queue.get()
client.writer.write(msg)
await client.writer.drain()
async def register_client(reader, writer):
client = Client(
hostname=writer.get_extra_info("peername")[0],
reader=reader,
writer=writer,
msg_queue=asyncio.Queue(),
)
asyncio.create_task(handle_reader(client))
asyncio.create_task(handle_writer(client))
async def main():
bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
port = os.getenv("PORT") or 6667
server = await asyncio.start_server(
register_client,
bind_addr,
port,
reuse_port=True,
)
logger.info(f"Listening on {bind_addr}:{port}...")
async with server:
await server.serve_forever()
if __name__ == "__main__":
asyncio.run(main())

Loading…
Cancel
Save