diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index 4af8b0b0..577894cf 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -2,14 +2,14 @@ import os import uuid from dataclasses import MISSING, asdict, dataclass, field, fields -from typing import Any, Dict +from typing import Any, Dict, List from jinja2 import Environment, Template from debug_gym.agents.history_tracker import HistoryTracker from debug_gym.gym.envs.env import EnvInfo, RepoEnv from debug_gym.gym.utils import filter_non_utf8 -from debug_gym.llms.base import LLM +from debug_gym.llms.base import LLM, LLMResponse from debug_gym.llms.utils import trim from debug_gym.logger import DebugGymLogger @@ -27,8 +27,8 @@ def register_agent(cls): @dataclass class AgentArgs: - system_prompt: str | None = None - instance_prompt: str | None = None + system_prompt: str = "" + instance_prompt: str = "Instructions: {{ info.instructions }}" max_steps: int = 100 max_history_token_cutoff: int = -1 max_history_steps_cutoff: int = -1 @@ -83,8 +83,6 @@ def to_dict(self) -> Dict[str, Any]: class BaseAgent: name: str = None args_class = AgentArgs - system_prompt: str = "" - instance_prompt: str = "Instructions: {{ info.instructions }}" def __init__( self, @@ -95,14 +93,10 @@ def __init__( self.args = self.args_class.make(agent_args or {}) self.history = HistoryTracker() self.logger = logger or DebugGymLogger("debug-gym") - self.llm = None + self.llm = llm self.env = None - - # Override prompts if provided in args - if self.args.system_prompt is not None: - self.system_prompt = str(self.args.system_prompt) - if self.args.instance_prompt is not None: - self.instance_prompt = str(self.args.instance_prompt) + self.system_prompt = str(self.args.system_prompt) + self.instance_prompt = str(self.args.instance_prompt) @staticmethod def to_pretty_json(value): @@ -238,17 +232,89 @@ def should_stop(self, step: int, info: EnvInfo): reason = "max_steps reached" return should_stop, reason - def run(self, env: RepoEnv, llm: LLM, debug=False): - self.env = env - self.llm = llm + def init(self, info: EnvInfo) -> None: + """Initialize the agent with environment + + Args: + info: The environment info to interact with. + """ + self.history.init( + self.build_system_prompt(info), self.build_instance_prompt(info), info + ) + + self.logger.info( + "Available tools (in LLM's tool calling format):\n" + f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n" + ) + + def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]: + """Execute a single agent step (LLM decision only). + + Args: + info: Current environment info. + + Returns: + LLMResponse with the agent's decision. + """ + messages = self.build_prompt(info) + return self.llm(messages, info.tools) + + def execute_action(self, llm_response: LLMResponse | List[LLMResponse]) -> EnvInfo: + next_info = self.env.step( + llm_response.tool, + llm_response.response, + llm_response.reasoning_response, + ) + self.history.step(next_info, llm_response) + return next_info + + def build_trajectory(self) -> Dict[str, Any]: + """Return the trajectory as a JSON-serializable dict without writing it.""" + tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools] + json_output = { + "problem": self.env.task_name, + "config": self.args.to_dict(), + "tools": self.llm.define_tools(self.env.tools) if self.llm else tools, + "uuid": self.args.uuid, + "success": self.env.resolved, + "log": [], + "agent_type": self.__class__.__name__, + "logger": str(self.logger.log_file), + } + for step_id in range(len(self.history)): + step_json = self.history.json(step_id) + json_output["log"].append(step_json) + return json_output + + def run( + self, + env: RepoEnv, + debug: bool = False, + reset_env: bool = True, + ) -> Dict[str, Any]: + """Run the agent loop until termination or max steps. + + Args: + env: The environment to interact with. + debug: Whether to drop into debugger after each LLM call. + reset_env: Whether to reset the environment (default True). + + Returns: + The trajectory as a JSON-serializable dict. + """ info = None step = 0 + # assign the env + self.env = env + try: - info = self.env.reset() - self.history.init( - self.build_system_prompt(info), self.build_instance_prompt(info), info - ) + if reset_env: + info = env.reset() + else: + info = env.info + + self.init(info) if info.resolved: self.logger.report_progress( @@ -259,12 +325,7 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): max_score=info.max_score, status="resolved", ) - return self._build_trajectory() - - self.logger.info( - "Available tools (in LLM's tool calling format):\n" - f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n" - ) + return self.build_trajectory() highscore = info.score should_stop = False @@ -273,18 +334,12 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): while not should_stop: self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n") - messages = self.build_prompt(info) - llm_response = self.llm(messages, info.tools) + agent_response = self.step(info) + info = self.execute_action(agent_response) if debug: breakpoint() - info = self.env.step( - llm_response.tool, - llm_response.response, - llm_response.reasoning_response, - ) - self.history.step(info, llm_response) should_stop, reason = self.should_stop(step + 1, info) status = ( "resolved" @@ -299,7 +354,6 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): self.logger.info(msg) step += 1 - # keep progress bar running until max_steps is reached self.logger.report_progress( problem_id=env.task_name, step=step, @@ -308,9 +362,8 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): max_score=info.max_score, status=status, ) - return self._build_trajectory() + return self.build_trajectory() except Exception as e: - # report any error that happens during the run self.logger.report_progress( problem_id=env.task_name, step=step, @@ -321,24 +374,6 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): ) raise e - def _build_trajectory(self) -> Dict[str, Any]: - """Return the trajectory as a JSON-serializable dict without writing it.""" - tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools] - json_output = { - "problem": self.env.task_name, - "config": self.args.to_dict(), - "tools": self.llm.define_tools(self.env.tools) if self.llm else tools, - "uuid": self.args.uuid, - "success": self.env.resolved, - "log": [], - "agent_type": self.__class__.__name__, - "logger": str(self.logger.log_file), - } - for step_id in range(len(self.history)): - step_json = self.history.json(step_id) - json_output["log"].append(step_json) - return json_output - def create_agent(config: Dict[str, Any], **kwargs) -> BaseAgent: """Create an agent from the config dictionary.""" diff --git a/debug_gym/agents/froggy_agent.py b/debug_gym/agents/froggy_agent.py index 3e531524..69d81f5f 100644 --- a/debug_gym/agents/froggy_agent.py +++ b/debug_gym/agents/froggy_agent.py @@ -12,13 +12,13 @@ @dataclass class FroggyAgentArgs(AgentArgs): show_current_breakpoints: bool = False + system_prompt: str = "{{ agent._default_system_prompt(info) }}" @register_agent class FroggyAgent(BaseAgent): name: str = "froggy_agent" args_class = FroggyAgentArgs - system_prompt: str = "{{ agent._default_system_prompt(info) }}" def shortcut_features(self): features = [] diff --git a/debug_gym/agents/solution_agent.py b/debug_gym/agents/solution_agent.py index 66f1d11e..87f97317 100644 --- a/debug_gym/agents/solution_agent.py +++ b/debug_gym/agents/solution_agent.py @@ -1,88 +1,66 @@ +from typing import Any, Dict + from debug_gym.agents.base_agent import BaseAgent, register_agent +from debug_gym.gym.envs.env import EnvInfo, RepoEnv from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLM, LLMResponse @register_agent class AgentSolution(BaseAgent): + """Agent that applies the gold patch and submits - used for testing environments.""" + name: str = "solution_agent" - def _report_progress(self, task_name, info, status): - self.logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=getattr(info, "score", 0), - max_score=getattr(info, "max_score", 0), - status=status, - ) + def __init__( + self, + llm: LLM | None = None, + **kwargs, + ): + super().__init__(llm=llm, **kwargs) def _env_implements_apply_gold_patch(self): """Fail early if the environment does not implement apply_gold_patch.""" return hasattr(self.env, "apply_gold_patch") - def run(self, env, llm=None, debug=False): - self.env = env - info = None - try: - if not self._env_implements_apply_gold_patch(): - raise NotImplementedError( - f"The environment {type(self.env)} is not compatible with SolutionAgent." - " Check the README.md to see which environments are compatible." - ) - - info = self.env.reset() - - if info.resolved is True: - self._report_progress(env.task_name, info, "resolved") - return True - - self.logger.info(f"Score: {info.score}/{info.max_score or '-'}") - - if env.has_tool("pdb"): - # Make a simple pdb call to make sure it is working. - action = ToolCall( - name="pdb", id="pdb", arguments={"command": "help help"} - ) - pdb_help_info = self.env.step(action, None, None) - assert "h(elp)" in pdb_help_info.step_observation.observation, ( - "PDB command did not return expected help message.\n" - f"{pdb_help_info.step_observation.observation}" - ) - - # Send a pdb continue command, and check the output matches the one from env.reset. - action = ToolCall( - name="pdb", id="pdb", arguments={"command": "continue"} - ) - pdb_continue_info = self.env.step(action, None, None) - - pdb_observation = pdb_continue_info.step_observation.observation - expected_messages = [ - "Reached the end of the program. Restarting the debugging session.", - "Uncaught exception. Entering post mortem debugging", - ] - reset_observation = info.step_observation.observation - if reset_observation.splitlines(): - expected_messages.append(reset_observation.splitlines()[-1]) - - assert any( - msg in pdb_observation for msg in expected_messages - ), f"PDB command did not return expected continue message.\n{pdb_observation}" - - self.env.apply_gold_patch() - - if debug: - breakpoint() - - action = ToolCall(name="submit", id="submit", arguments={}) - info = self.env.step(action, None, None) + def _run_pdb_sanity_checks(self, info: EnvInfo): + """Run PDB sanity checks if PDB tool is available.""" + if not self.env.has_tool("pdb"): + return + + # Make a simple pdb call to make sure it is working. + action = ToolCall(name="pdb", id="pdb", arguments={"command": "help help"}) + pdb_help_info = self.env.step(action, None, None) + assert "h(elp)" in pdb_help_info.step_observation.observation, ( + "PDB command did not return expected help message.\n" + f"{pdb_help_info.step_observation.observation}" + ) - self.logger.info(f"Score: {info.score}/{info.max_score or '-'}") - assert info.resolved, ( - "The task is not done after applying the gold patch.\n" - f"{info.step_observation.observation}" - ) - self._report_progress(env.task_name, info, "resolved") - except Exception: - self._report_progress(env.task_name, info, "error") - raise - return info.resolved + # Send a pdb continue command, and check the output matches the one from env.reset. + action = ToolCall(name="pdb", id="pdb", arguments={"command": "continue"}) + pdb_continue_info = self.env.step(action, None, None) + + pdb_observation = pdb_continue_info.step_observation.observation + expected_messages = [ + "Reached the end of the program. Restarting the debugging session.", + "Uncaught exception. Entering post mortem debugging", + ] + reset_observation = info.step_observation.observation + if reset_observation.splitlines(): + expected_messages.append(reset_observation.splitlines()[-1]) + + assert any( + msg in pdb_observation for msg in expected_messages + ), f"PDB command did not return expected continue message.\n{pdb_observation}" + + def step(self, info: EnvInfo) -> EnvInfo: + tool_call = ToolCall(name="submit", id="submit", arguments={}) + return LLMResponse([], tool=tool_call) + + def execute_action(self, llm_response, **kwargs): + self.env.apply_gold_patch() + info = self.env.step(llm_response.tool, None, None) + return info + + def init(self, info: EnvInfo) -> None: + self._run_pdb_sanity_checks(info) diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index 9dc54ba8..d95ee7e0 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -139,7 +139,7 @@ def save_patch(env, problem_path: Path, logger: DebugGymLogger): def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger): """Persist the agent trajectory to disk.""" problem_path.mkdir(parents=True, exist_ok=True) - trajectory = agent._build_trajectory() + trajectory = agent.build_trajectory() json_file = problem_path / "trajectory.json" with open(json_file, "w") as f: json.dump(trajectory, f, indent=4) diff --git a/debug_gym/gym/terminals/kubernetes.py b/debug_gym/gym/terminals/kubernetes.py index 0fd9e333..8c9a51ac 100644 --- a/debug_gym/gym/terminals/kubernetes.py +++ b/debug_gym/gym/terminals/kubernetes.py @@ -706,11 +706,9 @@ def copy_content(self, src: str | Path, target: str | Path | None = None) -> Non cmd = ["kubectl"] if self.kube_config: cmd.extend(["--kubeconfig", self.kube_config]) - # restore previous behavior if os.path.isdir(src): src = f"{src}/." - cmd.extend( [ "cp", diff --git a/debug_gym/llms/azure_openai.py b/debug_gym/llms/azure_openai.py index ad318b5c..ef32e802 100644 --- a/debug_gym/llms/azure_openai.py +++ b/debug_gym/llms/azure_openai.py @@ -22,11 +22,11 @@ class AzureOpenAILLM(OpenAILLM): def __init__( self, model_name, + llm_config, logger=None, - llm_config=None, runtime_generate_kwargs=None, ): - super().__init__(model_name, logger, llm_config, runtime_generate_kwargs) + super().__init__(model_name, llm_config, logger, runtime_generate_kwargs) self._client = None self._client_created_at = 0 diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index 6a7a891d..40254d8e 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -196,13 +196,13 @@ class LLM(ABC): def __init__( self, model_name: str, + llm_config: LLMConfig, logger: DebugGymLogger | None = None, - llm_config: LLMConfig | None = None, runtime_generate_kwargs: dict | None = None, ): self.model_name = model_name self.logger = logger or DebugGymLogger("debug-gym") - self.config = llm_config or LLMConfigRegistry.from_file()[model_name] + self.config = llm_config self.tokenizer_name = self.config.tokenizer self.context_length = self.config.context_limit * 1000 self.apply_chat_template = self.config.apply_chat_template @@ -291,8 +291,8 @@ def instantiate( llm = klass( name, - logger=logger, llm_config=llm_config, + logger=logger, runtime_generate_kwargs=runtime_generate_kwargs, ) return llm diff --git a/debug_gym/llms/copilot.py b/debug_gym/llms/copilot.py index f8fc21fa..078e1fd1 100644 --- a/debug_gym/llms/copilot.py +++ b/debug_gym/llms/copilot.py @@ -29,11 +29,16 @@ class CopilotLLM(OpenAILLM): def __init__( self, model_name, + llm_config, logger=None, - llm_config=None, runtime_generate_kwargs=None, ): - super().__init__(model_name, logger, llm_config, runtime_generate_kwargs) + super().__init__( + model_name, + llm_config=llm_config, + logger=logger, + runtime_generate_kwargs=runtime_generate_kwargs, + ) self._client = None self._token_cache = None self._token_expires_at = 0 diff --git a/scripts/run.py b/scripts/run.py index a6fc3da9..8c255e07 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -87,10 +87,10 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): env = create_env(config, task_data, task_logger) llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger) - agent = create_agent(config.get("agent", {}), logger=task_logger) + agent = create_agent(config.get("agent", {}), llm=llm, logger=task_logger) try: - success = agent.run(env, llm, debug=args.debug) + success = agent.run(env, debug=args.debug) except KeyboardInterrupt: task_logger.error("Agent run was interrupted by user.") task_logger.report_progress( diff --git a/tests/agents/conftest.py b/tests/agents/conftest.py index f3881702..8b4df548 100644 --- a/tests/agents/conftest.py +++ b/tests/agents/conftest.py @@ -52,8 +52,7 @@ def _agent_setup(agent_class, *, config_override=None): llm.context_length = 4096 llm.count_tokens = _length llm.define_tools = lambda x: x - agent = agent_class(config_dict) - agent.llm = llm + agent = agent_class(llm=llm, agent_args=config_dict) agent.env = env yield agent, env, llm diff --git a/tests/agents/test_froggy_agents.py b/tests/agents/test_froggy_agents.py index a937053b..3222a803 100644 --- a/tests/agents/test_froggy_agents.py +++ b/tests/agents/test_froggy_agents.py @@ -149,7 +149,7 @@ def test_run(agent_setup, build_env_info): tool=ToolCall(id="tool_id", name="tool_name", arguments={}), token_usage=TokenUsage(2, 4), ) - result = agent.run(env, llm, debug=False) + result = agent.run(env, debug=False) assert result @@ -243,7 +243,7 @@ def test_run_early_completion(agent_setup, build_env_info): step_observation="Test last run obs", ) - result = agent.run(env, llm) + result = agent.run(env) assert result["success"] is True env.step.assert_not_called() # Should not step if already done @@ -282,7 +282,7 @@ def test_run_stops_at_max_steps(agent_setup, build_env_info): response_token_count=4, ) - result = agent.run(env, llm) + result = agent.run(env) assert result["success"] is False assert env.step.call_count == 1 @@ -305,7 +305,7 @@ def test_run_exception_handling(agent_setup, build_env_info): llm.side_effect = RuntimeError("Test error") with pytest.raises(RuntimeError, match="Test error"): - agent.run(env, llm) + agent.run(env) def test_save_patch(agent_setup, tmp_path): @@ -350,7 +350,7 @@ def json(self, step_id): {"name": tool.name, "args": tool.arguments} for tool in tools ] - trajectory = agent._build_trajectory() + trajectory = agent.build_trajectory() assert trajectory["problem"] == env.task_name assert trajectory["uuid"] == "test-uuid-123" assert len(trajectory["log"]) == 2 diff --git a/tests/conftest.py b/tests/conftest.py index bb9e027c..28f57d1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,23 @@ def emit(self, record): @pytest.fixture def llm_class_mock(): class LLMMock(LLM): + def __init__( + self, + model_name: str, + llm_config=None, + logger=None, + runtime_generate_kwargs=None, + ): + # If llm_config is not provided, fetch it from the registry + if llm_config is None: + llm_config = LLMConfigRegistry.from_file()[model_name] + super().__init__( + model_name, + llm_config, + logger=logger, + runtime_generate_kwargs=runtime_generate_kwargs, + ) + def generate(self, messages, tools, **kwargs): self.called_messages = messages self.called_tools = tools diff --git a/tests/gym/envs/test_r2egym.py b/tests/gym/envs/test_r2egym.py index 909a3a83..d2ae3e2a 100644 --- a/tests/gym/envs/test_r2egym.py +++ b/tests/gym/envs/test_r2egym.py @@ -245,9 +245,8 @@ def test_running_solution_agent(get_r2egym_env, tmp_path): for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] @pytest.if_docker_running diff --git a/tests/gym/envs/test_swe_bench.py b/tests/gym/envs/test_swe_bench.py index 07632d33..9f133fa1 100644 --- a/tests/gym/envs/test_swe_bench.py +++ b/tests/gym/envs/test_swe_bench.py @@ -247,9 +247,8 @@ def test_running_solution_agent(get_swe_bench_env, tmp_path): for tool_name in ["pdb", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] @pytest.if_docker_running @@ -288,6 +287,5 @@ def test_running_solution_agent_in_debug_mode(get_swe_bench_debug_env, tmp_path) for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] diff --git a/tests/gym/envs/test_swe_smith.py b/tests/gym/envs/test_swe_smith.py index 2c6f566d..5a33a397 100644 --- a/tests/gym/envs/test_swe_smith.py +++ b/tests/gym/envs/test_swe_smith.py @@ -246,14 +246,12 @@ def test_running_solution_agent(get_swe_smith_env, tmp_path): "output_path": str(tmp_path), "random_seed": 0, "max_steps": 1, - "env": env, } for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] @pytest.if_docker_running