My OpenAI playground
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.5KB

  1. from . import *
  2. def train(training_episodes=10000, resume=True):
  3. if resume and os.path.exists(gym_name+".dat"):
  4. q_table = pickle.load( open( gym_name+".dat", "rb" ))
  5. else:
  6. q_table = np.zeros([env.observation_space.n, env.action_space.n])
  7. episodes = training_episodes
  8. percentage = episodes / 100
  9. frames = []
  10. suc_cnt = 0
  11. try:
  12. for i in range(1, episodes + 1):
  13. state = env.reset()
  14. epochs, reward, = 0, 0
  15. done = False
  16. while not done:
  17. if random.uniform(0, 1) < epsilon:
  18. action = env.action_space.sample()
  19. else:
  20. action = np.argmax(q_table[state])
  21. next_state, reward, done, info = env.step(action)
  22. old_value = q_table[state, action]
  23. next_max = np.max(q_table[next_state])
  24. new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
  25. q_table[state, action] = new_value
  26. if reward > 0:
  27. suc_cnt += 1
  28. state = next_state
  29. epochs += 1
  30. if i % percentage == 0:
  31. ftext = ""
  32. if use_ansi:
  33. ftext = env.render(mode="ansi")
  34. else:
  35. ftext = str(info)
  36. frames.append({
  37. 'frame': ftext,
  38. 'state': state,
  39. 'action': action,
  40. 'reward': reward,
  41. 'session': i
  42. }
  43. )
  44. if i % percentage == 0:
  45. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  46. print(f"Training: {i/percentage}%")
  47. print(f"Successes so far: {suc_cnt}")
  48. sleep(.1)
  49. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  50. print("Training: finished.\n")
  51. print(f"Successes totally: {suc_cnt}")
  52. pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
  53. print(f"Q-table saved: {gym_name}.dat")
  54. except KeyboardInterrupt:
  55. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  56. print("Training: stopped.\n")
  57. print(f"Successes totally: {suc_cnt}")
  58. pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
  59. print(f"Q-table saved: {gym_name}_stopped.dat")
  60. exit()
  61. return frames