python ircd using asyncio
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
5.0 KiB

  1. import asyncio
  2. import dataclasses
  3. import logging
  4. import os
  5. from typing import Dict, List, Set
  6. from paircd.message import IRCMessage
  7. clients_by_nick: Dict[str, "Client"] = {}
  8. all_channels: Dict[str, "Channel"] = {}
  9. @dataclasses.dataclass
  10. class Client:
  11. hostname: str
  12. reader: asyncio.StreamReader
  13. writer: asyncio.StreamWriter
  14. msg_queue: asyncio.Queue
  15. nickname: str = ""
  16. username: str = ""
  17. realname: str = ""
  18. registered: bool = False
  19. channels: Set[str] = dataclasses.field(default_factory=set)
  20. def id(self) -> str:
  21. return f"{self.nickname}!{self.username}@{self.hostname}"
  22. @dataclasses.dataclass
  23. class Channel:
  24. name: str
  25. clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
  26. msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
  27. logging.basicConfig(
  28. format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
  29. )
  30. logger = logging.getLogger()
  31. async def handle_reader(client: Client) -> None:
  32. while True:
  33. raw_msg = await client.reader.readuntil(b"\r\n")
  34. msg = IRCMessage.parse(raw_msg)
  35. await handle_irc_msg(client, msg)
  36. async def handle_irc_msg(client: Client, msg: IRCMessage) -> None:
  37. if msg.cmd == "CAP":
  38. # https://ircv3.net/specs/extensions/capability-negotiation.html
  39. logging.warning("TODO: implement support for client capability negotiation")
  40. elif msg.cmd == "NICK":
  41. if client.nickname:
  42. del clients_by_nick[client.nickname]
  43. client.nickname = msg.args[0]
  44. clients_by_nick[client.nickname] = client
  45. if client.username and client.realname:
  46. client.registered = True
  47. logging.info(f"{client.hostname} ({client.id()}) registered")
  48. elif msg.cmd == "USER":
  49. if client.registered:
  50. raise RuntimeError("USER command issued after registration")
  51. client.username = msg.args[0]
  52. client.realname = msg.args[3]
  53. if client.nickname:
  54. client.registered = True
  55. logging.info(f"{client.hostname} ({client.id()}) registered")
  56. elif msg.cmd == "JOIN":
  57. channel_name = msg.args[0]
  58. assert_registered(client, msg)
  59. if not channel_name.startswith("#"):
  60. raise RuntimeError("invalid channel name")
  61. if channel_name not in all_channels:
  62. all_channels[channel_name] = Channel(name=channel_name)
  63. asyncio.create_task(process_channel(all_channels[channel_name]))
  64. client.channels.add(channel_name)
  65. channel = all_channels[channel_name]
  66. channel.clients_by_host[client.hostname] = client
  67. logging.info(f"{client.hostname} ({client.id()}) joined {channel_name}")
  68. await channel.msg_queue.put(
  69. IRCMessage(cmd="JOIN", args=[channel_name], prefix=client.nickname).encode()
  70. )
  71. elif msg.cmd == "PRIVMSG":
  72. assert_registered(client, msg)
  73. await privmsg(client, msg.args[0], msg.args[1])
  74. else:
  75. logging.warning(f"unsupported message {msg.cmd}")
  76. async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
  77. msg = IRCMessage(
  78. "PRIVMSG", [recipient, f":{raw_msg}"], prefix=client.nickname
  79. ).encode()
  80. for name, other_client in clients_by_nick.items():
  81. if name == recipient:
  82. other_client.msg_queue.put_nowait(msg)
  83. return
  84. for name, channel in all_channels.items():
  85. if name == recipient:
  86. channel.msg_queue.put_nowait(msg)
  87. return
  88. raise RuntimeError(f"unknown recipient {recipient}")
  89. async def process_channel(channel: Channel) -> None:
  90. while True:
  91. msg = await channel.msg_queue.get()
  92. for client in channel.clients_by_host.values():
  93. if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
  94. continue
  95. client.msg_queue.put_nowait(msg)
  96. def assert_registered(client: Client, msg: IRCMessage) -> None:
  97. if client.registered:
  98. return
  99. raise RuntimeError(f"{msg.cmd} issued before client fully registered")
  100. def assert_argc(xs: List[str], i: int) -> None:
  101. if len(xs) == i:
  102. return
  103. raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
  104. async def handle_writer(client: Client) -> None:
  105. while True:
  106. msg = await client.msg_queue.get()
  107. client.writer.write(msg)
  108. await client.writer.drain()
  109. async def register_client(reader, writer):
  110. client = Client(
  111. hostname=writer.get_extra_info("peername")[0],
  112. reader=reader,
  113. writer=writer,
  114. msg_queue=asyncio.Queue(),
  115. )
  116. asyncio.create_task(handle_reader(client))
  117. asyncio.create_task(handle_writer(client))
  118. async def main():
  119. bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
  120. port = os.getenv("PORT") or 6667
  121. server = await asyncio.start_server(
  122. register_client,
  123. bind_addr,
  124. port,
  125. reuse_port=True,
  126. )
  127. logger.info(f"Listening on {bind_addr}:{port}...")
  128. async with server:
  129. await server.serve_forever()
  130. if __name__ == "__main__":
  131. asyncio.run(main())