diff --git a/configs/simple.yaml b/configs/simple.yaml
new file mode 100644
index 00000000..a4896f8c
--- /dev/null
+++ b/configs/simple.yaml
@@ -0,0 +1,68 @@
+# Configuration for standalone FreeEnv + FreeAgent runs.
+task_name: free-session
+output_path: exps/free_env
+
+llm:
+ name: frogboss
+
+# Tools to load into the environment toolbox.
+tools:
+ - bash
+ - submit:
+ eval_on_submit: False # Here we only terminate after submission, no auto-eval.
+
+task_data:
+ env_type: FreeEnv
+ image: ubuntu:22.04
+ local_path: /home/macote/src/debug-gym/data/mini_nightmare/pandas_dataframe
+ workspace_dir: /testbed
+
+terminal:
+ type: docker
+
+agent:
+ type: simple_agent
+ max_steps: 20
+ system_prompt: |-
+ You are a helpful assistant that can interact with a computer to solve tasks.
+
+ * If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it.
+
+
+ You have access to the following functions:
+
+ ---- BEGIN FUNCTION #1:
+ bash ----
+ Description: Execute a bash command in the terminal.
+
+ Parameters:
+ (1) command (string, required): The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.
+ ---- END FUNCTION #1 ----
+
+ ---- BEGIN FUNCTION #2: submit ----
+ Description: Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task.
+ No parameters are required for this function.
+ ---- END FUNCTION #2 ----
+
+ If you choose to call a function ONLY reply in the following format with NO suffix:
+
+ Provide any reasoning for the function call here.
+
+ value_1
+
+ This is the value for the second parameter
+ that can span
+ multiple lines
+
+
+
+
+ Reminder:
+ - Function calls MUST follow the specified format, start with
+ - Required parameters MUST be specified
+ - Only call one function at a time
+ - Always provide reasoning for your function call in natural language BEFORE the function call (not after)
+
+
+ instance_prompt: >-
+ Look at the codebase check that everything is working properly.
diff --git a/debug_gym/agents/__init__.py b/debug_gym/agents/__init__.py
index caf48ec6..8eafff17 100644
--- a/debug_gym/agents/__init__.py
+++ b/debug_gym/agents/__init__.py
@@ -1,5 +1,6 @@
from debug_gym.agents.base_agent import BaseAgent, register_agent
from debug_gym.agents.froggy_agent import FroggyAgent
+from debug_gym.agents.simple_agent import SimpleAgent
from debug_gym.agents.solution_agent import AgentSolution
__all__ = [
@@ -7,4 +8,5 @@
"register_agent",
"FroggyAgent",
"AgentSolution",
+ "SimpleAgent",
]
diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py
index 4af8b0b0..577894cf 100644
--- a/debug_gym/agents/base_agent.py
+++ b/debug_gym/agents/base_agent.py
@@ -2,14 +2,14 @@
import os
import uuid
from dataclasses import MISSING, asdict, dataclass, field, fields
-from typing import Any, Dict
+from typing import Any, Dict, List
from jinja2 import Environment, Template
from debug_gym.agents.history_tracker import HistoryTracker
from debug_gym.gym.envs.env import EnvInfo, RepoEnv
from debug_gym.gym.utils import filter_non_utf8
-from debug_gym.llms.base import LLM
+from debug_gym.llms.base import LLM, LLMResponse
from debug_gym.llms.utils import trim
from debug_gym.logger import DebugGymLogger
@@ -27,8 +27,8 @@ def register_agent(cls):
@dataclass
class AgentArgs:
- system_prompt: str | None = None
- instance_prompt: str | None = None
+ system_prompt: str = ""
+ instance_prompt: str = "Instructions: {{ info.instructions }}"
max_steps: int = 100
max_history_token_cutoff: int = -1
max_history_steps_cutoff: int = -1
@@ -83,8 +83,6 @@ def to_dict(self) -> Dict[str, Any]:
class BaseAgent:
name: str = None
args_class = AgentArgs
- system_prompt: str = ""
- instance_prompt: str = "Instructions: {{ info.instructions }}"
def __init__(
self,
@@ -95,14 +93,10 @@ def __init__(
self.args = self.args_class.make(agent_args or {})
self.history = HistoryTracker()
self.logger = logger or DebugGymLogger("debug-gym")
- self.llm = None
+ self.llm = llm
self.env = None
-
- # Override prompts if provided in args
- if self.args.system_prompt is not None:
- self.system_prompt = str(self.args.system_prompt)
- if self.args.instance_prompt is not None:
- self.instance_prompt = str(self.args.instance_prompt)
+ self.system_prompt = str(self.args.system_prompt)
+ self.instance_prompt = str(self.args.instance_prompt)
@staticmethod
def to_pretty_json(value):
@@ -238,17 +232,89 @@ def should_stop(self, step: int, info: EnvInfo):
reason = "max_steps reached"
return should_stop, reason
- def run(self, env: RepoEnv, llm: LLM, debug=False):
- self.env = env
- self.llm = llm
+ def init(self, info: EnvInfo) -> None:
+ """Initialize the agent with environment
+
+ Args:
+ info: The environment info to interact with.
+ """
+ self.history.init(
+ self.build_system_prompt(info), self.build_instance_prompt(info), info
+ )
+
+ self.logger.info(
+ "Available tools (in LLM's tool calling format):\n"
+ f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n"
+ )
+
+ def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]:
+ """Execute a single agent step (LLM decision only).
+
+ Args:
+ info: Current environment info.
+
+ Returns:
+ LLMResponse with the agent's decision.
+ """
+ messages = self.build_prompt(info)
+ return self.llm(messages, info.tools)
+
+ def execute_action(self, llm_response: LLMResponse | List[LLMResponse]) -> EnvInfo:
+ next_info = self.env.step(
+ llm_response.tool,
+ llm_response.response,
+ llm_response.reasoning_response,
+ )
+ self.history.step(next_info, llm_response)
+ return next_info
+
+ def build_trajectory(self) -> Dict[str, Any]:
+ """Return the trajectory as a JSON-serializable dict without writing it."""
+ tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools]
+ json_output = {
+ "problem": self.env.task_name,
+ "config": self.args.to_dict(),
+ "tools": self.llm.define_tools(self.env.tools) if self.llm else tools,
+ "uuid": self.args.uuid,
+ "success": self.env.resolved,
+ "log": [],
+ "agent_type": self.__class__.__name__,
+ "logger": str(self.logger.log_file),
+ }
+ for step_id in range(len(self.history)):
+ step_json = self.history.json(step_id)
+ json_output["log"].append(step_json)
+ return json_output
+
+ def run(
+ self,
+ env: RepoEnv,
+ debug: bool = False,
+ reset_env: bool = True,
+ ) -> Dict[str, Any]:
+ """Run the agent loop until termination or max steps.
+
+ Args:
+ env: The environment to interact with.
+ debug: Whether to drop into debugger after each LLM call.
+ reset_env: Whether to reset the environment (default True).
+
+ Returns:
+ The trajectory as a JSON-serializable dict.
+ """
info = None
step = 0
+ # assign the env
+ self.env = env
+
try:
- info = self.env.reset()
- self.history.init(
- self.build_system_prompt(info), self.build_instance_prompt(info), info
- )
+ if reset_env:
+ info = env.reset()
+ else:
+ info = env.info
+
+ self.init(info)
if info.resolved:
self.logger.report_progress(
@@ -259,12 +325,7 @@ def run(self, env: RepoEnv, llm: LLM, debug=False):
max_score=info.max_score,
status="resolved",
)
- return self._build_trajectory()
-
- self.logger.info(
- "Available tools (in LLM's tool calling format):\n"
- f"{json.dumps(self.llm.define_tools(info.tools), indent=4)}\n"
- )
+ return self.build_trajectory()
highscore = info.score
should_stop = False
@@ -273,18 +334,12 @@ def run(self, env: RepoEnv, llm: LLM, debug=False):
while not should_stop:
self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n")
- messages = self.build_prompt(info)
- llm_response = self.llm(messages, info.tools)
+ agent_response = self.step(info)
+ info = self.execute_action(agent_response)
if debug:
breakpoint()
- info = self.env.step(
- llm_response.tool,
- llm_response.response,
- llm_response.reasoning_response,
- )
- self.history.step(info, llm_response)
should_stop, reason = self.should_stop(step + 1, info)
status = (
"resolved"
@@ -299,7 +354,6 @@ def run(self, env: RepoEnv, llm: LLM, debug=False):
self.logger.info(msg)
step += 1
- # keep progress bar running until max_steps is reached
self.logger.report_progress(
problem_id=env.task_name,
step=step,
@@ -308,9 +362,8 @@ def run(self, env: RepoEnv, llm: LLM, debug=False):
max_score=info.max_score,
status=status,
)
- return self._build_trajectory()
+ return self.build_trajectory()
except Exception as e:
- # report any error that happens during the run
self.logger.report_progress(
problem_id=env.task_name,
step=step,
@@ -321,24 +374,6 @@ def run(self, env: RepoEnv, llm: LLM, debug=False):
)
raise e
- def _build_trajectory(self) -> Dict[str, Any]:
- """Return the trajectory as a JSON-serializable dict without writing it."""
- tools = [f"{tool.name}({tool.arguments})" for tool in self.env.tools]
- json_output = {
- "problem": self.env.task_name,
- "config": self.args.to_dict(),
- "tools": self.llm.define_tools(self.env.tools) if self.llm else tools,
- "uuid": self.args.uuid,
- "success": self.env.resolved,
- "log": [],
- "agent_type": self.__class__.__name__,
- "logger": str(self.logger.log_file),
- }
- for step_id in range(len(self.history)):
- step_json = self.history.json(step_id)
- json_output["log"].append(step_json)
- return json_output
-
def create_agent(config: Dict[str, Any], **kwargs) -> BaseAgent:
"""Create an agent from the config dictionary."""
diff --git a/debug_gym/agents/froggy_agent.py b/debug_gym/agents/froggy_agent.py
index 3e531524..69d81f5f 100644
--- a/debug_gym/agents/froggy_agent.py
+++ b/debug_gym/agents/froggy_agent.py
@@ -12,13 +12,13 @@
@dataclass
class FroggyAgentArgs(AgentArgs):
show_current_breakpoints: bool = False
+ system_prompt: str = "{{ agent._default_system_prompt(info) }}"
@register_agent
class FroggyAgent(BaseAgent):
name: str = "froggy_agent"
args_class = FroggyAgentArgs
- system_prompt: str = "{{ agent._default_system_prompt(info) }}"
def shortcut_features(self):
features = []
diff --git a/debug_gym/agents/simple_agent.py b/debug_gym/agents/simple_agent.py
new file mode 100644
index 00000000..1f92ccc4
--- /dev/null
+++ b/debug_gym/agents/simple_agent.py
@@ -0,0 +1,152 @@
+import re
+from dataclasses import dataclass
+from typing import List, Tuple
+
+from debug_gym.agents.base_agent import (
+ AgentArgs,
+ BaseAgent,
+ LLMResponse,
+ register_agent,
+)
+from debug_gym.gym.envs.env import EnvInfo, RepoEnv
+from debug_gym.gym.tools.tool import ToolCall
+from debug_gym.llms.base import LLM
+
+
+@dataclass
+class SimpleAgentArgs(AgentArgs):
+ system_prompt: str = """You are a helpful assistant that can interact with a computer to solve tasks.
+
+* If user provides a path, you should NOT assume it's relative to the current working directory. Instead, you should explore the file system to find the file before working on it.
+
+
+You have access to the following functions:
+
+---- BEGIN FUNCTION #1: bash ----
+Description: Execute a bash command in the terminal.
+
+Parameters:
+(1) command (string, required): The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.
+---- END FUNCTION #1 ----
+
+---- BEGIN FUNCTION #2: submit ----
+Description: Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task.
+No parameters are required for this function.
+---- END FUNCTION #2 ----
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+Provide any reasoning for the function call here.
+
+value_1
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format, start with
+- Required parameters MUST be specified
+- Only call one function at a time
+- Always provide reasoning for your function call in natural language BEFORE the function call (not after)
+
+"""
+ instance_prompt: str = """
+I have uploaded a python code repository in the /testbed directory.
+
+Now consider the following instructions:
+
+\n\n{info.instructions}\n\n
+
+Can you help me solve the issue?
+"""
+
+
+@register_agent
+class SimpleAgent(BaseAgent):
+ name: str = "simple_agent"
+
+ def parse_tool_call(self, tool_call: str) -> List[ToolCall]:
+ """
+ Parses a string of the form:
+
+
+ VALUE
+ ...
+
+
+ and returns a ToolCall object.
+
+ For example:
+
+ view
+ ./sympy/tensor/array/dense_ndim_array.py
+ True
+
+ """
+ tool_calls = []
+ func_pattern = r"]+)>(.*?)"
+ for func_match in re.finditer(func_pattern, tool_call, re.DOTALL):
+ function_name = func_match.group(1)
+ function_content = func_match.group(2)
+
+ pattern = r"]+)>(.*?)"
+ param_matches = re.findall(pattern, function_content, flags=re.DOTALL)
+
+ params = {}
+ for param_key, param_value in param_matches:
+ param_key = param_key.strip()
+ param_value = param_value.strip()
+ params[param_key] = param_value
+
+ tool_calls.append(ToolCall(id="None", name=function_name, arguments=params))
+ return tool_calls
+
+ def parse_response(self, response_text: str) -> Tuple[str, List[ToolCall]]:
+ """
+ Extracts:
+ - thought: everything before the first block
+ - action: the entire first block
+ Returns (thought, action).
+ """
+ # Regex to match (non-greedily) from ``
+ pattern = re.compile(r"(?s)()")
+ match = pattern.search(response_text)
+
+ if match:
+ action = match.group(1) # The entire block
+ thought = response_text[: match.start()] # Everything before the block
+ else:
+ # If no match, treat entire text as "thought"
+ thought = response_text
+ action = ""
+
+ # Strip leading/trailing whitespace
+ thought = thought.strip()
+ action = action.strip()
+
+ tool_calls = self.parse_tool_call(action)
+ return thought, tool_calls
+
+ def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]:
+ """Execute a single agent step (LLM decision only).
+
+ Args:
+ info: Current environment info.
+
+ Returns:
+ LLMResponse with the agent's decision.
+ """
+ messages = self.build_prompt(info)
+ response = self.llm(messages, tools=None)
+ thought, tool_calls = self.parse_response(response.response)
+ if tool_calls and len(tool_calls) > 1:
+ self.logger.info(
+ f"Multiple tool calls detected ({len(tool_calls)}), using the first one."
+ )
+ response.response = thought
+ response.tool = tool_calls[0] if tool_calls else None
+ return response
diff --git a/debug_gym/agents/solution_agent.py b/debug_gym/agents/solution_agent.py
index 66f1d11e..87f97317 100644
--- a/debug_gym/agents/solution_agent.py
+++ b/debug_gym/agents/solution_agent.py
@@ -1,88 +1,66 @@
+from typing import Any, Dict
+
from debug_gym.agents.base_agent import BaseAgent, register_agent
+from debug_gym.gym.envs.env import EnvInfo, RepoEnv
from debug_gym.gym.tools.tool import ToolCall
+from debug_gym.llms.base import LLM, LLMResponse
@register_agent
class AgentSolution(BaseAgent):
+ """Agent that applies the gold patch and submits - used for testing environments."""
+
name: str = "solution_agent"
- def _report_progress(self, task_name, info, status):
- self.logger.report_progress(
- problem_id=task_name,
- step=1,
- total_steps=1,
- score=getattr(info, "score", 0),
- max_score=getattr(info, "max_score", 0),
- status=status,
- )
+ def __init__(
+ self,
+ llm: LLM | None = None,
+ **kwargs,
+ ):
+ super().__init__(llm=llm, **kwargs)
def _env_implements_apply_gold_patch(self):
"""Fail early if the environment does not implement apply_gold_patch."""
return hasattr(self.env, "apply_gold_patch")
- def run(self, env, llm=None, debug=False):
- self.env = env
- info = None
- try:
- if not self._env_implements_apply_gold_patch():
- raise NotImplementedError(
- f"The environment {type(self.env)} is not compatible with SolutionAgent."
- " Check the README.md to see which environments are compatible."
- )
-
- info = self.env.reset()
-
- if info.resolved is True:
- self._report_progress(env.task_name, info, "resolved")
- return True
-
- self.logger.info(f"Score: {info.score}/{info.max_score or '-'}")
-
- if env.has_tool("pdb"):
- # Make a simple pdb call to make sure it is working.
- action = ToolCall(
- name="pdb", id="pdb", arguments={"command": "help help"}
- )
- pdb_help_info = self.env.step(action, None, None)
- assert "h(elp)" in pdb_help_info.step_observation.observation, (
- "PDB command did not return expected help message.\n"
- f"{pdb_help_info.step_observation.observation}"
- )
-
- # Send a pdb continue command, and check the output matches the one from env.reset.
- action = ToolCall(
- name="pdb", id="pdb", arguments={"command": "continue"}
- )
- pdb_continue_info = self.env.step(action, None, None)
-
- pdb_observation = pdb_continue_info.step_observation.observation
- expected_messages = [
- "Reached the end of the program. Restarting the debugging session.",
- "Uncaught exception. Entering post mortem debugging",
- ]
- reset_observation = info.step_observation.observation
- if reset_observation.splitlines():
- expected_messages.append(reset_observation.splitlines()[-1])
-
- assert any(
- msg in pdb_observation for msg in expected_messages
- ), f"PDB command did not return expected continue message.\n{pdb_observation}"
-
- self.env.apply_gold_patch()
-
- if debug:
- breakpoint()
-
- action = ToolCall(name="submit", id="submit", arguments={})
- info = self.env.step(action, None, None)
+ def _run_pdb_sanity_checks(self, info: EnvInfo):
+ """Run PDB sanity checks if PDB tool is available."""
+ if not self.env.has_tool("pdb"):
+ return
+
+ # Make a simple pdb call to make sure it is working.
+ action = ToolCall(name="pdb", id="pdb", arguments={"command": "help help"})
+ pdb_help_info = self.env.step(action, None, None)
+ assert "h(elp)" in pdb_help_info.step_observation.observation, (
+ "PDB command did not return expected help message.\n"
+ f"{pdb_help_info.step_observation.observation}"
+ )
- self.logger.info(f"Score: {info.score}/{info.max_score or '-'}")
- assert info.resolved, (
- "The task is not done after applying the gold patch.\n"
- f"{info.step_observation.observation}"
- )
- self._report_progress(env.task_name, info, "resolved")
- except Exception:
- self._report_progress(env.task_name, info, "error")
- raise
- return info.resolved
+ # Send a pdb continue command, and check the output matches the one from env.reset.
+ action = ToolCall(name="pdb", id="pdb", arguments={"command": "continue"})
+ pdb_continue_info = self.env.step(action, None, None)
+
+ pdb_observation = pdb_continue_info.step_observation.observation
+ expected_messages = [
+ "Reached the end of the program. Restarting the debugging session.",
+ "Uncaught exception. Entering post mortem debugging",
+ ]
+ reset_observation = info.step_observation.observation
+ if reset_observation.splitlines():
+ expected_messages.append(reset_observation.splitlines()[-1])
+
+ assert any(
+ msg in pdb_observation for msg in expected_messages
+ ), f"PDB command did not return expected continue message.\n{pdb_observation}"
+
+ def step(self, info: EnvInfo) -> EnvInfo:
+ tool_call = ToolCall(name="submit", id="submit", arguments={})
+ return LLMResponse([], tool=tool_call)
+
+ def execute_action(self, llm_response, **kwargs):
+ self.env.apply_gold_patch()
+ info = self.env.step(llm_response.tool, None, None)
+ return info
+
+ def init(self, info: EnvInfo) -> None:
+ self._run_pdb_sanity_checks(info)
diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py
index 9dc54ba8..d95ee7e0 100644
--- a/debug_gym/agents/utils.py
+++ b/debug_gym/agents/utils.py
@@ -139,7 +139,7 @@ def save_patch(env, problem_path: Path, logger: DebugGymLogger):
def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger):
"""Persist the agent trajectory to disk."""
problem_path.mkdir(parents=True, exist_ok=True)
- trajectory = agent._build_trajectory()
+ trajectory = agent.build_trajectory()
json_file = problem_path / "trajectory.json"
with open(json_file, "w") as f:
json.dump(trajectory, f, indent=4)
diff --git a/debug_gym/gym/terminals/kubernetes.py b/debug_gym/gym/terminals/kubernetes.py
index 0fd9e333..8c9a51ac 100644
--- a/debug_gym/gym/terminals/kubernetes.py
+++ b/debug_gym/gym/terminals/kubernetes.py
@@ -706,11 +706,9 @@ def copy_content(self, src: str | Path, target: str | Path | None = None) -> Non
cmd = ["kubectl"]
if self.kube_config:
cmd.extend(["--kubeconfig", self.kube_config])
-
# restore previous behavior
if os.path.isdir(src):
src = f"{src}/."
-
cmd.extend(
[
"cp",
diff --git a/debug_gym/llms/azure_openai.py b/debug_gym/llms/azure_openai.py
index ad318b5c..ef32e802 100644
--- a/debug_gym/llms/azure_openai.py
+++ b/debug_gym/llms/azure_openai.py
@@ -22,11 +22,11 @@ class AzureOpenAILLM(OpenAILLM):
def __init__(
self,
model_name,
+ llm_config,
logger=None,
- llm_config=None,
runtime_generate_kwargs=None,
):
- super().__init__(model_name, logger, llm_config, runtime_generate_kwargs)
+ super().__init__(model_name, llm_config, logger, runtime_generate_kwargs)
self._client = None
self._client_created_at = 0
diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py
index 6a7a891d..40254d8e 100644
--- a/debug_gym/llms/base.py
+++ b/debug_gym/llms/base.py
@@ -196,13 +196,13 @@ class LLM(ABC):
def __init__(
self,
model_name: str,
+ llm_config: LLMConfig,
logger: DebugGymLogger | None = None,
- llm_config: LLMConfig | None = None,
runtime_generate_kwargs: dict | None = None,
):
self.model_name = model_name
self.logger = logger or DebugGymLogger("debug-gym")
- self.config = llm_config or LLMConfigRegistry.from_file()[model_name]
+ self.config = llm_config
self.tokenizer_name = self.config.tokenizer
self.context_length = self.config.context_limit * 1000
self.apply_chat_template = self.config.apply_chat_template
@@ -291,8 +291,8 @@ def instantiate(
llm = klass(
name,
- logger=logger,
llm_config=llm_config,
+ logger=logger,
runtime_generate_kwargs=runtime_generate_kwargs,
)
return llm
diff --git a/debug_gym/llms/copilot.py b/debug_gym/llms/copilot.py
index f8fc21fa..078e1fd1 100644
--- a/debug_gym/llms/copilot.py
+++ b/debug_gym/llms/copilot.py
@@ -29,11 +29,16 @@ class CopilotLLM(OpenAILLM):
def __init__(
self,
model_name,
+ llm_config,
logger=None,
- llm_config=None,
runtime_generate_kwargs=None,
):
- super().__init__(model_name, logger, llm_config, runtime_generate_kwargs)
+ super().__init__(
+ model_name,
+ llm_config=llm_config,
+ logger=logger,
+ runtime_generate_kwargs=runtime_generate_kwargs,
+ )
self._client = None
self._token_cache = None
self._token_expires_at = 0
diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py
index 77102bbd..c8c1ed2b 100644
--- a/debug_gym/llms/openai.py
+++ b/debug_gym/llms/openai.py
@@ -268,13 +268,20 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse:
self.need_to_be_retried,
)
try:
- response = api_call(
- model=self.config.model,
- messages=messages,
- tools=self.define_tools(tools),
- tool_choice="auto",
- **kwargs,
- )
+ if tools:
+ response = api_call(
+ model=self.config.model,
+ messages=messages,
+ tools=self.define_tools(tools),
+ tool_choice="auto",
+ **kwargs,
+ )
+ else:
+ response = api_call(
+ model=self.config.model,
+ messages=messages,
+ **kwargs,
+ )
except openai.BadRequestError as e:
# Handle specific error for context length exceeded, otherwise just propagate the error
if self.is_context_length_error(e):
diff --git a/scripts/run.py b/scripts/run.py
index a6fc3da9..8c255e07 100644
--- a/scripts/run.py
+++ b/scripts/run.py
@@ -87,10 +87,10 @@ def run_agent(args, task_name: str, task_data: dict, config: dict):
env = create_env(config, task_data, task_logger)
llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger)
- agent = create_agent(config.get("agent", {}), logger=task_logger)
+ agent = create_agent(config.get("agent", {}), llm=llm, logger=task_logger)
try:
- success = agent.run(env, llm, debug=args.debug)
+ success = agent.run(env, debug=args.debug)
except KeyboardInterrupt:
task_logger.error("Agent run was interrupted by user.")
task_logger.report_progress(
diff --git a/tests/agents/conftest.py b/tests/agents/conftest.py
index f3881702..8b4df548 100644
--- a/tests/agents/conftest.py
+++ b/tests/agents/conftest.py
@@ -52,8 +52,7 @@ def _agent_setup(agent_class, *, config_override=None):
llm.context_length = 4096
llm.count_tokens = _length
llm.define_tools = lambda x: x
- agent = agent_class(config_dict)
- agent.llm = llm
+ agent = agent_class(llm=llm, agent_args=config_dict)
agent.env = env
yield agent, env, llm
diff --git a/tests/agents/test_froggy_agents.py b/tests/agents/test_froggy_agents.py
index a937053b..3222a803 100644
--- a/tests/agents/test_froggy_agents.py
+++ b/tests/agents/test_froggy_agents.py
@@ -149,7 +149,7 @@ def test_run(agent_setup, build_env_info):
tool=ToolCall(id="tool_id", name="tool_name", arguments={}),
token_usage=TokenUsage(2, 4),
)
- result = agent.run(env, llm, debug=False)
+ result = agent.run(env, debug=False)
assert result
@@ -243,7 +243,7 @@ def test_run_early_completion(agent_setup, build_env_info):
step_observation="Test last run obs",
)
- result = agent.run(env, llm)
+ result = agent.run(env)
assert result["success"] is True
env.step.assert_not_called() # Should not step if already done
@@ -282,7 +282,7 @@ def test_run_stops_at_max_steps(agent_setup, build_env_info):
response_token_count=4,
)
- result = agent.run(env, llm)
+ result = agent.run(env)
assert result["success"] is False
assert env.step.call_count == 1
@@ -305,7 +305,7 @@ def test_run_exception_handling(agent_setup, build_env_info):
llm.side_effect = RuntimeError("Test error")
with pytest.raises(RuntimeError, match="Test error"):
- agent.run(env, llm)
+ agent.run(env)
def test_save_patch(agent_setup, tmp_path):
@@ -350,7 +350,7 @@ def json(self, step_id):
{"name": tool.name, "args": tool.arguments} for tool in tools
]
- trajectory = agent._build_trajectory()
+ trajectory = agent.build_trajectory()
assert trajectory["problem"] == env.task_name
assert trajectory["uuid"] == "test-uuid-123"
assert len(trajectory["log"]) == 2
diff --git a/tests/agents/test_simple_agent.py b/tests/agents/test_simple_agent.py
new file mode 100644
index 00000000..4793f0ce
--- /dev/null
+++ b/tests/agents/test_simple_agent.py
@@ -0,0 +1,49 @@
+from unittest.mock import Mock
+
+import pytest
+
+from debug_gym.agents.base_agent import AgentArgs
+from debug_gym.agents.simple_agent import SimpleAgent
+
+
+@pytest.fixture
+def agent():
+ agent = SimpleAgent(agent_args=AgentArgs(max_steps=10))
+ agent.logger = Mock()
+ return agent
+
+
+def test_parse_with_parameters(agent):
+ """Covers main parsing logic and multiline parameters"""
+ response = """
+
+1
+
+def hello():
+ pass
+
+
+"""
+ tool_calls = agent.parse_tool_call(response)
+ assert len(tool_calls) == 1
+ assert tool_calls[0].name == "test"
+ assert tool_calls[0].arguments["x"] == "1"
+ assert "def hello():" in tool_calls[0].arguments["code"]
+
+
+def test_parse_multiple_and_empty(agent):
+ """Covers multiple functions and parameter scoping"""
+ response = (
+ "1"
+ )
+ tool_calls = agent.parse_tool_call(response)
+ assert len(tool_calls) == 2
+ assert tool_calls[0].arguments == {"x": "1"}
+ assert tool_calls[1].arguments == {}
+
+
+def test_parse_fallback_and_exception(agent):
+ """Covers no-match fallback and exception handling"""
+ # No match fallback
+ tool_calls = agent.parse_tool_call("text")
+ assert not tool_calls
diff --git a/tests/conftest.py b/tests/conftest.py
index bb9e027c..28f57d1f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,6 +30,23 @@ def emit(self, record):
@pytest.fixture
def llm_class_mock():
class LLMMock(LLM):
+ def __init__(
+ self,
+ model_name: str,
+ llm_config=None,
+ logger=None,
+ runtime_generate_kwargs=None,
+ ):
+ # If llm_config is not provided, fetch it from the registry
+ if llm_config is None:
+ llm_config = LLMConfigRegistry.from_file()[model_name]
+ super().__init__(
+ model_name,
+ llm_config,
+ logger=logger,
+ runtime_generate_kwargs=runtime_generate_kwargs,
+ )
+
def generate(self, messages, tools, **kwargs):
self.called_messages = messages
self.called_tools = tools
diff --git a/tests/gym/envs/test_r2egym.py b/tests/gym/envs/test_r2egym.py
index 909a3a83..d2ae3e2a 100644
--- a/tests/gym/envs/test_r2egym.py
+++ b/tests/gym/envs/test_r2egym.py
@@ -245,9 +245,8 @@ def test_running_solution_agent(get_r2egym_env, tmp_path):
for tool_name in ["pdb", "eval", "submit"]:
env.add_tool(Toolbox.get_tool(tool_name))
agent = AgentSolution(agent_args=config, llm=None, logger=env.logger)
- env.reset()
- success = agent.run(env)
- assert success
+ result = agent.run(env)
+ assert result["success"]
@pytest.if_docker_running
diff --git a/tests/gym/envs/test_swe_bench.py b/tests/gym/envs/test_swe_bench.py
index 07632d33..9f133fa1 100644
--- a/tests/gym/envs/test_swe_bench.py
+++ b/tests/gym/envs/test_swe_bench.py
@@ -247,9 +247,8 @@ def test_running_solution_agent(get_swe_bench_env, tmp_path):
for tool_name in ["pdb", "submit"]:
env.add_tool(Toolbox.get_tool(tool_name))
agent = AgentSolution(agent_args=config, llm=None, logger=env.logger)
- env.reset()
- success = agent.run(env)
- assert success
+ result = agent.run(env)
+ assert result["success"]
@pytest.if_docker_running
@@ -288,6 +287,5 @@ def test_running_solution_agent_in_debug_mode(get_swe_bench_debug_env, tmp_path)
for tool_name in ["pdb", "eval", "submit"]:
env.add_tool(Toolbox.get_tool(tool_name))
agent = AgentSolution(agent_args=config, llm=None, logger=env.logger)
- env.reset()
- success = agent.run(env)
- assert success
+ result = agent.run(env)
+ assert result["success"]
diff --git a/tests/gym/envs/test_swe_smith.py b/tests/gym/envs/test_swe_smith.py
index 2c6f566d..5a33a397 100644
--- a/tests/gym/envs/test_swe_smith.py
+++ b/tests/gym/envs/test_swe_smith.py
@@ -246,14 +246,12 @@ def test_running_solution_agent(get_swe_smith_env, tmp_path):
"output_path": str(tmp_path),
"random_seed": 0,
"max_steps": 1,
- "env": env,
}
for tool_name in ["pdb", "eval", "submit"]:
env.add_tool(Toolbox.get_tool(tool_name))
agent = AgentSolution(agent_args=config, llm=None, logger=env.logger)
- env.reset()
- success = agent.run(env)
- assert success
+ result = agent.run(env)
+ assert result["success"]
@pytest.if_docker_running