from dataclasses import dataclass, field, replace
|
|
from typing import List, Union
|
|
|
|
from gbso.cpu.cpu import CPU
|
|
from gbso.cpu.regs import R8
|
|
from gbso.program.program import Program
|
|
from gbso.cpu.state import CPUState
|
|
|
|
|
|
# TODO: Support 16-bit memory/register outputs
|
|
Output = Union[R8, int]
|
|
|
|
|
|
@dataclass
|
|
class TestCase:
|
|
__test__ = False
|
|
state: CPUState = field(default_factory=CPUState)
|
|
|
|
|
|
# TODO: Allow p_cpu to be computed AOT
|
|
# TODO: Add penalty for undefined behavior (uninitialized register/memory reads)
|
|
def eq_on_testcase(
|
|
p: Program, q: Program, case: TestCase, outputs: List[Output]
|
|
) -> int:
|
|
p_cpu = p.execute(case.state.copy())
|
|
q_cpu = q.execute(case.state.copy())
|
|
return sum([eq_on_output(o, p_cpu, q_cpu) for o in outputs])
|
|
|
|
|
|
def eq_on_output(o: Output, p_cpu: CPU, q_cpu: CPU) -> int:
|
|
if type(o) == R8:
|
|
delta = eq_8bit(p_cpu.state.reg8[o], q_cpu.state.reg8[o])
|
|
elif type(o) == int:
|
|
delta = eq_8bit(p_cpu.state.memory[o], q_cpu.state.memory[o])
|
|
else:
|
|
raise TypeError(f"unknown output type {type(o)}")
|
|
return delta
|
|
|
|
|
|
# Counts differing bits between two 8-bit values
|
|
def eq_8bit(x: int, y: int) -> int:
|
|
if x == y:
|
|
return 0
|
|
|
|
delta = 0
|
|
for i in range(8):
|
|
mask = 1 << i
|
|
if x & mask != y & mask:
|
|
delta += 1
|
|
return delta
|