|
1 | 1 | import logging |
2 | 2 |
|
3 | | -from termcolor import colored |
4 | | - |
5 | 3 | from debug_gym.agents.base_agent import register_agent |
6 | 4 | from debug_gym.agents.rewrite_agent import RewriteAgent |
| 5 | +from debug_gym.gym.entities import Event |
| 6 | +from debug_gym.gym.tools.tool import ToolCall |
| 7 | +from debug_gym.gym.tools.toolbox import Toolbox |
7 | 8 | from debug_gym.llms.base import LLM |
8 | 9 | from debug_gym.logger import DebugGymLogger |
9 | 10 |
|
|
12 | 13 | class GuidedRewriteAgent(RewriteAgent): |
13 | 14 | name: str = "guided_agent" |
14 | 15 |
|
15 | | - def try_rewrite(self, task_name): |
16 | | - # make a copy of the env for the llm |
17 | | - from ipdb import set_trace |
| 16 | + def __init__(self, *args, **kwargs): |
| 17 | + super().__init__(*args, **kwargs) |
18 | 18 |
|
19 | | - set_trace() |
20 | | - cloned_env = self.env.clone() |
| 19 | + # Create a dedicated env for the guided rewrite agent. |
| 20 | + self.llm.logger = DebugGymLogger( |
| 21 | + name="LLM", |
| 22 | + level=logging.DEBUG, |
| 23 | + log_dir=self.logger.log_file.parent, |
| 24 | + icon="🤖", |
| 25 | + ) |
21 | 26 |
|
22 | | - # Only keep the rewrite tool in the cloned env |
23 | | - for tool in cloned_env.tools: |
24 | | - if tool.name != "rewrite": |
25 | | - cloned_env.remove_tool(tool.name) |
| 27 | + # Create a human interface for the guided agent. |
| 28 | + self.logger.level = logging.DEBUG |
| 29 | + self.logger.icon = "👤" |
| 30 | + self.human = LLM.instantiate(llm_name="human", logger=self.logger) |
26 | 31 |
|
27 | | - # Reset the cloned environment and replay the history. |
28 | | - info = cloned_env.reset(options={"task_name": task_name}) |
29 | | - # replay the history up to the current step |
30 | | - for step in self.history.get_all(): |
31 | | - assert not step.done |
32 | | - info = cloned_env.step(step.action) |
| 32 | + def try_rewrite_and_rollback(self, last_info): |
| 33 | + prompt = self.build_prompt(last_info) |
33 | 34 |
|
34 | | - prompt = self.build_prompt(info) |
35 | | - response = self.llm(prompt, info.tools) |
36 | | - info = cloned_env.step(response.response) |
| 35 | + # Git commit the current state before trying to rewrite. |
| 36 | + self.env.terminal.run("git add . && git commit -m 'Before rewrite attempt'") |
37 | 37 |
|
38 | | - return info.done |
| 38 | + # Remove all tools except the rewrite tool. |
| 39 | + tools = [tool for tool in last_info.tools if tool.name == "rewrite"] |
| 40 | + response = self.llm(prompt, tools) |
| 41 | + self.llm.logger.info(f"LLM response: {response.response}") |
| 42 | + self.llm.logger.info(f"LLM tool: {response.tool}") |
39 | 43 |
|
40 | | - def run(self, task_name=None, debug=False): |
41 | | - self.logger.level = logging.DEBUG |
42 | | - self.llm.logger = DebugGymLogger( |
43 | | - name="LLM", level=logging.ERROR, log_dir=self.logger.log_file.parent |
44 | | - ) |
45 | | - self.human = LLM.instantiate(llm_name="human", logger=self.logger) |
| 44 | + # Temporarily disable the REWRITE_SUCCESS event. |
| 45 | + self.env.event_hooks.mute(Event.REWRITE_SUCCESS) |
| 46 | + info_after_rewrite = self.env.step(response.tool) |
| 47 | + info = self.env.step(ToolCall(id="eval", name="eval", arguments={})) |
| 48 | + self.env.event_hooks.unmute(Event.REWRITE_SUCCESS) |
46 | 49 |
|
47 | | - self.history.reset() |
48 | | - info = self.env.reset(options={"task_name": task_name}) |
49 | | - # initial state does not have prompt and response |
50 | | - self.history.step(info, None) |
| 50 | + self.llm.logger.info(f"LLM observation: {info.eval_observation.observation}.") |
51 | 51 |
|
52 | | - if info.done is True: |
53 | | - # msg = "Environment started with entrypoint passing without errors." |
54 | | - return True |
| 52 | + # Rollback any changes made by the LLM. |
| 53 | + self.env.terminal.run("git reset --hard HEAD") |
55 | 54 |
|
56 | | - highscore = info.score |
| 55 | + return info.done |
57 | 56 |
|
58 | | - for step in self.logger.tqdm(range(self.config["max_steps"])): |
59 | | - highscore = max(highscore, info.score) |
60 | | - self.logger.info( |
61 | | - f"Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]" |
| 57 | + def run(self, task_name=None, debug=False): |
| 58 | + step = 0 |
| 59 | + max_steps = self.config["max_steps"] |
| 60 | + try: |
| 61 | + self.history.reset() |
| 62 | + info = self.env.reset(options={"task_name": task_name}) |
| 63 | + # initial state does not have prompt and response |
| 64 | + self.history.step(info, None) |
| 65 | + |
| 66 | + # First make sure git is setup correctly. |
| 67 | + self.env.terminal.run( |
| 68 | + "git init && git config user.name 'debug-gym' && git config user.email '<>'" |
62 | 69 | ) |
63 | 70 |
|
64 | | - llm_done = self.try_rewrite(task_name) |
65 | | - if llm_done: |
66 | | - msg = f"*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***" |
67 | | - self.logger.info(colored(msg, "green")) |
68 | | - break |
69 | | - else: |
70 | | - msg = f"*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***" |
71 | | - self.logger.info(colored(msg, "red")) |
72 | | - |
73 | | - # If the LLM did not manage to solve the task, we continue with the guided approach. |
74 | | - prompt = self.build_prompt(info) |
75 | | - human_response = self.human(prompt, info.tools) |
76 | | - |
77 | | - if debug: |
78 | | - breakpoint() |
| 71 | + if info.done is True: |
| 72 | + self.logger.report_progress( |
| 73 | + problem_id=task_name, |
| 74 | + step=1, |
| 75 | + total_steps=1, |
| 76 | + score=info.score, |
| 77 | + max_score=info.max_score, |
| 78 | + status="resolved", |
| 79 | + ) |
| 80 | + return True |
79 | 81 |
|
80 | | - # step the environment with the human response |
81 | | - info = self.env.step(human_response.tool) |
82 | | - # log the human response |
83 | | - self.history.step(info, human_response) |
| 82 | + highscore = info.score |
84 | 83 |
|
85 | | - if info.done: |
| 84 | + for step in range(max_steps): |
| 85 | + self.logger.info(f"\n{'='*20} STEP {step+1} {'='*20}\n") |
| 86 | + highscore = max(highscore, info.score) |
86 | 87 | self.logger.info( |
87 | | - "You managed to provide the patch that solves the task before the LLM. Congrats!" |
| 88 | + f"Step: {step} | Score: {info.score}/{info.max_score} ({info.score/info.max_score:.1%}) [Best: {highscore}]" |
88 | 89 | ) |
89 | | - break |
90 | 90 |
|
91 | | - return info.done |
| 91 | + llm_done = self.try_rewrite_and_rollback(info) |
| 92 | + if llm_done: |
| 93 | + msg = f"[green]*** The rewrite-only agent with {self.llm.model_name} managed to solve the task with the current context. ***[/green]" |
| 94 | + self.llm.logger.error(msg) |
| 95 | + break |
| 96 | + else: |
| 97 | + msg = f"[red]*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***[/red]" |
| 98 | + self.llm.logger.error(msg) |
| 99 | + |
| 100 | + # If the LLM did not manage to solve the task, we continue with the guided approach. |
| 101 | + prompt = self.build_prompt(info) |
| 102 | + human_response = self.human(prompt, info.tools) |
| 103 | + if not llm_done: |
| 104 | + msg = f"[red]*** The rewrite-only agent with {self.llm.model_name} failed to solve the task with the current context. ***[/red]" |
| 105 | + self.llm.logger.error(msg) |
| 106 | + |
| 107 | + if debug: |
| 108 | + breakpoint() |
| 109 | + |
| 110 | + # step the environment with the human response |
| 111 | + info = self.env.step(human_response.tool) |
| 112 | + # log the human response |
| 113 | + self.history.step(info, human_response) |
| 114 | + |
| 115 | + if info.done: |
| 116 | + self.logger.info( |
| 117 | + "You managed to provide the patch that solves the task before the LLM. Congrats!" |
| 118 | + ) |
| 119 | + # early stop, set current step and total steps to be the same |
| 120 | + self.logger.report_progress( |
| 121 | + problem_id=task_name, |
| 122 | + step=step + 1, |
| 123 | + total_steps=step + 1, |
| 124 | + score=info.score, |
| 125 | + max_score=info.max_score, |
| 126 | + status="resolved" if info.done else "unresolved", |
| 127 | + ) |
| 128 | + break |
| 129 | + # keep progress bar running until max_steps is reached |
| 130 | + self.logger.report_progress( |
| 131 | + problem_id=task_name, |
| 132 | + step=step + 1, |
| 133 | + total_steps=max_steps + 1, |
| 134 | + score=info.score, |
| 135 | + max_score=info.max_score, |
| 136 | + status="running", |
| 137 | + ) |
| 138 | + # max_steps was reached, task was either resolved or unresolved |
| 139 | + self.logger.report_progress( |
| 140 | + problem_id=task_name, |
| 141 | + step=step + 1, |
| 142 | + total_steps=step + 1, |
| 143 | + score=info.score, |
| 144 | + max_score=info.max_score, |
| 145 | + status="resolved" if info.done else "unresolved", |
| 146 | + ) |
| 147 | + |
| 148 | + return info.done |
| 149 | + except Exception: |
| 150 | + # report any error that happens during the run |
| 151 | + self.logger.report_progress( |
| 152 | + problem_id=task_name, |
| 153 | + step=step + 1, |
| 154 | + total_steps=step + 1, |
| 155 | + score=info.score if info else 0, |
| 156 | + max_score=info.max_score if info else 1, |
| 157 | + status="error", |
| 158 | + ) |
| 159 | + raise |
0 commit comments