ZhiyuanZeng's picture
Upload folder using huggingface_hub
3bf8430 verified
import random
from typing import Optional, List, Dict
from ...environment import VerifiableEnvironment
class Circuit_Environment(VerifiableEnvironment) :
prompt_template = \
r"""There are {N} boolean (0/1) values x[0], x[1], ..., x[{N_minus_1}].
Given a Boolean expression (where `&` is bitwise AND, `|` is bitwise OR, and `^` is bitwise XOR): {expression}
Please find any solution x[0], x[1], ..., x[{N_minus_1}] that makes the expression evaluate to 1.
Output Format: Your final answer should be a single line containing x[0], x[1], ..., x[{N_minus_1}], separated by **spaces**.
Example: `{N_boolean}` (do **NOT** include quotes or backticks)."""
def __init__(self,
binary_ops_probs : Dict[str, float] = None,
wrong_format : float = -1.0, invalid_solution : float = -0.5, correct_solution : float = +1.0, wrong_solution : float = 0.0,
**kwargs) :
"""
Initialize the Circuit_Environment instance.
"""
super().__init__(**kwargs)
if binary_ops_probs is None :
binary_ops_probs = {
"&" : 0.25,
"|" : 0.25,
"^" : 0.5,
}
assert abs(sum(binary_ops_probs.values()) - 1.0) < 1E-8, "binary_ops_probs values should sum to 1"
self.binary_ops_probs = binary_ops_probs
self.rewards = {
"wrong_format" : wrong_format,
"invalid_solution" : invalid_solution,
"correct_solution" : correct_solution,
"wrong_solution" : wrong_solution,
}
def _generate(self) -> None :
assert "N" in self.parameter, "N is required in parameter"
N = self.parameter["N"]
assert N >= 2, "N should be greater than or equal to 2"
assert "M" in self.parameter, "M is required in parameter"
M = self.parameter["M"]
assert M >= N, "M should be greater than or equal to N"
binary_ops, binary_probs = zip(*self.binary_ops_probs.items())
while True :
x = [random.randint(0, 1) for i in range(N)]
def build_tree(n) :
if n == 1 :
index = random.randint(0, N - 1)
return index, x[index]
left_n = random.randint(1, n - 1)
right_n = n - left_n
left_tree, left_value = build_tree(left_n)
right_tree, right_value = build_tree(right_n)
op = random.choices(binary_ops, weights = binary_probs, k = 1)[0]
if op == "&" :
value = left_value & right_value
elif op == "|" :
value = left_value | right_value
elif op == "^" :
value = left_value ^ right_value
else :
raise ValueError("Invalid operator")
return (left_tree, op, right_tree), value
tree, value = build_tree(M)
if value == 1 :
self.parameter["reference_answer"] = " ".join(map(str, x))
self.parameter["tree"] = tree
break
def build_expression(self, tree) :
if isinstance(tree, int) :
return "x[{}]".format(tree)
left_tree, op, right_tree = tree
return "({} {} {})".format(self.build_expression(left_tree), op, self.build_expression(right_tree))
def _prompt_generate(self) -> str :
N = self.parameter["N"]
return self.prompt_template.format(
N = N,
N_minus_1 = N - 1,
expression = self.build_expression(self.parameter["tree"])[1 : -1],
N_boolean = " ".join(str(i % 2) for i in range(self.parameter["N"])),
)
def _process(self, answer : Optional[str]) -> Optional[List] :
if answer is not None :
answer = answer.strip()
try :
answer_array = list(map(int, answer.split()))
return answer_array
except ValueError :
return None # Invalid answer format
else :
return None # Invalid answer format
def scorer(self, output : str) -> float :
processed_result = self.processor(output)
if processed_result is not None :
assert isinstance(processed_result, list), "processed_result should be a list"
x = processed_result
if len(x) != self.parameter["N"] :
return self.rewards["invalid_solution"]
if not all(xi in (0, 1) for xi in x) :
return self.rewards["invalid_solution"]
def compute(tree) :
if isinstance(tree, int) :
return x[tree]
left_tree, op, right_tree = tree
left_value = compute(left_tree)
right_value = compute(right_tree)
if op == "&" :
return left_value & right_value
elif op == "|" :
return left_value | right_value
elif op == "^" :
return left_value ^ right_value
else :
raise ValueError("Invalid operator")
if compute(self.parameter["tree"]) == 1 :
return self.rewards["correct_solution"]
else :
return self.rewards["wrong_solution"]
else :
return self.rewards["wrong_format"]