from math import log from random import random from typing import List, Optional, Tuple from gbso.program.test_case import Output, TestCase, eq_on_testcase from gbso.program.mutate import mutate_program from gbso.program.program import Program EPSILON = 0.00001 DEFAULT_NUM_ITERS = 1_000_000 DEFAULT_PROB_OPCODE = 0.25 DEFAULT_PROB_OPERAND = 0.25 DEFAULT_PROB_SWAP = 0.25 DEFAULT_PROB_INSN = 0.25 DEFAULT_PROB_INSN_UNUSED = 0.1 def cost(orig_prgm, test_cases, outputs, prgm) -> Tuple[int, bool]: # Since each instruction executes in 4*k cycles (for some k), this can have # the undesirable effect of performance improvements being weighted much # higher than correctness. This hurts convergence pretty badly, so we scale # by 1/4 to compensate. perf = (prgm.perf() - orig_prgm.perf()) / 4.0 eq = 0 for test_case in test_cases: eq += eq_on_testcase(orig_prgm, prgm, test_case, outputs) return perf + eq, eq == 0 def optimize( target_prgm: Program, max_size: int, test_cases: List[TestCase], outputs: List[Output], beta: int = 0.5, # How far away in cost you are allowed to search init_prgm: Optional[Program] = None, num_iters: int = DEFAULT_NUM_ITERS, prob_opcode: float = DEFAULT_PROB_OPCODE, prob_operand: float = DEFAULT_PROB_OPERAND, prob_swap: float = DEFAULT_PROB_SWAP, prob_insn: float = DEFAULT_PROB_INSN, prob_insn_unused: float = DEFAULT_PROB_INSN_UNUSED, ) -> Program: padded_prgm = (init_prgm or target_prgm).pad(max_size) last_prgm = padded_prgm last_cost, _last_eq = cost(target_prgm, test_cases, outputs, last_prgm) best_prgm = target_prgm.pad(max_size) best_cost = 0 num_candidates = 0 for _ in range(num_iters): candidate_prgm = mutate_program( last_prgm, prob_opcode, prob_operand, prob_swap, prob_insn, prob_insn_unused ) candidate_cost, candidate_eq = cost( target_prgm, test_cases, outputs, candidate_prgm ) if candidate_cost < best_cost and candidate_eq: best_prgm = candidate_prgm best_cost = candidate_cost num_candidates += 1 if candidate_cost < last_cost - log(random()) / beta: last_prgm = candidate_prgm last_cost = candidate_cost print(f"Optimization complete. Total candidates: {num_candidates}") return best_prgm