diff --git a/configs/simple.yaml b/configs/simple.yaml new file mode 100644 index 00000000..a4896f8c --- /dev/null +++ b/configs/simple.yaml @@ -0,0 +1,68 @@ +# Configuration for standalone FreeEnv + FreeAgent runs. +task_name: free-session +output_path: exps/free_env + +llm: + name: frogboss + +# Tools to load into the environment toolbox. +tools: + - bash + - submit: + eval_on_submit: False # Here we only terminate after submission, no auto-eval. + +task_data: + env_type: FreeEnv + image: ubuntu:22.04 + local_path: /home/macote/src/debug-gym/data/mini_nightmare/pandas_dataframe + workspace_dir: /testbed + +terminal: + type: docker + +agent: + type: simple_agent + max_steps: 20 + system_prompt: |- + You are a helpful assistant that can interact with a computer to solve tasks. + + * If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it. + + + You have access to the following functions: + + ---- BEGIN FUNCTION #1: + bash ---- + Description: Execute a bash command in the terminal. + + Parameters: + (1) command (string, required): The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process. + ---- END FUNCTION #1 ---- + + ---- BEGIN FUNCTION #2: submit ---- + Description: Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task. + No parameters are required for this function. + ---- END FUNCTION #2 ---- + + If you choose to call a function ONLY reply in the following format with NO suffix: + + Provide any reasoning for the function call here. + + value_1 + + This is the value for the second parameter + that can span + multiple lines + + + + + Reminder: + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Always provide reasoning for your function call in natural language BEFORE the function call (not after) + + + instance_prompt: >- + Look at the codebase check that everything is working properly. diff --git a/debug_gym/agents/__init__.py b/debug_gym/agents/__init__.py index caf48ec6..8eafff17 100644 --- a/debug_gym/agents/__init__.py +++ b/debug_gym/agents/__init__.py @@ -1,5 +1,6 @@ from debug_gym.agents.base_agent import BaseAgent, register_agent from debug_gym.agents.froggy_agent import FroggyAgent +from debug_gym.agents.simple_agent import SimpleAgent from debug_gym.agents.solution_agent import AgentSolution __all__ = [ @@ -7,4 +8,5 @@ "register_agent", "FroggyAgent", "AgentSolution", + "SimpleAgent", ] diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index 4af8b0b0..577894cf 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -2,14 +2,14 @@ import os import uuid from dataclasses import MISSING, asdict, dataclass, field, fields -from typing import Any, Dict +from typing import Any, Dict, List from jinja2 import Environment, Template from debug_gym.agents.history_tracker import HistoryTracker from debug_gym.gym.envs.env import EnvInfo, RepoEnv from debug_gym.gym.utils import filter_non_utf8 -from debug_gym.llms.base import LLM +from debug_gym.llms.base import LLM, LLMResponse from debug_gym.llms.utils import trim from debug_gym.logger import DebugGymLogger @@ -27,8 +27,8 @@ def register_agent(cls): @dataclass class AgentArgs: - system_prompt: str | None = None - instance_prompt: str | None = None + system_prompt: str = "" + instance_prompt: str = "Instructions: {{ info.instructions }}" max_steps: int = 100 max_history_token_cutoff: int = -1 max_history_steps_cutoff: int = -1 @@ -83,8 +83,6 @@ def to_dict(self) -> Dict[str, Any]: class BaseAgent: name: str = None args_class = AgentArgs - system_prompt: str = "" - instance_prompt: str = "Instructions: {{ info.instructions }}" def __init__( self, @@ -95,14 +93,10 @@ def __init__( self.args = self.args_class.make(agent_args or {}) self.history = HistoryTracker() self.logger = logger or DebugGymLogger("debug-gym") - self.llm = None + self.llm = llm self.env = None - - # Override prompts if provided in args - if self.args.system_prompt is not None: - self.system_prompt = str(self.args.system_prompt) - if self.args.instance_prompt is not None: - self.instance_prompt = str(self.args.instance_prompt) + self.system_prompt = str(self.args.system_prompt) + self.instance_prompt = str(self.args.instance_prompt) @staticmethod def to_pretty_json(value): @@ -238,17 +232,89 @@ def should_stop(self, step: int, info: EnvInfo): reason = "max_steps reached" return should_stop, reason - def run(self, env: RepoEnv, llm: LLM, debug=False): - self.env = env - self.llm = llm + def init(self, info: EnvInfo) -> None: + """Initialize the agent with environment + + Args: + info: The environment info to interact with. + """ + self.history.init( + self.build_system_prompt(info), self.build_instance_prompt(info), info + ) + + self.logger.info( + "Available tools (in LLM's tool calling format):\n" + f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n" + ) + + def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]: + """Execute a single agent step (LLM decision only). + + Args: + info: Current environment info. + + Returns: + LLMResponse with the agent's decision. + """ + messages = self.build_prompt(info) + return self.llm(messages, info.tools) + + def execute_action(self, llm_response: LLMResponse | List[LLMResponse]) -> EnvInfo: + next_info = self.env.step( + llm_response.tool, + llm_response.response, + llm_response.reasoning_response, + ) + self.history.step(next_info, llm_response) + return next_info + + def build_trajectory(self) -> Dict[str, Any]: + """Return the trajectory as a JSON-serializable dict without writing it.""" + tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools] + json_output = { + "problem": self.env.task_name, + "config": self.args.to_dict(), + "tools": self.llm.define_tools(self.env.tools) if self.llm else tools, + "uuid": self.args.uuid, + "success": self.env.resolved, + "log": [], + "agent_type": self.__class__.__name__, + "logger": str(self.logger.log_file), + } + for step_id in range(len(self.history)): + step_json = self.history.json(step_id) + json_output["log"].append(step_json) + return json_output + + def run( + self, + env: RepoEnv, + debug: bool = False, + reset_env: bool = True, + ) -> Dict[str, Any]: + """Run the agent loop until termination or max steps. + + Args: + env: The environment to interact with. + debug: Whether to drop into debugger after each LLM call. + reset_env: Whether to reset the environment (default True). + + Returns: + The trajectory as a JSON-serializable dict. + """ info = None step = 0 + # assign the env + self.env = env + try: - info = self.env.reset() - self.history.init( - self.build_system_prompt(info), self.build_instance_prompt(info), info - ) + if reset_env: + info = env.reset() + else: + info = env.info + + self.init(info) if info.resolved: self.logger.report_progress( @@ -259,12 +325,7 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): max_score=info.max_score, status="resolved", ) - return self._build_trajectory() - - self.logger.info( - "Available tools (in LLM's tool calling format):\n" - f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n" - ) + return self.build_trajectory() highscore = info.score should_stop = False @@ -273,18 +334,12 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): while not should_stop: self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n") - messages = self.build_prompt(info) - llm_response = self.llm(messages, info.tools) + agent_response = self.step(info) + info = self.execute_action(agent_response) if debug: breakpoint() - info = self.env.step( - llm_response.tool, - llm_response.response, - llm_response.reasoning_response, - ) - self.history.step(info, llm_response) should_stop, reason = self.should_stop(step + 1, info) status = ( "resolved" @@ -299,7 +354,6 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): self.logger.info(msg) step += 1 - # keep progress bar running until max_steps is reached self.logger.report_progress( problem_id=env.task_name, step=step, @@ -308,9 +362,8 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): max_score=info.max_score, status=status, ) - return self._build_trajectory() + return self.build_trajectory() except Exception as e: - # report any error that happens during the run self.logger.report_progress( problem_id=env.task_name, step=step, @@ -321,24 +374,6 @@ def run(self, env: RepoEnv, llm: LLM, debug=False): ) raise e - def _build_trajectory(self) -> Dict[str, Any]: - """Return the trajectory as a JSON-serializable dict without writing it.""" - tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools] - json_output = { - "problem": self.env.task_name, - "config": self.args.to_dict(), - "tools": self.llm.define_tools(self.env.tools) if self.llm else tools, - "uuid": self.args.uuid, - "success": self.env.resolved, - "log": [], - "agent_type": self.__class__.__name__, - "logger": str(self.logger.log_file), - } - for step_id in range(len(self.history)): - step_json = self.history.json(step_id) - json_output["log"].append(step_json) - return json_output - def create_agent(config: Dict[str, Any], **kwargs) -> BaseAgent: """Create an agent from the config dictionary.""" diff --git a/debug_gym/agents/froggy_agent.py b/debug_gym/agents/froggy_agent.py index 3e531524..69d81f5f 100644 --- a/debug_gym/agents/froggy_agent.py +++ b/debug_gym/agents/froggy_agent.py @@ -12,13 +12,13 @@ @dataclass class FroggyAgentArgs(AgentArgs): show_current_breakpoints: bool = False + system_prompt: str = "{{ agent._default_system_prompt(info) }}" @register_agent class FroggyAgent(BaseAgent): name: str = "froggy_agent" args_class = FroggyAgentArgs - system_prompt: str = "{{ agent._default_system_prompt(info) }}" def shortcut_features(self): features = [] diff --git a/debug_gym/agents/simple_agent.py b/debug_gym/agents/simple_agent.py new file mode 100644 index 00000000..1f92ccc4 --- /dev/null +++ b/debug_gym/agents/simple_agent.py @@ -0,0 +1,152 @@ +import re +from dataclasses import dataclass +from typing import List, Tuple + +from debug_gym.agents.base_agent import ( + AgentArgs, + BaseAgent, + LLMResponse, + register_agent, +) +from debug_gym.gym.envs.env import EnvInfo, RepoEnv +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLM + + +@dataclass +class SimpleAgentArgs(AgentArgs): + system_prompt: str = """You are a helpful assistant that can interact with a computer to solve tasks. + +* If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it. + + +You have access to the following functions: + +---- BEGIN FUNCTION #1: bash ---- +Description: Execute a bash command in the terminal. + +Parameters: +(1) command (string, required): The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process. +---- END FUNCTION #1 ---- + +---- BEGIN FUNCTION #2: submit ---- +Description: Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task. +No parameters are required for this function. +---- END FUNCTION #2 ---- + +If you choose to call a function ONLY reply in the following format with NO suffix: + +Provide any reasoning for the function call here. + +value_1 + +This is the value for the second parameter +that can span +multiple lines + + + + +Reminder: +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Always provide reasoning for your function call in natural language BEFORE the function call (not after) + +""" + instance_prompt: str = """ +I have uploaded a python code repository in the /testbed directory. + +Now consider the following instructions: + +\n\n{info.instructions}\n\n + +Can you help me solve the issue? +""" + + +@register_agent +class SimpleAgent(BaseAgent): + name: str = "simple_agent" + + def parse_tool_call(self, tool_call: str) -> List[ToolCall]: + """ + Parses a string of the form: + + + VALUE + ... + + + and returns a ToolCall object. + + For example: + + view + ./sympy/tensor/array/dense_ndim_array.py + True + + """ + tool_calls = [] + func_pattern = r"]+)>(.*?)" + for func_match in re.finditer(func_pattern, tool_call, re.DOTALL): + function_name = func_match.group(1) + function_content = func_match.group(2) + + pattern = r"]+)>(.*?)" + param_matches = re.findall(pattern, function_content, flags=re.DOTALL) + + params = {} + for param_key, param_value in param_matches: + param_key = param_key.strip() + param_value = param_value.strip() + params[param_key] = param_value + + tool_calls.append(ToolCall(id="None", name=function_name, arguments=params)) + return tool_calls + + def parse_response(self, response_text: str) -> Tuple[str, List[ToolCall]]: + """ + Extracts: + - thought: everything before the first block + - action: the entire first block + Returns (thought, action). + """ + # Regex to match (non-greedily) from `` + pattern = re.compile(r"(?s)()") + match = pattern.search(response_text) + + if match: + action = match.group(1) # The entire block + thought = response_text[: match.start()] # Everything before the block + else: + # If no match, treat entire text as "thought" + thought = response_text + action = "" + + # Strip leading/trailing whitespace + thought = thought.strip() + action = action.strip() + + tool_calls = self.parse_tool_call(action) + return thought, tool_calls + + def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]: + """Execute a single agent step (LLM decision only). + + Args: + info: Current environment info. + + Returns: + LLMResponse with the agent's decision. + """ + messages = self.build_prompt(info) + response = self.llm(messages, tools=None) + thought, tool_calls = self.parse_response(response.response) + if tool_calls and len(tool_calls) > 1: + self.logger.info( + f"Multiple tool calls detected ({len(tool_calls)}), using the first one." + ) + response.response = thought + response.tool = tool_calls[0] if tool_calls else None + return response diff --git a/debug_gym/agents/solution_agent.py b/debug_gym/agents/solution_agent.py index 66f1d11e..87f97317 100644 --- a/debug_gym/agents/solution_agent.py +++ b/debug_gym/agents/solution_agent.py @@ -1,88 +1,66 @@ +from typing import Any, Dict + from debug_gym.agents.base_agent import BaseAgent, register_agent +from debug_gym.gym.envs.env import EnvInfo, RepoEnv from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLM, LLMResponse @register_agent class AgentSolution(BaseAgent): + """Agent that applies the gold patch and submits - used for testing environments.""" + name: str = "solution_agent" - def _report_progress(self, task_name, info, status): - self.logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=getattr(info, "score", 0), - max_score=getattr(info, "max_score", 0), - status=status, - ) + def __init__( + self, + llm: LLM | None = None, + **kwargs, + ): + super().__init__(llm=llm, **kwargs) def _env_implements_apply_gold_patch(self): """Fail early if the environment does not implement apply_gold_patch.""" return hasattr(self.env, "apply_gold_patch") - def run(self, env, llm=None, debug=False): - self.env = env - info = None - try: - if not self._env_implements_apply_gold_patch(): - raise NotImplementedError( - f"The environment {type(self.env)} is not compatible with SolutionAgent." - " Check the README.md to see which environments are compatible." - ) - - info = self.env.reset() - - if info.resolved is True: - self._report_progress(env.task_name, info, "resolved") - return True - - self.logger.info(f"Score: {info.score}/{info.max_score or '-'}") - - if env.has_tool("pdb"): - # Make a simple pdb call to make sure it is working. - action = ToolCall( - name="pdb", id="pdb", arguments={"command": "help help"} - ) - pdb_help_info = self.env.step(action, None, None) - assert "h(elp)" in pdb_help_info.step_observation.observation, ( - "PDB command did not return expected help message.\n" - f"{pdb_help_info.step_observation.observation}" - ) - - # Send a pdb continue command, and check the output matches the one from env.reset. - action = ToolCall( - name="pdb", id="pdb", arguments={"command": "continue"} - ) - pdb_continue_info = self.env.step(action, None, None) - - pdb_observation = pdb_continue_info.step_observation.observation - expected_messages = [ - "Reached the end of the program. Restarting the debugging session.", - "Uncaught exception. Entering post mortem debugging", - ] - reset_observation = info.step_observation.observation - if reset_observation.splitlines(): - expected_messages.append(reset_observation.splitlines()[-1]) - - assert any( - msg in pdb_observation for msg in expected_messages - ), f"PDB command did not return expected continue message.\n{pdb_observation}" - - self.env.apply_gold_patch() - - if debug: - breakpoint() - - action = ToolCall(name="submit", id="submit", arguments={}) - info = self.env.step(action, None, None) + def _run_pdb_sanity_checks(self, info: EnvInfo): + """Run PDB sanity checks if PDB tool is available.""" + if not self.env.has_tool("pdb"): + return + + # Make a simple pdb call to make sure it is working. + action = ToolCall(name="pdb", id="pdb", arguments={"command": "help help"}) + pdb_help_info = self.env.step(action, None, None) + assert "h(elp)" in pdb_help_info.step_observation.observation, ( + "PDB command did not return expected help message.\n" + f"{pdb_help_info.step_observation.observation}" + ) - self.logger.info(f"Score: {info.score}/{info.max_score or '-'}") - assert info.resolved, ( - "The task is not done after applying the gold patch.\n" - f"{info.step_observation.observation}" - ) - self._report_progress(env.task_name, info, "resolved") - except Exception: - self._report_progress(env.task_name, info, "error") - raise - return info.resolved + # Send a pdb continue command, and check the output matches the one from env.reset. + action = ToolCall(name="pdb", id="pdb", arguments={"command": "continue"}) + pdb_continue_info = self.env.step(action, None, None) + + pdb_observation = pdb_continue_info.step_observation.observation + expected_messages = [ + "Reached the end of the program. Restarting the debugging session.", + "Uncaught exception. Entering post mortem debugging", + ] + reset_observation = info.step_observation.observation + if reset_observation.splitlines(): + expected_messages.append(reset_observation.splitlines()[-1]) + + assert any( + msg in pdb_observation for msg in expected_messages + ), f"PDB command did not return expected continue message.\n{pdb_observation}" + + def step(self, info: EnvInfo) -> EnvInfo: + tool_call = ToolCall(name="submit", id="submit", arguments={}) + return LLMResponse([], tool=tool_call) + + def execute_action(self, llm_response, **kwargs): + self.env.apply_gold_patch() + info = self.env.step(llm_response.tool, None, None) + return info + + def init(self, info: EnvInfo) -> None: + self._run_pdb_sanity_checks(info) diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index 9dc54ba8..d95ee7e0 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -139,7 +139,7 @@ def save_patch(env, problem_path: Path, logger: DebugGymLogger): def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger): """Persist the agent trajectory to disk.""" problem_path.mkdir(parents=True, exist_ok=True) - trajectory = agent._build_trajectory() + trajectory = agent.build_trajectory() json_file = problem_path / "trajectory.json" with open(json_file, "w") as f: json.dump(trajectory, f, indent=4) diff --git a/debug_gym/gym/terminals/kubernetes.py b/debug_gym/gym/terminals/kubernetes.py index 0fd9e333..8c9a51ac 100644 --- a/debug_gym/gym/terminals/kubernetes.py +++ b/debug_gym/gym/terminals/kubernetes.py @@ -706,11 +706,9 @@ def copy_content(self, src: str | Path, target: str | Path | None = None) -> Non cmd = ["kubectl"] if self.kube_config: cmd.extend(["--kubeconfig", self.kube_config]) - # restore previous behavior if os.path.isdir(src): src = f"{src}/." - cmd.extend( [ "cp", diff --git a/debug_gym/llms/azure_openai.py b/debug_gym/llms/azure_openai.py index ad318b5c..ef32e802 100644 --- a/debug_gym/llms/azure_openai.py +++ b/debug_gym/llms/azure_openai.py @@ -22,11 +22,11 @@ class AzureOpenAILLM(OpenAILLM): def __init__( self, model_name, + llm_config, logger=None, - llm_config=None, runtime_generate_kwargs=None, ): - super().__init__(model_name, logger, llm_config, runtime_generate_kwargs) + super().__init__(model_name, llm_config, logger, runtime_generate_kwargs) self._client = None self._client_created_at = 0 diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index 6a7a891d..40254d8e 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -196,13 +196,13 @@ class LLM(ABC): def __init__( self, model_name: str, + llm_config: LLMConfig, logger: DebugGymLogger | None = None, - llm_config: LLMConfig | None = None, runtime_generate_kwargs: dict | None = None, ): self.model_name = model_name self.logger = logger or DebugGymLogger("debug-gym") - self.config = llm_config or LLMConfigRegistry.from_file()[model_name] + self.config = llm_config self.tokenizer_name = self.config.tokenizer self.context_length = self.config.context_limit * 1000 self.apply_chat_template = self.config.apply_chat_template @@ -291,8 +291,8 @@ def instantiate( llm = klass( name, - logger=logger, llm_config=llm_config, + logger=logger, runtime_generate_kwargs=runtime_generate_kwargs, ) return llm diff --git a/debug_gym/llms/copilot.py b/debug_gym/llms/copilot.py index f8fc21fa..078e1fd1 100644 --- a/debug_gym/llms/copilot.py +++ b/debug_gym/llms/copilot.py @@ -29,11 +29,16 @@ class CopilotLLM(OpenAILLM): def __init__( self, model_name, + llm_config, logger=None, - llm_config=None, runtime_generate_kwargs=None, ): - super().__init__(model_name, logger, llm_config, runtime_generate_kwargs) + super().__init__( + model_name, + llm_config=llm_config, + logger=logger, + runtime_generate_kwargs=runtime_generate_kwargs, + ) self._client = None self._token_cache = None self._token_expires_at = 0 diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index 77102bbd..c8c1ed2b 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -268,13 +268,20 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse: self.need_to_be_retried, ) try: - response = api_call( - model=self.config.model, - messages=messages, - tools=self.define_tools(tools), - tool_choice="auto", - **kwargs, - ) + if tools: + response = api_call( + model=self.config.model, + messages=messages, + tools=self.define_tools(tools), + tool_choice="auto", + **kwargs, + ) + else: + response = api_call( + model=self.config.model, + messages=messages, + **kwargs, + ) except openai.BadRequestError as e: # Handle specific error for context length exceeded, otherwise just propagate the error if self.is_context_length_error(e): diff --git a/scripts/run.py b/scripts/run.py index a6fc3da9..8c255e07 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -87,10 +87,10 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): env = create_env(config, task_data, task_logger) llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger) - agent = create_agent(config.get("agent", {}), logger=task_logger) + agent = create_agent(config.get("agent", {}), llm=llm, logger=task_logger) try: - success = agent.run(env, llm, debug=args.debug) + success = agent.run(env, debug=args.debug) except KeyboardInterrupt: task_logger.error("Agent run was interrupted by user.") task_logger.report_progress( diff --git a/tests/agents/conftest.py b/tests/agents/conftest.py index f3881702..8b4df548 100644 --- a/tests/agents/conftest.py +++ b/tests/agents/conftest.py @@ -52,8 +52,7 @@ def _agent_setup(agent_class, *, config_override=None): llm.context_length = 4096 llm.count_tokens = _length llm.define_tools = lambda x: x - agent = agent_class(config_dict) - agent.llm = llm + agent = agent_class(llm=llm, agent_args=config_dict) agent.env = env yield agent, env, llm diff --git a/tests/agents/test_froggy_agents.py b/tests/agents/test_froggy_agents.py index a937053b..3222a803 100644 --- a/tests/agents/test_froggy_agents.py +++ b/tests/agents/test_froggy_agents.py @@ -149,7 +149,7 @@ def test_run(agent_setup, build_env_info): tool=ToolCall(id="tool_id", name="tool_name", arguments={}), token_usage=TokenUsage(2, 4), ) - result = agent.run(env, llm, debug=False) + result = agent.run(env, debug=False) assert result @@ -243,7 +243,7 @@ def test_run_early_completion(agent_setup, build_env_info): step_observation="Test last run obs", ) - result = agent.run(env, llm) + result = agent.run(env) assert result["success"] is True env.step.assert_not_called() # Should not step if already done @@ -282,7 +282,7 @@ def test_run_stops_at_max_steps(agent_setup, build_env_info): response_token_count=4, ) - result = agent.run(env, llm) + result = agent.run(env) assert result["success"] is False assert env.step.call_count == 1 @@ -305,7 +305,7 @@ def test_run_exception_handling(agent_setup, build_env_info): llm.side_effect = RuntimeError("Test error") with pytest.raises(RuntimeError, match="Test error"): - agent.run(env, llm) + agent.run(env) def test_save_patch(agent_setup, tmp_path): @@ -350,7 +350,7 @@ def json(self, step_id): {"name": tool.name, "args": tool.arguments} for tool in tools ] - trajectory = agent._build_trajectory() + trajectory = agent.build_trajectory() assert trajectory["problem"] == env.task_name assert trajectory["uuid"] == "test-uuid-123" assert len(trajectory["log"]) == 2 diff --git a/tests/agents/test_simple_agent.py b/tests/agents/test_simple_agent.py new file mode 100644 index 00000000..4793f0ce --- /dev/null +++ b/tests/agents/test_simple_agent.py @@ -0,0 +1,49 @@ +from unittest.mock import Mock + +import pytest + +from debug_gym.agents.base_agent import AgentArgs +from debug_gym.agents.simple_agent import SimpleAgent + + +@pytest.fixture +def agent(): + agent = SimpleAgent(agent_args=AgentArgs(max_steps=10)) + agent.logger = Mock() + return agent + + +def test_parse_with_parameters(agent): + """Covers main parsing logic and multiline parameters""" + response = """ + +1 + +def hello(): + pass + + +""" + tool_calls = agent.parse_tool_call(response) + assert len(tool_calls) == 1 + assert tool_calls[0].name == "test" + assert tool_calls[0].arguments["x"] == "1" + assert "def hello():" in tool_calls[0].arguments["code"] + + +def test_parse_multiple_and_empty(agent): + """Covers multiple functions and parameter scoping""" + response = ( + "1" + ) + tool_calls = agent.parse_tool_call(response) + assert len(tool_calls) == 2 + assert tool_calls[0].arguments == {"x": "1"} + assert tool_calls[1].arguments == {} + + +def test_parse_fallback_and_exception(agent): + """Covers no-match fallback and exception handling""" + # No match fallback + tool_calls = agent.parse_tool_call("text") + assert not tool_calls diff --git a/tests/conftest.py b/tests/conftest.py index bb9e027c..28f57d1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,23 @@ def emit(self, record): @pytest.fixture def llm_class_mock(): class LLMMock(LLM): + def __init__( + self, + model_name: str, + llm_config=None, + logger=None, + runtime_generate_kwargs=None, + ): + # If llm_config is not provided, fetch it from the registry + if llm_config is None: + llm_config = LLMConfigRegistry.from_file()[model_name] + super().__init__( + model_name, + llm_config, + logger=logger, + runtime_generate_kwargs=runtime_generate_kwargs, + ) + def generate(self, messages, tools, **kwargs): self.called_messages = messages self.called_tools = tools diff --git a/tests/gym/envs/test_r2egym.py b/tests/gym/envs/test_r2egym.py index 909a3a83..d2ae3e2a 100644 --- a/tests/gym/envs/test_r2egym.py +++ b/tests/gym/envs/test_r2egym.py @@ -245,9 +245,8 @@ def test_running_solution_agent(get_r2egym_env, tmp_path): for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] @pytest.if_docker_running diff --git a/tests/gym/envs/test_swe_bench.py b/tests/gym/envs/test_swe_bench.py index 07632d33..9f133fa1 100644 --- a/tests/gym/envs/test_swe_bench.py +++ b/tests/gym/envs/test_swe_bench.py @@ -247,9 +247,8 @@ def test_running_solution_agent(get_swe_bench_env, tmp_path): for tool_name in ["pdb", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] @pytest.if_docker_running @@ -288,6 +287,5 @@ def test_running_solution_agent_in_debug_mode(get_swe_bench_debug_env, tmp_path) for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] diff --git a/tests/gym/envs/test_swe_smith.py b/tests/gym/envs/test_swe_smith.py index 2c6f566d..5a33a397 100644 --- a/tests/gym/envs/test_swe_smith.py +++ b/tests/gym/envs/test_swe_smith.py @@ -246,14 +246,12 @@ def test_running_solution_agent(get_swe_smith_env, tmp_path): "output_path": str(tmp_path), "random_seed": 0, "max_steps": 1, - "env": env, } for tool_name in ["pdb", "eval", "submit"]: env.add_tool(Toolbox.get_tool(tool_name)) agent = AgentSolution(agent_args=config, llm=None, logger=env.logger) - env.reset() - success = agent.run(env) - assert success + result = agent.run(env) + assert result["success"] @pytest.if_docker_running