from dataclasses import dataclass, field, replace from typing import List, Union from gbso.cpu.cpu import CPU from gbso.cpu.regs import R8 from gbso.program.program import Program from gbso.cpu.state import CPUState # TODO: Support 16-bit memory/register outputs Output = Union[R8, int] @dataclass class TestCase: __test__ = False state: CPUState = field(default_factory=CPUState) # TODO: Allow p_cpu to be computed AOT # TODO: Add penalty for undefined behavior (uninitialized register/memory reads) def eq_on_testcase( p: Program, q: Program, case: TestCase, outputs: List[Output] ) -> int: p_cpu = p.execute(case.state.copy()) q_cpu = q.execute(case.state.copy()) return sum([eq_on_output(o, p_cpu, q_cpu) for o in outputs]) def eq_on_output(o: Output, p_cpu: CPU, q_cpu: CPU) -> int: if type(o) == R8: delta = eq_8bit(p_cpu.state.reg8[o], q_cpu.state.reg8[o]) elif type(o) == int: delta = eq_8bit(p_cpu.state.memory[o], q_cpu.state.memory[o]) else: raise TypeError(f"unknown output type {type(o)}") return delta # Counts differing bits between two 8-bit values def eq_8bit(x: int, y: int) -> int: delta = 0 for i in range(8): mask = 1 << i if x & mask != y & mask: delta += 1 return delta