diff --git a/gbso/cpu.py b/gbso/cpu.py index 3f9b39d..1cef3a3 100644 --- a/gbso/cpu.py +++ b/gbso/cpu.py @@ -1,54 +1,61 @@ from collections import defaultdict -from typing import Dict +from dataclasses import dataclass, field +from typing import Dict, Optional from gbso.regs import R16_HI, R16_LO, R8, R16 +@dataclass +class CPUState: + carry: int = 0 + cycles: int = 0 + sp: int = 0 + reg8: Dict[R8, int] = field(default_factory=lambda: defaultdict(lambda: 0)) + memory: Dict[int, int] = field(default_factory=lambda: defaultdict(lambda: 0)) + + class CPU: - carry: int - cycles: int - reg8: Dict[R8, int] - sp: int - memory: bytearray - - def __init__(self) -> None: - self.carry = 0 - self.cycles = 0 - self.reg8 = defaultdict(lambda: 0) - self.memory = bytearray(0xFFFF + 1) + state: CPUState + + def __init__(self, state: Optional[CPUState] = None) -> None: + if state is None: + state = CPUState() + self.state = state def get_reg8(self, r: R8) -> int: - return self.reg8[r] + return self.state.reg8[r] def get_reg16(self, rr: R16) -> int: if rr == R16.SP: - return self.sp + return self.state.sp - return (self.reg8[R16_HI[rr]] << 8) | self.reg8[R16_LO[rr]] + return (self.state.reg8[R16_HI[rr]] << 8) | self.state.reg8[R16_LO[rr]] def set_reg8(self, r: R8, n: int) -> None: - self.reg8[r] = n & 0xFF + self.state.reg8[r] = n & 0xFF def set_reg16(self, rr: R16, nn: int) -> None: if rr == R16.SP: - self.sp = nn & 0xFFFF + self.state.sp = nn & 0xFFFF return - self.reg8[R16_HI[rr]] = (nn >> 8) & 0xFF - self.reg8[R16_LO[rr]] = nn & 0xFF + self.state.reg8[R16_HI[rr]] = (nn >> 8) & 0xFF + self.state.reg8[R16_LO[rr]] = nn & 0xFF def get_mem8(self, nn: int) -> int: - return self.memory[nn & 0xFFFF] + return self.state.memory[nn & 0xFFFF] def get_mem16(self, nn: int) -> int: - return (self.memory[nn & 0xFFFF] << 8) | self.memory[(nn + 1) & 0xFFFF] + return (self.state.memory[nn & 0xFFFF] << 8) | self.state.memory[ + (nn + 1) & 0xFFFF + ] def set_mem8(self, nn: int, n: int) -> None: - self.memory[nn & 0xFFFF] = n & 0xFF + self.state.memory[nn & 0xFFFF] = n & 0xFF def set_mem16(self, nn: int, nn1: int) -> None: - self.memory[nn & 0xFFFF] = (nn1 >> 8) & 0xFF - self.memory[(nn + 1) & 0xFFFF] = nn1 & 0xFF + self.state.memory[nn & 0xFFFF] = (nn1 >> 8) & 0xFF + self.state.memory[(nn + 1) & 0xFFFF] = nn1 & 0xFF def deref_hl(self) -> int: return self.get_mem8(self.get_reg16(R16.HL))