1212class GuidedRewriteAgent (RewriteAgent ):
1313 name : str = "guided_agent"
1414
15+ def __init__ (self , * args , ** kwargs ):
16+ super ().__init__ (* args , ** kwargs )
17+ self .logger .set_no_live ()
18+
1519 def try_rewrite (self , task_name ):
1620 # make a copy of the env for the llm
1721 from ipdb import set_trace
@@ -38,54 +42,106 @@ def try_rewrite(self, task_name):
3842 return info .done
3943
4044 def run (self , task_name = None , debug = False ):
41- self .logger .level = logging .DEBUG
42- self .llm .logger = DebugGymLogger (
43- name = "LLM" , level = logging .ERROR , log_dir = self .logger .log_file .parent
44- )
45- self .human = LLM .instantiate (llm_name = "human" , logger = self .logger )
46-
47- self .history .reset ()
48- info = self .env .reset (options = {"task_name" : task_name })
49- # initial state does not have prompt and response
50- self .history .step (info , None )
51-
52- if info .done is True :
53- # msg = "Environment started with entrypoint passing without errors."
54- return True
55-
56- highscore = info .score
57-
58- for step in self .logger .tqdm (range (self .config ["max_steps" ])):
59- highscore = max (highscore , info .score )
60- self .logger .info (
61- f"Score: { info .score } /{ info .max_score } ({ info .score / info .max_score :.1%} ) [Best: { highscore } ]"
45+ step = 0
46+ max_steps = self .config ["max_steps" ]
47+ try :
48+ self .logger .level = logging .DEBUG
49+ self .logger .icon = "👤"
50+ self .llm .logger = DebugGymLogger (
51+ name = "LLM" , level = logging .ERROR , log_dir = self .logger .log_file .parent
6252 )
53+ self .llm .logger .icon = "🤖"
54+ self .human = LLM .instantiate (llm_name = "human" , logger = self .logger )
55+
56+ self .history .reset ()
57+ info = self .env .reset (options = {"task_name" : task_name })
58+ # initial state does not have prompt and response
59+ self .history .step (info , None )
60+
61+ if info .done is True :
62+ # msg = "Environment started with entrypoint passing without errors."self.logger.report_progress(
63+ self .logger .report_progress (
64+ problem_id = task_name ,
65+ step = 1 ,
66+ total_steps = 1 ,
67+ score = info .score ,
68+ max_score = info .max_score ,
69+ status = "resolved" ,
70+ )
71+ return True
6372
64- llm_done = self .try_rewrite (task_name )
65- if llm_done :
66- msg = f"*** The rewrite-only agent with { self .llm .model_name } managed to solve the task with the current context. ***"
67- self .logger .info (colored (msg , "green" ))
68- break
69- else :
70- msg = f"*** The rewrite-only agent with { self .llm .model_name } failed to solve the task with the current context. ***"
71- self .logger .info (colored (msg , "red" ))
72-
73- # If the LLM did not manage to solve the task, we continue with the guided approach.
74- prompt = self .build_prompt (info )
75- human_response = self .human (prompt , info .tools )
76-
77- if debug :
78- breakpoint ()
79-
80- # step the environment with the human response
81- info = self .env .step (human_response .tool )
82- # log the human response
83- self .history .step (info , human_response )
73+ highscore = info .score
8474
85- if info .done :
75+ for step in range (max_steps ):
76+ self .logger .info (f"\n { '=' * 20 } STEP { step + 1 } { '=' * 20 } \n " )
77+ highscore = max (highscore , info .score )
8678 self .logger .info (
87- "You managed to provide the patch that solves the task before the LLM. Congrats! "
79+ f"Step: { step } | Score: { info . score } / { info . max_score } ( { info . score / info . max_score :.1% } ) [Best: { highscore } ] "
8880 )
89- break
9081
91- return info .done
82+ llm_done = self .try_rewrite (task_name )
83+ if llm_done :
84+ msg = f"[green]*** The rewrite-only agent with { self .llm .model_name } managed to solve the task with the current context. ***[/green]"
85+ self .logger .info (msg )
86+ break
87+ else :
88+ msg = f"[red]*** The rewrite-only agent with { self .llm .model_name } failed to solve the task with the current context. ***[/red]"
89+ self .logger .info (msg )
90+
91+ # If the LLM did not manage to solve the task, we continue with the guided approach.
92+ prompt = self .build_prompt (info )
93+ human_response = self .human (prompt , info .tools )
94+
95+ if debug :
96+ breakpoint ()
97+
98+ # step the environment with the human response
99+ info = self .env .step (human_response .tool )
100+ # log the human response
101+ self .history .step (info , human_response )
102+
103+ if info .done :
104+ self .logger .info (
105+ "You managed to provide the patch that solves the task before the LLM. Congrats!"
106+ )
107+ # early stop, set current step and total steps to be the same
108+ self .logger .report_progress (
109+ problem_id = task_name ,
110+ step = step + 1 ,
111+ total_steps = step + 1 ,
112+ score = info .score ,
113+ max_score = info .max_score ,
114+ status = "resolved" if info .done else "unresolved" ,
115+ )
116+ break
117+ # keep progress bar running until max_steps is reached
118+ self .logger .report_progress (
119+ problem_id = task_name ,
120+ step = step + 1 ,
121+ total_steps = max_steps + 1 ,
122+ score = info .score ,
123+ max_score = info .max_score ,
124+ status = "running" ,
125+ )
126+ # max_steps was reached, task was either resolved or unresolved
127+ self .logger .report_progress (
128+ problem_id = task_name ,
129+ step = step + 1 ,
130+ total_steps = step + 1 ,
131+ score = info .score ,
132+ max_score = info .max_score ,
133+ status = "resolved" if info .done else "unresolved" ,
134+ )
135+
136+ return info .done
137+ except Exception :
138+ # report any error that happens during the run
139+ self .logger .report_progress (
140+ problem_id = task_name ,
141+ step = step + 1 ,
142+ total_steps = step + 1 ,
143+ score = info .score if info else 0 ,
144+ max_score = info .max_score if info else 1 ,
145+ status = "error" ,
146+ )
147+ raise
0 commit comments