python ircd using asyncio
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

188 lines
5.5 KiB

  1. import asyncio
  2. import dataclasses
  3. import logging
  4. import os
  5. from typing import Dict, List, Set
  6. clients_by_nick: Dict[str, "Client"] = {}
  7. all_channels: Dict[str, "Channel"] = {}
  8. @dataclasses.dataclass
  9. class Client:
  10. hostname: str
  11. reader: asyncio.StreamReader
  12. writer: asyncio.StreamWriter
  13. msg_queue: asyncio.Queue
  14. nickname: str = ""
  15. username: str = ""
  16. realname: str = ""
  17. registered: bool = False
  18. channels: Set[str] = dataclasses.field(default_factory=set)
  19. def id(self) -> str:
  20. return f"{self.nickname}!{self.username}@{self.hostname}"
  21. @dataclasses.dataclass
  22. class Channel:
  23. name: str
  24. clients_by_host: Dict[str, Client] = dataclasses.field(default_factory=dict)
  25. msg_queue: asyncio.Queue = dataclasses.field(default_factory=asyncio.Queue)
  26. logging.basicConfig(
  27. format="%(asctime)s [%(levelname)s] - %(message)s", level=logging.INFO
  28. )
  29. logger = logging.getLogger()
  30. async def handle_reader(client: Client) -> None:
  31. while True:
  32. raw_msg = await client.reader.readuntil(b"\r\n")
  33. if not raw_msg.endswith(b"\r\n"):
  34. raise RuntimeError("malformed message")
  35. msg = parse_irc_msg(raw_msg)
  36. await handle_irc_msg(client, msg)
  37. async def handle_irc_msg(client: Client, msg: List[str]) -> None:
  38. if msg[0] == "CAP":
  39. # https://ircv3.net/specs/extensions/capability-negotiation.html
  40. logging.warning("TODO: implement support for client capability negotiation")
  41. elif msg[0] == "NICK":
  42. assert_argc(msg, 2)
  43. if client.nickname:
  44. del clients_by_nick[client.nickname]
  45. client.nickname = msg[1]
  46. clients_by_nick[client.nickname] = client
  47. if client.username and client.realname:
  48. client.registered = True
  49. logging.info(f"{client.hostname} ({client.id()}) registered")
  50. elif msg[0] == "USER":
  51. assert_argc(msg, 5)
  52. if client.registered:
  53. raise RuntimeError("USER command issued after registration")
  54. client.username = msg[1]
  55. client.realname = msg[4]
  56. if client.nickname:
  57. client.registered = True
  58. logging.info(f"{client.hostname} ({client.id()}) registered")
  59. elif msg[0] == "JOIN":
  60. assert_argc(msg, 2)
  61. assert_registered(client, msg)
  62. if not msg[1].startswith("#"):
  63. raise RuntimeError("invalid channel name")
  64. if msg[1] not in all_channels:
  65. all_channels[msg[1]] = Channel(name=msg[1])
  66. asyncio.create_task(process_channel(all_channels[msg[1]]))
  67. channel = all_channels[msg[1]]
  68. client.channels.add(channel.name)
  69. channel.clients_by_host[client.hostname] = client
  70. logging.info(f"{client.hostname} ({client.id()}) joined {msg[1]}")
  71. await channel.msg_queue.put(
  72. f":{client.nickname} JOIN {msg[1]}\r\n".encode("utf-8")
  73. )
  74. elif msg[0] == "PRIVMSG":
  75. assert_argc(msg, 3)
  76. assert_registered(client, msg)
  77. await privmsg(client, msg[1], msg[2])
  78. else:
  79. logging.warning(f"unsupported message {msg[0]}")
  80. async def privmsg(client: Client, recipient: str, raw_msg: str) -> None:
  81. msg = f":{client.nickname} PRIVMSG {recipient} :{raw_msg}\r\n".encode("utf-8")
  82. for name, other_client in clients_by_nick.items():
  83. if name == recipient:
  84. other_client.msg_queue.put_nowait(msg)
  85. return
  86. for name, channel in all_channels.items():
  87. if name == recipient:
  88. channel.msg_queue.put_nowait(msg)
  89. return
  90. raise RuntimeError(f"unknown recipient {recipient}")
  91. async def process_channel(channel: Channel) -> None:
  92. while True:
  93. msg = await channel.msg_queue.get()
  94. for client in channel.clients_by_host.values():
  95. if msg.startswith(f":{client.nickname} PRIVMSG".encode("utf-8")):
  96. continue
  97. client.msg_queue.put_nowait(msg)
  98. def assert_registered(client: Client, msg: List[str]) -> None:
  99. if client.registered:
  100. return
  101. raise RuntimeError(f"{msg[0]} command issued before client fully registered")
  102. def assert_argc(xs: List[str], i: int) -> None:
  103. if len(xs) == i:
  104. return
  105. raise RuntimeError(f"{xs[0]} had {len(xs)} arguments (expected {i})")
  106. def parse_irc_msg(raw_msg: str) -> List[str]:
  107. tokens: List[str] = []
  108. raw_tokens: List[str] = raw_msg.decode("utf-8").split(" ")
  109. for i, token in enumerate(raw_tokens):
  110. if token.startswith(":"):
  111. trailing = token[1:] + " ".join(raw_tokens[i + 1 :])
  112. tokens.append(trailing)
  113. break
  114. tokens.append(token)
  115. if len(tokens) > 0:
  116. tokens[-1] = tokens[-1].strip()
  117. else:
  118. raise RuntimeError("empty message")
  119. return tokens
  120. async def handle_writer(client: Client) -> None:
  121. while True:
  122. msg = await client.msg_queue.get()
  123. client.writer.write(msg)
  124. await client.writer.drain()
  125. async def register_client(reader, writer):
  126. client = Client(
  127. hostname=writer.get_extra_info("peername")[0],
  128. reader=reader,
  129. writer=writer,
  130. msg_queue=asyncio.Queue(),
  131. )
  132. asyncio.create_task(handle_reader(client))
  133. asyncio.create_task(handle_writer(client))
  134. async def main():
  135. bind_addr = os.getenv("BIND_ADDR") or "0.0.0.0"
  136. port = os.getenv("PORT") or 6667
  137. server = await asyncio.start_server(
  138. register_client,
  139. bind_addr,
  140. port,
  141. reuse_port=True,
  142. )
  143. logger.info(f"Listening on {bind_addr}:{port}...")
  144. async with server:
  145. await server.serve_forever()
  146. if __name__ == "__main__":
  147. asyncio.run(main())