diff --git a/naive_deep_q_learning/cartpole_naive_dqn.py b/naive_deep_q_learning/cartpole_naive_dqn.py index 21c33ae..e433783 100644 --- a/naive_deep_q_learning/cartpole_naive_dqn.py +++ b/naive_deep_q_learning/cartpole_naive_dqn.py @@ -83,11 +83,11 @@ def learn(self, state, action, reward, state_): for i in range(n_games): score = 0 done = False - obs = env.reset() + obs = env.reset()[0] while not done: action = agent.choose_action(obs) - obs_, reward, done, info = env.step(action) + obs_, reward, done, _, info = env.step(action) score += reward agent.learn(obs, action, reward, obs_) obs = obs_