RLVE_Gym / server /RLVE_Gym_environment.py
burtenshaw
return to non package install
9567311
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
RLVE-Gym Environment Implementation.
"""
from typing import Optional, Tuple
import random
from openenv_core.env_server.interfaces import Environment
from models import RlveGymState, RlveGymAction, RlveGymObservation
from server.Gym.environment import VerifiableEnvironment
from server.Gym.parameter_controller import ParameterController
from server.Gym.environments import identifier2environment
from server.Gym.parameter_controllers import identifier2controller
class RlveGymEnvironment(Environment):
"""
Wrap any verifiable environment from RLVE-Gym behind the OpenEnv ``Environment`` API.
"""
def __init__(
self,
environment_identifier: str = "Multiplication",
difficulty: int = 0,
answer_markers: Optional[Tuple[str, str]] = None,
initial_seed: int = 0,
):
"""Initialize the RLVE_Gym environment."""
self._state = RlveGymState(
seed=initial_seed,
problem_input=None,
num_samples=0,
sum_accuracy=0,
)
self.environment_identifier = environment_identifier
self.difficulty = difficulty
self.answer_markers = answer_markers
self.problem = None
def reset(self) -> RlveGymObservation:
"""
Reset the environment.
Returns:
problem_input: The generated problem input string (or None if generation failed)
verifier_result: None
success: Boolean indicating if the reset was successful
message: Message indicating the result of the reset
"""
if (self.environment_identifier not in identifier2environment) or (
self.environment_identifier not in identifier2controller
):
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Invalid environment identifier.",
reward=None,
)
if not (isinstance(self.difficulty, int) and self.difficulty >= 0):
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Difficulty should be a non-negative integer.",
reward=None,
)
if not (isinstance(self._state.seed, int) and self._state.seed >= 0):
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Seed should be a non-negative integer.",
reward=None,
)
try:
problem: VerifiableEnvironment = identifier2environment[self.environment_identifier](
answer_markers=self.answer_markers
)
except Exception as e:
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message=f"Failed to initialize environment: {e}",
reward=None,
)
controller: ParameterController = identifier2controller[self.environment_identifier]()
for _ in range(self.difficulty):
controller.update()
random.seed(self._state.seed)
parameter = random.choice(controller.get_parameter_list())
if problem.generator(seed=self._state.seed, parameter=parameter):
self._state.problem_input = problem.prompt_generator()
self.problem = problem
else:
self._state.problem_input = None
self.problem = None
self._state.seed += 1
self._state.num_samples = self._state.sum_accuracy = 0
if self.problem is not None:
return RlveGymObservation(
problem_input=self._state.problem_input,
verifier_result=None,
success=True,
message="Problem generated successfully.",
reward=None,
)
else:
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Problem generation failed. Please try decreasing difficulty or changing seed.",
reward=None,
)
def step(self, action: RlveGymAction) -> RlveGymObservation: # type: ignore[override]
"""
Execute a step in the environment by verifying the model output.
Args:
action: RlveGymAction containing the output to verify
Returns:
problem_input: The problem input string from the current state
verifier_result: Result of the verification containing accuracy and other metrics
success: Boolean indicating if the step was successful
message: Message indicating the result of the step
"""
if self.problem is None:
return RlveGymObservation(
problem_input=None,
verifier_result=None,
success=False,
message="Problem not ready. Please reset the environment.",
reward=None,
)
try:
verifier_result = self.problem.verifier(action.output)
except Exception as e:
return RlveGymObservation(
problem_input=self._state.problem_input,
verifier_result=None,
success=False,
message=f"Verification failed with error: {e}",
reward=None,
)
self._state.num_samples += 1
self._state.sum_accuracy += verifier_result["accuracy"]
return RlveGymObservation(
problem_input=self._state.problem_input,
verifier_result=verifier_result,
success=True,
message="Verification completed.",
reward=verifier_result["reward"],
)
@property
def state(self) -> RlveGymState:
"""
Get the current environment state.
Returns:
seed: The current random seed value for problem generation
problem_input: The generated problem input string (or None if generation failed)
num_samples: Number of samples taken so far
sum_accuracy: Sum of accuracies from verifications so far
"""
return self._state