|
@ -1,10 +1,12 @@ |
|
|
from asyncio import StreamReader, StreamWriter, Queue |
|
|
|
|
|
|
|
|
from asyncio import StreamReader, StreamWriter, TimeoutError, Queue |
|
|
|
|
|
from asyncio.tasks import wait_for |
|
|
from dataclasses import dataclass, field |
|
|
from dataclasses import dataclass, field |
|
|
from datetime import datetime |
|
|
from datetime import datetime |
|
|
from logging import log, INFO |
|
|
from logging import log, INFO |
|
|
from paircd.reply import RPL_CREATED, RPL_MYINFO, RPL_WELCOME, RPL_YOURHOST |
|
|
|
|
|
|
|
|
from typing import Any, Set |
|
|
|
|
|
|
|
|
|
|
|
from paircd.reply import QUIT, RPL_CREATED, RPL_MYINFO, RPL_WELCOME, RPL_YOURHOST |
|
|
from paircd.message import Message |
|
|
from paircd.message import Message |
|
|
from typing import Set |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
@dataclass |
|
@ -12,6 +14,7 @@ class Client: |
|
|
hostname: str |
|
|
hostname: str |
|
|
reader: StreamReader |
|
|
reader: StreamReader |
|
|
writer: StreamWriter |
|
|
writer: StreamWriter |
|
|
|
|
|
|
|
|
msg_queue: Queue = field(default_factory=Queue) |
|
|
msg_queue: Queue = field(default_factory=Queue) |
|
|
|
|
|
|
|
|
nickname: str = "" |
|
|
nickname: str = "" |
|
@ -19,6 +22,7 @@ class Client: |
|
|
realname: str = "" |
|
|
realname: str = "" |
|
|
registered: bool = False |
|
|
registered: bool = False |
|
|
away: str = "" |
|
|
away: str = "" |
|
|
|
|
|
closed = False |
|
|
|
|
|
|
|
|
modes: Set[str] = field(default_factory=set) |
|
|
modes: Set[str] = field(default_factory=set) |
|
|
channels: Set[str] = field(default_factory=set) |
|
|
channels: Set[str] = field(default_factory=set) |
|
@ -31,11 +35,43 @@ class Client: |
|
|
def log(self, msg: str, level: int = INFO) -> None: |
|
|
def log(self, msg: str, level: int = INFO) -> None: |
|
|
log(level, f"{self.hostname} ({self.id()}) {msg}") |
|
|
log(level, f"{self.hostname} ({self.id()}) {msg}") |
|
|
|
|
|
|
|
|
async def write_forever(self) -> None: |
|
|
|
|
|
while True: |
|
|
|
|
|
msg = await self.msg_queue.get() |
|
|
|
|
|
self.writer.write(msg) |
|
|
|
|
|
await self.writer.drain() |
|
|
|
|
|
|
|
|
async def write_until_closed(self, server: Any) -> None: |
|
|
|
|
|
while not self.closed: |
|
|
|
|
|
msg = None |
|
|
|
|
|
try: |
|
|
|
|
|
msg = await wait_for(self.msg_queue.get(), timeout=1.0) |
|
|
|
|
|
except TimeoutError: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if msg is not None: |
|
|
|
|
|
self.writer.write(msg) |
|
|
|
|
|
try: |
|
|
|
|
|
await self.writer.drain() |
|
|
|
|
|
except ConnectionResetError: |
|
|
|
|
|
await self.quit(server, "Connection reset by peer") |
|
|
|
|
|
|
|
|
|
|
|
async def quit(self, server: Any, msg: str) -> None: |
|
|
|
|
|
if self.closed: |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
quit_msg = QUIT(msg, prefix=self.id()) |
|
|
|
|
|
for client in server.clients_by_nick.values(): |
|
|
|
|
|
client.write_message(quit_msg) |
|
|
|
|
|
for channel_name in self.channels: |
|
|
|
|
|
channel = server.get_channel_by_name(channel_name, create=False) |
|
|
|
|
|
if channel is None: |
|
|
|
|
|
continue |
|
|
|
|
|
channel.remove_client_by_nick(self.nickname) |
|
|
|
|
|
server.remove_client_by_name(self.nickname) |
|
|
|
|
|
await self.close() |
|
|
|
|
|
|
|
|
|
|
|
async def close(self) -> None: |
|
|
|
|
|
if self.closed: |
|
|
|
|
|
return |
|
|
|
|
|
self.closed = True |
|
|
|
|
|
if not self.writer.is_closing(): |
|
|
|
|
|
self.writer.close() |
|
|
|
|
|
await self.writer.wait_closed() |
|
|
|
|
|
|
|
|
def write_message(self, message: Message) -> None: |
|
|
def write_message(self, message: Message) -> None: |
|
|
self.msg_queue.put_nowait(message.encode()) |
|
|
self.msg_queue.put_nowait(message.encode()) |
|
|