from asyncio import StreamReader, StreamWriter, TimeoutError, Queue
|
|
from asyncio.tasks import wait_for
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from logging import log, INFO
|
|
from typing import Any, Set
|
|
|
|
from paircd.reply import RPL_CREATED, RPL_MYINFO, RPL_WELCOME, RPL_YOURHOST
|
|
from paircd.message import Message
|
|
|
|
|
|
@dataclass
|
|
class Client:
|
|
hostname: str
|
|
reader: StreamReader
|
|
writer: StreamWriter
|
|
|
|
msg_queue: Queue = field(default_factory=Queue)
|
|
|
|
nickname: str = ""
|
|
username: str = ""
|
|
realname: str = ""
|
|
registered: bool = False
|
|
away: str = ""
|
|
closed = False
|
|
|
|
modes: Set[str] = field(default_factory=set)
|
|
channels: Set[str] = field(default_factory=set)
|
|
|
|
def id(self) -> str:
|
|
nickname = self.nickname or "<unknown>"
|
|
username = self.username or "<unknown>"
|
|
return f"{nickname}!{username}@{self.hostname}"
|
|
|
|
def log(self, msg: str, level: int = INFO) -> None:
|
|
log(level, f"{self.hostname} ({self.id()}) {msg}")
|
|
|
|
async def write_until_closed(self, server: Any) -> None:
|
|
while not self.closed:
|
|
msg = None
|
|
try:
|
|
msg = await wait_for(self.msg_queue.get(), timeout=1.0)
|
|
except TimeoutError:
|
|
pass
|
|
|
|
if msg is not None:
|
|
self.writer.write(msg)
|
|
try:
|
|
await self.writer.drain()
|
|
except ConnectionResetError:
|
|
await server.disconnect_client(
|
|
self.nickname, "Connection reset by peer"
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
if self.closed:
|
|
return
|
|
self.closed = True
|
|
if not self.writer.is_closing():
|
|
self.writer.close()
|
|
await self.writer.wait_closed()
|
|
|
|
def write_message(self, message: Message) -> None:
|
|
self.msg_queue.put_nowait(message.encode())
|
|
|
|
def register(self) -> None:
|
|
if self.registered:
|
|
return
|
|
|
|
if not (self.nickname and self.username and self.realname):
|
|
return
|
|
|
|
self.registered = True
|
|
self.log("registered")
|
|
|
|
self.write_message(RPL_WELCOME(self.nickname, self.id()))
|
|
self.write_message(RPL_YOURHOST(self.nickname, "localhost", "paircd-0.0.1"))
|
|
# TODO: Pull timestamp from server instance
|
|
self.write_message(RPL_CREATED(self.nickname, datetime.utcnow()))
|
|
# TODO: Display list of supported user & channel modes
|
|
self.write_message(
|
|
RPL_MYINFO(self.nickname, "localhost", "paircd-0.0.1", "", "")
|
|
)
|
|
|
|
def add_mode(self, mode: str) -> None:
|
|
self.modes.add(mode)
|
|
|
|
def get_mode_settings(self) -> str:
|
|
return f"+{''.join(sorted(self.modes))}"
|
|
|
|
def remove_mode(self, mode: str) -> None:
|
|
self.modes.remove(mode)
|