Skip to content

Commit f6ddcad

Browse files
committed
WIP
1 parent 5353750 commit f6ddcad

File tree

4 files changed

+177
-81
lines changed

4 files changed

+177
-81
lines changed

debug_gym/agents/guided_agent.py

Lines changed: 130 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
22

3-
from termcolor import colored
4-
53
from debug_gym.agents.base_agent import register_agent
64
from debug_gym.agents.rewrite_agent import RewriteAgent
5+
from debug_gym.gym.entities import Event
6+
from debug_gym.gym.tools.tool import ToolCall
7+
from debug_gym.gym.tools.toolbox import Toolbox
78
from debug_gym.llms.base import LLM
89
from debug_gym.logger import DebugGymLogger
910

@@ -12,80 +13,147 @@
1213
class GuidedRewriteAgent(RewriteAgent):
1314
name: str = "guided_agent"
1415

15-
def try_rewrite(self, task_name):
16-
# make a copy of the env for the llm
17-
from ipdb import set_trace
16+
def __init__(self, *args, **kwargs):
17+
super().__init__(*args, **kwargs)
1818

19-
set_trace()
20-
cloned_env = self.env.clone()
19+
# Create a dedicated env for the guided rewrite agent.
20+
self.llm.logger = DebugGymLogger(
21+
name="LLM",
22+
level=logging.DEBUG,
23+
log_dir=self.logger.log_file.parent,
24+
icon="🤖",
25+
)
2126

22-
# Only keep the rewrite tool in the cloned env
23-
for tool in cloned_env.tools:
24-
if tool.name != "rewrite":
25-
cloned_env.remove_tool(tool.name)
27+
# Create a human interface for the guided agent.
28+
self.logger.level = logging.DEBUG
29+
self.logger.icon = "👤"
30+
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
2631

27-
# Reset the cloned environment and replay the history.
28-
info = cloned_env.reset(options={"task_name": task_name})
29-
# replay the history up to the current step
30-
for step in self.history.get_all():
31-
assert not step.done
32-
info = cloned_env.step(step.action)
32+
def try_rewrite_and_rollback(self, last_info):
33+
prompt = self.build_prompt(last_info)
3334

34-
prompt = self.build_prompt(info)
35-
response = self.llm(prompt, info.tools)
36-
info = cloned_env.step(response.response)
35+
# Git commit the current state before trying to rewrite.
36+
self.env.terminal.run("git add . && git commit -m 'Before rewrite attempt'")
3737

38-
return info.done
38+
# Remove all tools except the rewrite tool.
39+
tools = [tool for tool in last_info.tools if tool.name == "rewrite"]
40+
response = self.llm(prompt, tools)
41+
self.llm.logger.info(f"LLM response: {response.response}")
42+
self.llm.logger.info(f"LLM tool: {response.tool}")
3943

40-
def run(self, task_name=None, debug=False):
41-
self.logger.level = logging.DEBUG
42-
self.llm.logger = DebugGymLogger(
43-
name="LLM", level=logging.ERROR, log_dir=self.logger.log_file.parent
44-
)
45-
self.human = LLM.instantiate(llm_name="human", logger=self.logger)
44+
# Temporarily disable the REWRITE_SUCCESS event.
45+
self.env.event_hooks.mute(Event.REWRITE_SUCCESS)
46+
info_after_rewrite = self.env.step(response.tool)
47+
info = self.env.step(ToolCall(id="eval", name="eval", arguments={}))
48+
self.env.event_hooks.unmute(Event.REWRITE_SUCCESS)
4649

47-
self.history.reset()
48-
info = self.env.reset(options={"task_name": task_name})
49-
# initial state does not have prompt and response
50-
self.history.step(info, None)
50+
self.llm.logger.info(f"LLM observation: {info.eval_observation.observation}.")
5151

52-
if info.done is True:
53-
# msg = "Environment started with entrypoint passing without errors."
54-
return True
52+
# Rollback any changes made by the LLM.
53+
self.env.terminal.run("git reset --hard HEAD")
5554

56-
highscore = info.score
55+
return info.done
5756

