Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| RLVE-Gym Environment Implementation. | |
| """ | |
| from typing import Optional, Tuple | |
| import random | |
| from openenv_core.env_server.interfaces import Environment | |
| from models import RlveGymState, RlveGymAction, RlveGymObservation | |
| from server.Gym.environment import VerifiableEnvironment | |
| from server.Gym.parameter_controller import ParameterController | |
| from server.Gym.environments import identifier2environment | |
| from server.Gym.parameter_controllers import identifier2controller | |
| class RlveGymEnvironment(Environment): | |
| """ | |
| Wrap any verifiable environment from RLVE-Gym behind the OpenEnv ``Environment`` API. | |
| """ | |
| def __init__( | |
| self, | |
| environment_identifier: str = "Multiplication", | |
| difficulty: int = 0, | |
| answer_markers: Optional[Tuple[str, str]] = None, | |
| initial_seed: int = 0, | |
| ): | |
| """Initialize the RLVE_Gym environment.""" | |
| self._state = RlveGymState( | |
| seed=initial_seed, | |
| problem_input=None, | |
| num_samples=0, | |
| sum_accuracy=0, | |
| ) | |
| self.environment_identifier = environment_identifier | |
| self.difficulty = difficulty | |
| self.answer_markers = answer_markers | |
| self.problem = None | |
| def reset(self) -> RlveGymObservation: | |
| """ | |
| Reset the environment. | |
| Returns: | |
| problem_input: The generated problem input string (or None if generation failed) | |
| verifier_result: None | |
| success: Boolean indicating if the reset was successful | |
| message: Message indicating the result of the reset | |
| """ | |
| if (self.environment_identifier not in identifier2environment) or ( | |
| self.environment_identifier not in identifier2controller | |
| ): | |
| return RlveGymObservation( | |
| problem_input=None, | |
| verifier_result=None, | |
| success=False, | |
| message="Invalid environment identifier.", | |
| reward=None, | |
| ) | |
| if not (isinstance(self.difficulty, int) and self.difficulty >= 0): | |
| return RlveGymObservation( | |
| problem_input=None, | |
| verifier_result=None, | |
| success=False, | |
| message="Difficulty should be a non-negative integer.", | |
| reward=None, | |
| ) | |
| if not (isinstance(self._state.seed, int) and self._state.seed >= 0): | |
| return RlveGymObservation( | |
| problem_input=None, | |
| verifier_result=None, | |
| success=False, | |
| message="Seed should be a non-negative integer.", | |
| reward=None, | |
| ) | |
| try: | |
| problem: VerifiableEnvironment = identifier2environment[self.environment_identifier]( | |
| answer_markers=self.answer_markers | |
| ) | |
| except Exception as e: | |
| return RlveGymObservation( | |
| problem_input=None, | |
| verifier_result=None, | |
| success=False, | |
| message=f"Failed to initialize environment: {e}", | |
| reward=None, | |
| ) | |
| controller: ParameterController = identifier2controller[self.environment_identifier]() | |
| for _ in range(self.difficulty): | |
| controller.update() | |
| random.seed(self._state.seed) | |
| parameter = random.choice(controller.get_parameter_list()) | |
| if problem.generator(seed=self._state.seed, parameter=parameter): | |
| self._state.problem_input = problem.prompt_generator() | |
| self.problem = problem | |
| else: | |
| self._state.problem_input = None | |
| self.problem = None | |
| self._state.seed += 1 | |
| self._state.num_samples = self._state.sum_accuracy = 0 | |
| if self.problem is not None: | |
| return RlveGymObservation( | |
| problem_input=self._state.problem_input, | |
| verifier_result=None, | |
| success=True, | |
| message="Problem generated successfully.", | |
| reward=None, | |
| ) | |
| else: | |
| return RlveGymObservation( | |
| problem_input=None, | |
| verifier_result=None, | |
| success=False, | |
| message="Problem generation failed. Please try decreasing difficulty or changing seed.", | |
| reward=None, | |
| ) | |
| def step(self, action: RlveGymAction) -> RlveGymObservation: # type: ignore[override] | |
| """ | |
| Execute a step in the environment by verifying the model output. | |
| Args: | |
| action: RlveGymAction containing the output to verify | |
| Returns: | |
| problem_input: The problem input string from the current state | |
| verifier_result: Result of the verification containing accuracy and other metrics | |
| success: Boolean indicating if the step was successful | |
| message: Message indicating the result of the step | |
| """ | |
| if self.problem is None: | |
| return RlveGymObservation( | |
| problem_input=None, | |
| verifier_result=None, | |
| success=False, | |
| message="Problem not ready. Please reset the environment.", | |
| reward=None, | |
| ) | |
| try: | |
| verifier_result = self.problem.verifier(action.output) | |
| except Exception as e: | |
| return RlveGymObservation( | |
| problem_input=self._state.problem_input, | |
| verifier_result=None, | |
| success=False, | |
| message=f"Verification failed with error: {e}", | |
| reward=None, | |
| ) | |
| self._state.num_samples += 1 | |
| self._state.sum_accuracy += verifier_result["accuracy"] | |
| return RlveGymObservation( | |
| problem_input=self._state.problem_input, | |
| verifier_result=verifier_result, | |
| success=True, | |
| message="Verification completed.", | |
| reward=verifier_result["reward"], | |
| ) | |
| def state(self) -> RlveGymState: | |
| """ | |
| Get the current environment state. | |
| Returns: | |
| seed: The current random seed value for problem generation | |
| problem_input: The generated problem input string (or None if generation failed) | |
| num_samples: Number of samples taken so far | |
| sum_accuracy: Sum of accuracies from verifications so far | |
| """ | |
| return self._state | |