from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
MAX_MESSAGE_SIZE = 512
|
|
|
|
|
|
class ParsingError(Exception):
|
|
message: Optional[str]
|
|
|
|
def __init__(self, message: Optional[str] = None) -> None:
|
|
self.message = message
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
cmd: str
|
|
args: List[str]
|
|
prefix: str = "localhost"
|
|
|
|
def encode(self) -> bytes:
|
|
prefix = self.prefix
|
|
if prefix != "":
|
|
prefix = f":{prefix} "
|
|
|
|
# TODO: Raise exception if formatted message exceeds 512 bytes
|
|
return f"{prefix}{self.cmd} {' '.join(self.args)}\r\n".encode("utf-8")
|
|
|
|
|
|
def parse_message(raw: bytes) -> Message:
|
|
if len(raw) > MAX_MESSAGE_SIZE:
|
|
raise ParsingError(
|
|
f"Message is {len(raw)} bytes, larger than allowed {MAX_MESSAGE_SIZE}"
|
|
)
|
|
|
|
if not raw.endswith(b"\r\n"):
|
|
raise ParsingError("Message does not terminate in CRLF")
|
|
|
|
tokens: List[str] = []
|
|
raw_tokens: List[str] = raw.decode("utf-8").split(" ")
|
|
|
|
for i, token in enumerate(raw_tokens):
|
|
if token.startswith(":"):
|
|
trailing = token[1:] + " " + " ".join(raw_tokens[i + 1 :])
|
|
tokens.append(trailing)
|
|
break
|
|
tokens.append(token)
|
|
|
|
if len(tokens) == 0:
|
|
raise ParsingError("Message has no command")
|
|
|
|
cmd = tokens[0].upper()
|
|
tokens[-1] = tokens[-1].strip()
|
|
return Message(cmd=cmd, args=tokens[1:])
|