Spaces:
Sleeping
Sleeping
| import random | |
| from typing import Optional, List, Dict | |
| from ...environment import VerifiableEnvironment | |
| class Circuit_Environment(VerifiableEnvironment) : | |
| prompt_template = \ | |
| r"""There are {N} boolean (0/1) values x[0], x[1], ..., x[{N_minus_1}]. | |
| Given a Boolean expression (where `&` is bitwise AND, `|` is bitwise OR, and `^` is bitwise XOR): {expression} | |
| Please find any solution x[0], x[1], ..., x[{N_minus_1}] that makes the expression evaluate to 1. | |
| Output Format: Your final answer should be a single line containing x[0], x[1], ..., x[{N_minus_1}], separated by **spaces**. | |
| Example: `{N_boolean}` (do **NOT** include quotes or backticks).""" | |
| def __init__(self, | |
| binary_ops_probs : Dict[str, float] = None, | |
| wrong_format : float = -1.0, invalid_solution : float = -0.5, correct_solution : float = +1.0, wrong_solution : float = 0.0, | |
| **kwargs) : | |
| """ | |
| Initialize the Circuit_Environment instance. | |
| """ | |
| super().__init__(**kwargs) | |
| if binary_ops_probs is None : | |
| binary_ops_probs = { | |
| "&" : 0.25, | |
| "|" : 0.25, | |
| "^" : 0.5, | |
| } | |
| assert abs(sum(binary_ops_probs.values()) - 1.0) < 1E-8, "binary_ops_probs values should sum to 1" | |
| self.binary_ops_probs = binary_ops_probs | |
| self.rewards = { | |
| "wrong_format" : wrong_format, | |
| "invalid_solution" : invalid_solution, | |
| "correct_solution" : correct_solution, | |
| "wrong_solution" : wrong_solution, | |
| } | |
| def _generate(self) -> None : | |
| assert "N" in self.parameter, "N is required in parameter" | |
| N = self.parameter["N"] | |
| assert N >= 2, "N should be greater than or equal to 2" | |
| assert "M" in self.parameter, "M is required in parameter" | |
| M = self.parameter["M"] | |
| assert M >= N, "M should be greater than or equal to N" | |
| binary_ops, binary_probs = zip(*self.binary_ops_probs.items()) | |
| while True : | |
| x = [random.randint(0, 1) for i in range(N)] | |
| def build_tree(n) : | |
| if n == 1 : | |
| index = random.randint(0, N - 1) | |
| return index, x[index] | |
| left_n = random.randint(1, n - 1) | |
| right_n = n - left_n | |
| left_tree, left_value = build_tree(left_n) | |
| right_tree, right_value = build_tree(right_n) | |
| op = random.choices(binary_ops, weights = binary_probs, k = 1)[0] | |
| if op == "&" : | |
| value = left_value & right_value | |
| elif op == "|" : | |
| value = left_value | right_value | |
| elif op == "^" : | |
| value = left_value ^ right_value | |
| else : | |
| raise ValueError("Invalid operator") | |
| return (left_tree, op, right_tree), value | |
| tree, value = build_tree(M) | |
| if value == 1 : | |
| self.parameter["reference_answer"] = " ".join(map(str, x)) | |
| self.parameter["tree"] = tree | |
| break | |
| def build_expression(self, tree) : | |
| if isinstance(tree, int) : | |
| return "x[{}]".format(tree) | |
| left_tree, op, right_tree = tree | |
| return "({} {} {})".format(self.build_expression(left_tree), op, self.build_expression(right_tree)) | |
| def _prompt_generate(self) -> str : | |
| N = self.parameter["N"] | |
| return self.prompt_template.format( | |
| N = N, | |
| N_minus_1 = N - 1, | |
| expression = self.build_expression(self.parameter["tree"])[1 : -1], | |
| N_boolean = " ".join(str(i % 2) for i in range(self.parameter["N"])), | |
| ) | |
| def _process(self, answer : Optional[str]) -> Optional[List] : | |
| if answer is not None : | |
| answer = answer.strip() | |
| try : | |
| answer_array = list(map(int, answer.split())) | |
| return answer_array | |
| except ValueError : | |
| return None # Invalid answer format | |
| else : | |
| return None # Invalid answer format | |
| def scorer(self, output : str) -> float : | |
| processed_result = self.processor(output) | |
| if processed_result is not None : | |
| assert isinstance(processed_result, list), "processed_result should be a list" | |
| x = processed_result | |
| if len(x) != self.parameter["N"] : | |
| return self.rewards["invalid_solution"] | |
| if not all(xi in (0, 1) for xi in x) : | |
| return self.rewards["invalid_solution"] | |
| def compute(tree) : | |
| if isinstance(tree, int) : | |
| return x[tree] | |
| left_tree, op, right_tree = tree | |
| left_value = compute(left_tree) | |
| right_value = compute(right_tree) | |
| if op == "&" : | |
| return left_value & right_value | |
| elif op == "|" : | |
| return left_value | right_value | |
| elif op == "^" : | |
| return left_value ^ right_value | |
| else : | |
| raise ValueError("Invalid operator") | |
| if compute(self.parameter["tree"]) == 1 : | |
| return self.rewards["correct_solution"] | |
| else : | |
| return self.rewards["wrong_solution"] | |
| else : | |
| return self.rewards["wrong_format"] |