58-
for step in self.logger.tqdm(range(self.config["max_steps"])):
59-
highscore = max(highscore, info.score)
60-
self.logger.info(
61-
f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
57+
def run(self, task_name=None, debug=False):
58+
step = 0
59+
max_steps = self.config["max_steps"]
60+
try:
61+
self.history.reset()
62+
info = self.env.reset(options={"task_name": task_name})
63+
# initial state does not have prompt and response
64+
self.history.step(info, None)
65+
66+
# First make sure git is setup correctly.
67+
self.env.terminal.run(
68+
"git init && git config user.name 'debug-gym' && git config user.email '<>'"
6269
)
6370

64-
llm_done = self.try_rewrite(task_name)
65-
if llm_done:
66-
msg = f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***"
67-
self.logger.info(colored(msg, "green"))
68-
break
69-
else:
70-
msg = f"*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***"
71-
self.logger.info(colored(msg, "red"))
72-
73-
# If the LLM did not manage to solve the task, we continue with the guided approach.
74-
prompt = self.build_prompt(info)
75-
human_response = self.human(prompt, info.tools)
76-
77-
if debug:
78-
breakpoint()
71+
if info.done is True:
72+
self.logger.report_progress(
73+
problem_id=task_name,
74+
step=1,
75+
total_steps=1,
76+
score=info.score,
77+
max_score=info.max_score,
78+
status="resolved",
79+
)
80+
return True
7981

80-
# step the environment with the human response
81-
info = self.env.step(human_response.tool)
82-
# log the human response
83-
self.history.step(info, human_response)
82+
highscore = info.score
8483

85-
if info.done:
84+
for step in range(max_steps):
85+
self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n")
86+
highscore = max(highscore, info.score)
8687
self.logger.info(
87-
"You managed to provide the patch that solves the task before the LLM. Congrats!"
88+
f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]"
8889
)
89-
break
9090

91-
return info.done
91+
llm_done = self.try_rewrite_and_rollback(info)
92+
if llm_done:
93+
msg = f"[green]*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***[/green]"
94+
self.llm.logger.error(msg)
95+
break
96+
else:
97+
msg = f"[red]*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***[/red]"
98+
self.llm.logger.error(msg)
99+
100+
# If the LLM did not manage to solve the task, we continue with the guided approach.
101+
prompt = self.build_prompt(info)
102+
human_response = self.human(prompt, info.tools)
103+
if not llm_done:
104+
msg = f"[red]*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***[/red]"
105+
self.llm.logger.error(msg)
106+
107+
if debug:
108+
breakpoint()
109+
110+
# step the environment with the human response
111+
info = self.env.step(human_response.tool)
112+
# log the human response
113+
self.history.step(info, human_response)
114+
115+
if info.done:
116+
self.logger.info(
117+
"You managed to provide the patch that solves the task before the LLM. Congrats!"
118+
)
119+
# early stop, set current step and total steps to be the same
120+
self.logger.report_progress(
121+
problem_id=task_name,
122+
step=step + 1,
123+
total_steps=step + 1,
124+
score=info.score,
125+
max_score=info.max_score,
126+
status="resolved" if info.done else "unresolved",
127+
)
128+
break
129+
# keep progress bar running until max_steps is reached
130+
self.logger.report_progress(
131+
problem_id=task_name,
132+
step=step + 1,
133+
total_steps=max_steps + 1,
134+
score=info.score,
135+
max_score=info.max_score,
136+
status="running",
137+
)
138+
# max_steps was reached, task was either resolved or unresolved
139+
self.logger.report_progress(
140+
problem_id=task_name,
141+
step=step + 1,
142+
total_steps=step + 1,
143+
score=info.score,
144+
max_score=info.max_score,
145+
status="resolved" if info.done else "unresolved",
146+
)
147+
148+
return info.done
149+
except Exception:
150+
# report any error that happens during the run
151+
self.logger.report_progress(
152+
problem_id=task_name,
153+
step=step + 1,
154+
total_steps=step + 1,
155+
score=info.score if info else 0,
156+
max_score=info.max_score if info else 1,
157+
status="error",
158+
)
159+
raise

debug_gym/gym/envs/env.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class EnvInfo:
3737
class EventHooks:
3838
def __init__(self):
3939
self.event_listeners = {event: [] for event in Event}
40+
self.event_listeners_muted = {event: [] for event in Event}
4041

4142
def subscribe(self, event: Event, tool: "Tool"):
4243
if event not in self.event_listeners:
@@ -50,6 +51,20 @@ def subscribe(self, event: Event, tool: "Tool"):
5051
def unsubscribe(self, event: Event, tool):
5152
self.event_listeners[event].remove(tool)
5253

54+
def mute(self, event: Event):
55+
"""Mute all tools for the given event."""
56+
if event not in self.event_listeners_muted:
57+
raise ValueError(f"Unknown event type: {event}")
58+
self.event_listeners_muted[event] = self.event_listeners[event][:]
59+
self.event_listeners[event] = []
60+
61+
def unmute(self, event: Event):
62+
"""Unmute all tools for the given event."""
63+
if event not in self.event_listeners_muted:
64+
raise ValueError(f"Unknown event type: {event}")
65+
self.event_listeners[event] = self.event_listeners_muted[event][:]
66+
self.event_listeners_muted[event] = []
67+
5368
def notify(
5469
self, environment, event: Event, source=None, **kwargs
5570
) -> list[Observation]:
@@ -555,23 +570,6 @@ def step(self, action: ToolCall, action_reasoning: str = "") -> EnvInfo:
555570

