|
|
- from dataclasses import dataclass, field, replace
- from typing import List, Union
-
- from gbso.cpu.cpu import CPU
- from gbso.cpu.regs import R8
- from gbso.program.program import Program
- from gbso.cpu.state import CPUState
-
-
- # TODO: Support 16-bit memory/register outputs
- Output = Union[R8, int]
-
-
- @dataclass
- class TestCase:
- __test__ = False
- state: CPUState = field(default_factory=CPUState)
-
-
- # TODO: Allow p_cpu to be computed AOT
- # TODO: Add penalty for undefined behavior (uninitialized register/memory reads)
- def eq_on_testcase(
- p: Program, q: Program, case: TestCase, outputs: List[Output]
- ) -> int:
- p_cpu = p.execute(case.state.copy())
- q_cpu = q.execute(case.state.copy())
- return sum([eq_on_output(o, p_cpu, q_cpu) for o in outputs])
-
-
- def eq_on_output(o: Output, p_cpu: CPU, q_cpu: CPU) -> int:
- if type(o) == R8:
- delta = eq_8bit(p_cpu.state.reg8[o], q_cpu.state.reg8[o])
- elif type(o) == int:
- delta = eq_8bit(p_cpu.state.memory[o], q_cpu.state.memory[o])
- else:
- raise TypeError(f"unknown output type {type(o)}")
- return delta
-
-
- # Counts differing bits between two 8-bit values
- def eq_8bit(x: int, y: int) -> int:
- delta = 0
- for i in range(8):
- mask = 1 << i
- if x & mask != y & mask:
- delta += 1
- return delta
|