from asyncio import StreamReader, StreamWriter, TimeoutError, Queue from asyncio.tasks import wait_for from dataclasses import dataclass, field from datetime import datetime from logging import log, INFO from typing import Any, Set from paircd.reply import RPL_CREATED, RPL_MYINFO, RPL_WELCOME, RPL_YOURHOST from paircd.message import Message @dataclass class Client: hostname: str reader: StreamReader writer: StreamWriter msg_queue: Queue = field(default_factory=Queue) nickname: str = "" username: str = "" realname: str = "" registered: bool = False away: str = "" closed = False modes: Set[str] = field(default_factory=set) channels: Set[str] = field(default_factory=set) def id(self) -> str: nickname = self.nickname or "" username = self.username or "" return f"{nickname}!{username}@{self.hostname}" def log(self, msg: str, level: int = INFO) -> None: log(level, f"{self.hostname} ({self.id()}) {msg}") async def write_until_closed(self, server: Any) -> None: while not self.closed: msg = None try: msg = await wait_for(self.msg_queue.get(), timeout=1.0) except TimeoutError: pass if msg is not None: self.writer.write(msg) try: await self.writer.drain() except ConnectionResetError: await server.disconnect_client( self.nickname, "Connection reset by peer" ) async def close(self) -> None: if self.closed: return self.closed = True if not self.writer.is_closing(): self.writer.close() await self.writer.wait_closed() def write_message(self, message: Message) -> None: self.msg_queue.put_nowait(message.encode()) def register(self) -> None: if self.registered: return if not (self.nickname and self.username and self.realname): return self.registered = True self.log("registered") self.write_message(RPL_WELCOME(self.nickname, self.id())) self.write_message(RPL_YOURHOST(self.nickname, "localhost", "paircd-0.0.1")) # TODO: Pull timestamp from server instance self.write_message(RPL_CREATED(self.nickname, datetime.utcnow())) # TODO: Display list of supported user & channel modes self.write_message( RPL_MYINFO(self.nickname, "localhost", "paircd-0.0.1", "", "") ) def add_mode(self, mode: str) -> None: self.modes.add(mode) def get_mode_settings(self) -> str: return f"+{''.join(sorted(self.modes))}" def remove_mode(self, mode: str) -> None: self.modes.remove(mode)