My OpenAI playground
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.7KB

5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from . import *
  2. def train(training_episodes=10000, resume=True):
  3. global frames
  4. frames = []
  5. if resume and os.path.exists(gym_name+".dat"):
  6. q_table = pickle.load( open( gym_name+".dat", "rb" ))
  7. else:
  8. q_table = np.zeros([env.observation_space.n, env.action_space.n])
  9. episodes = training_episodes
  10. percentage = episodes / 100
  11. suc_cnt = 0
  12. try:
  13. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  14. print("Training: 0%")
  15. print("Successes so far: 0")
  16. for i in range(1, episodes + 1):
  17. state = env.reset()
  18. epochs, reward, = 0, 0
  19. done = False
  20. while not done:
  21. if random.uniform(0, 1) < epsilon:
  22. action = env.action_space.sample()
  23. else:
  24. action = np.argmax(q_table[state])
  25. next_state, reward, done, info = env.step(action)
  26. old_value = q_table[state, action]
  27. next_max = np.max(q_table[next_state])
  28. new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
  29. q_table[state, action] = new_value
  30. if reward > 0:
  31. suc_cnt += 1
  32. state = next_state
  33. epochs += 1
  34. if i % percentage == 0:
  35. ftext = ""
  36. if use_video:
  37. env.render(mode="ansi")
  38. if use_ansi:
  39. ftext = env.render(mode="ansi")
  40. else:
  41. ftext = str(info)
  42. frames.append({
  43. 'frame': ftext,
  44. 'state': state,
  45. 'action': action,
  46. 'reward': reward,
  47. 'session': i
  48. }
  49. )
  50. if i % percentage == 0:
  51. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  52. print(f"Training: {i/percentage}%")
  53. print(f"Successes so far: {suc_cnt}")
  54. sleep(.1)
  55. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  56. print("Training: finished.\n")
  57. print(f"Successes totally: {suc_cnt}")
  58. pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
  59. print(f"Q-table saved: {gym_name}.dat")
  60. except KeyboardInterrupt:
  61. print (u"{}[2J{}[;H".format(chr(27), chr(27)))
  62. print("Training: stopped.\n")
  63. print(f"Successes totally: {suc_cnt}")
  64. pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
  65. print(f"Q-table saved: {gym_name}_stopped.dat")
  66. exit()