556571
return self.infos
557572

558-
def clone(self):
559-
# Create a new instance of RepoEnv
560-
new_env = RepoEnv(
561-
path=self.path,
562-
entrypoint=self.entrypoint,
563-
debug_entrypoint=self.debug_entrypoint,
564-
max_score=self.max_score,
565-
readonly_patterns=None,
566-
run_timeout=self.run_timeout,
567-
dir_tree_depth=self.dir_tree_depth,
568-
terminal=Terminal(),
569-
logger=self.logger,
570-
)
571-
for tool in self.tools:
572-
new_env.add_tool(tool)
573-
return new_env
574-
575573
def post_process_event(self, event: Event, source, kwargs, observations):
576574
"""Post-process the event after it has been handled by the tools."""
577575
if event in (Event.REWRITE_SUCCESS, Event.REWRITE_FAIL):

debug_gym/logger.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,20 @@ def _status_listener(self):
405405
self.logger.debug("Status listener thread exiting...")
406406

407407

408+
class IconFilter(logging.Filter):
409+
def __init__(self, *args, icon="🐸", **kwargs):
410+
super().__init__(*args, **kwargs)
411+
self.icon = icon
412+
413+
def filter(self, record):
414+
if not hasattr(record, "icon"):
415+
# If the record does not have an icon attribute, set it
416+
# This allows the icon to be used in log messages
417+
record.icon = self.icon
418+
419+
return True
420+
421+
408422
class DebugGymLogger(logging.Logger):
409423
"""A multiprocess friendly logger that integrates with Rich for progress reporting.
410424
Multiprocess workers can use this logger to log messages and report progress via
@@ -420,6 +434,7 @@ def __init__(
420434
log_dir: str | None = None,
421435
level: str | int = logging.INFO,
422436
mode: str = "a",
437+
icon: str = "🐸",
423438
):
424439
super().__init__(name)
425440
# If var env "DEBUG_GYM_DEBUG" is set, turn on debug mode
@@ -428,6 +443,8 @@ def __init__(
428443

429444
# Prevent the log messages from being propagated to the root logger
430445
self.propagate = False
446+
self.icon_filter = IconFilter(icon=icon)
447+
self.addFilter(self.icon_filter)
431448

432449
self.setLevel(level) # Set logger level, might be overridden by file handler
433450
self.log_file = None # File handler for logging to a file
@@ -443,6 +460,16 @@ def __init__(
443460
if log_dir:
444461
self._initialize_file_handler(name, log_dir, mode)
445462

463+
@property
464+
def icon(self):
465+
"""Get the icon used in log messages."""
466+
return self.icon_filter.icon
467+
468+
@icon.setter
469+
def icon(self, icon: str):
470+
"""Set the icon for the logger. This will update the icon used in log messages."""
471+
self.icon_filter.icon = icon
472+
446473
def _initialize_main_logger(self, level):
447474
self._live = Live(transient=True, refresh_per_second=2)
448475
rich_handler = RichHandler(
@@ -451,7 +478,9 @@ def _initialize_main_logger(self, level):
451478
rich_tracebacks=True,
452479
markup=True,
453480
)
454-
rich_handler.setFormatter(logging.Formatter("🐸 [%(name)-12s]: %(message)s"))
481+
rich_handler.setFormatter(
482+
logging.Formatter(r"%(icon)s \[%(name)-12s]: %(message)s")
483+
)
455484
rich_handler.setLevel(level)
456485
self.addHandler(rich_handler)
457486

@@ -481,6 +510,7 @@ def handle(self, record):
481510
record into the log queue for the main process to display
482511
logs through Rich."""
483512
if self._is_worker:
513+
# record.args.append(self.icon)
484514
self.LOG_QUEUE.put(record)
485515
super().handle(record)
486516

scripts/config_mini_nightmare.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ base:
2020
# session_commands define commands that are always executed before starting a shell session or running a single command in the terminal.
2121
# session_commands:["conda activate aider"],
2222
# setup_commands define commands that are executed only once when the terminal is created. This is only supported for Docker terminal.
23-
setup_commands: ["pip install pytest pandas"],
23+
setup_commands: ["apt update", "apt install -y git", "pip install pytest pandas"],
2424
}
2525

2626
# LLM configs

0 commit comments

Comments
 (0)