Browse Source

Minor improvements

master
Forest Belton 2 years ago
parent
commit
97b70e2ef0
3 changed files with 34 additions and 15 deletions
  1. +12
    -1
      gbso/cpu/state.py
  2. +19
    -14
      gbso/optimize.py
  3. +3
    -0
      gbso/program/test_case.py

+ 12
- 1
gbso/cpu/state.py View File

@ -1,9 +1,11 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, List, Tuple, Union
from gbso.cpu.regs import R8
Loc = Union[R8, int]
@dataclass
class CPUState:
@ -19,3 +21,12 @@ class CPUState:
reg8=self.reg8.copy(),
memory=self.memory.copy(),
)
def with_vals(self, *init_vals: List[Tuple[Loc, int]]) -> "CPUState":
cpu = self.copy()
for loc, val in init_vals:
if type(loc) == R8:
cpu.reg8[loc] = val
else:
cpu.memory[loc] = val
return cpu

+ 19
- 14
gbso/optimize.py View File

@ -53,6 +53,7 @@ def cost_noperf(
class OptimizationParameters:
max_size: int
beta: float = DEFAULT_ANNEALING_CONSTANT
synthesize: bool = True
synthesis_iters: int = DEFAULT_SYNTHESIS_ITERS
optimize_iters: int = DEFAULT_OPTIMIZE_ITERS
num_candidates: int = DEFAULT_NUM_CANDIDATES
@ -118,21 +119,25 @@ def optimize(
outputs: List[Output],
params: OptimizationParameters,
) -> Program:
print("Synthesizing candidates...")
candidates = [
_optimize(
target_prgm,
test_cases,
outputs,
replace(params, cost_fn=cost_noperf),
num_iters=params.synthesis_iters,
init_prgm=create_random_program(params.max_size),
best_candidate = target_prgm
if params.synthesize:
print("Synthesizing candidates...")
candidates = [
_optimize(
target_prgm,
test_cases,
outputs,
replace(params, cost_fn=cost_noperf),
num_iters=params.synthesis_iters,
init_prgm=create_random_program(params.max_size),
)
for _ in range(params.num_candidates)
]
best_candidate = min(
candidates, key=lambda p: cost(target_prgm, test_cases, outputs, p)[0]
)
for _ in range(params.num_candidates)
]
best_candidate = min(
candidates, key=lambda p: cost(target_prgm, test_cases, outputs, p)[0]
)
print("Optimizing...")
return _optimize(
target_prgm,

+ 3
- 0
gbso/program/test_case.py View File

@ -39,6 +39,9 @@ def eq_on_output(o: Output, p_cpu: CPU, q_cpu: CPU) -> int:
# Counts differing bits between two 8-bit values
def eq_8bit(x: int, y: int) -> int:
if x == y:
return 0
delta = 0
for i in range(8):
mask = 1 << i

Loading…
Cancel
Save