File size: 5,511 Bytes
3bf8430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import random
from typing import Optional, List, Dict
from ...environment import VerifiableEnvironment


class Circuit_Environment(VerifiableEnvironment) :
    prompt_template = \
r"""There are {N} boolean (0/1) values x[0], x[1], ..., x[{N_minus_1}].

Given a Boolean expression (where `&` is bitwise AND, `|` is bitwise OR, and `^` is bitwise XOR): {expression}
Please find any solution x[0], x[1], ..., x[{N_minus_1}] that makes the expression evaluate to 1.

Output Format: Your final answer should be a single line containing x[0], x[1], ..., x[{N_minus_1}], separated by **spaces**.
Example: `{N_boolean}` (do **NOT** include quotes or backticks)."""

    def __init__(self,
                 binary_ops_probs : Dict[str, float] = None,
                 wrong_format : float = -1.0, invalid_solution : float = -0.5, correct_solution : float = +1.0, wrong_solution : float = 0.0,
                 **kwargs) :
        """
        Initialize the Circuit_Environment instance.
        """
        super().__init__(**kwargs)

        if binary_ops_probs is None :
            binary_ops_probs = {
                "&" : 0.25, 
                "|" : 0.25, 
                "^" : 0.5,
            }
        assert abs(sum(binary_ops_probs.values()) - 1.0) < 1E-8, "binary_ops_probs values should sum to 1"
        self.binary_ops_probs = binary_ops_probs

        self.rewards = {
            "wrong_format" : wrong_format,
            "invalid_solution" : invalid_solution,
            "correct_solution" : correct_solution,
            "wrong_solution" : wrong_solution,
        }
    
    def _generate(self) -> None :
        assert "N" in self.parameter, "N is required in parameter"
        N = self.parameter["N"]
        assert N >= 2, "N should be greater than or equal to 2"

        assert "M" in self.parameter, "M is required in parameter"
        M = self.parameter["M"]
        assert M >= N, "M should be greater than or equal to N"

        binary_ops, binary_probs = zip(*self.binary_ops_probs.items())

        while True :
            x = [random.randint(0, 1) for i in range(N)]

            def build_tree(n) :
                if n == 1 :
                    index = random.randint(0, N - 1)
                    return index, x[index]
                left_n = random.randint(1, n - 1)
                right_n = n - left_n
                left_tree, left_value = build_tree(left_n)
                right_tree, right_value = build_tree(right_n)
                op = random.choices(binary_ops, weights = binary_probs, k = 1)[0]
                if op == "&" :
                    value = left_value & right_value
                elif op == "|" :
                    value = left_value | right_value
                elif op == "^" :
                    value = left_value ^ right_value
                else :
                    raise ValueError("Invalid operator")
                return (left_tree, op, right_tree), value
            tree, value = build_tree(M)
            
            if value == 1 :
                self.parameter["reference_answer"] = " ".join(map(str, x))
                self.parameter["tree"] = tree
                break
    
    def build_expression(self, tree) :
        if isinstance(tree, int) :
            return "x[{}]".format(tree)
        left_tree, op, right_tree = tree
        return "({} {} {})".format(self.build_expression(left_tree), op, self.build_expression(right_tree))
    
    def _prompt_generate(self) -> str :
        N = self.parameter["N"]
        return self.prompt_template.format(
            N = N,
            N_minus_1 = N - 1,
            expression = self.build_expression(self.parameter["tree"])[1 : -1],
            N_boolean = " ".join(str(i % 2) for i in range(self.parameter["N"])),
        )

    def _process(self, answer : Optional[str]) -> Optional[List] :
        if answer is not None :
            answer = answer.strip()
            try :
                answer_array = list(map(int, answer.split()))
                return answer_array
            except ValueError :
                return None # Invalid answer format
        else :
            return None # Invalid answer format
    
    def scorer(self, output : str) -> float :
        processed_result = self.processor(output)
        if processed_result is not None :
            assert isinstance(processed_result, list), "processed_result should be a list"

            x = processed_result
            if len(x) != self.parameter["N"] :
                return self.rewards["invalid_solution"]
            if not all(xi in (0, 1) for xi in x) :
                return self.rewards["invalid_solution"]
            
            def compute(tree) :
                if isinstance(tree, int) :
                    return x[tree]
                left_tree, op, right_tree = tree
                left_value = compute(left_tree)
                right_value = compute(right_tree)
                if op == "&" :
                    return left_value & right_value
                elif op == "|" :
                    return left_value | right_value
                elif op == "^" :
                    return left_value ^ right_value
                else :
                    raise ValueError("Invalid operator")
            
            if compute(self.parameter["tree"]) == 1 :
                return self.rewards["correct_solution"]
            else :
                return self.rewards["wrong_solution"]
        else :
            return self.rewards["wrong_format"]