|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- from . import *
-
- def train(training_episodes=10000, resume=True):
- global frames
- frames = []
-
- if resume and os.path.exists(gym_name+".dat"):
- q_table = pickle.load( open( gym_name+".dat", "rb" ))
- else:
- q_table = np.zeros([env.observation_space.n, env.action_space.n])
-
- episodes = training_episodes
- percentage = episodes / 100
-
- suc_cnt = 0
-
- try:
- print (u"{}[2J{}[;H".format(chr(27), chr(27)))
- print("Training: 0%")
- print("Successes so far: 0")
- for i in range(1, episodes + 1):
- state = env.reset()
-
- epochs, reward, = 0, 0
- done = False
-
- while not done:
- if random.uniform(0, 1) < epsilon:
- action = env.action_space.sample()
- else:
- action = np.argmax(q_table[state])
-
- next_state, reward, done, info = env.step(action)
-
- old_value = q_table[state, action]
- next_max = np.max(q_table[next_state])
-
- new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
- q_table[state, action] = new_value
-
- if reward > 0:
- suc_cnt += 1
-
- state = next_state
- epochs += 1
-
- if i % percentage == 0:
- ftext = ""
- if use_video:
- env.render(mode="ansi")
- if use_ansi:
- ftext = env.render(mode="ansi")
- else:
- ftext = str(info)
-
- frames.append({
- 'frame': ftext,
- 'state': state,
- 'action': action,
- 'reward': reward,
- 'session': i
- }
- )
-
- if i % percentage == 0:
- print (u"{}[2J{}[;H".format(chr(27), chr(27)))
- print(f"Training: {i/percentage}%")
- print(f"Successes so far: {suc_cnt}")
- sleep(.1)
- print (u"{}[2J{}[;H".format(chr(27), chr(27)))
- print("Training: finished.\n")
- print(f"Successes totally: {suc_cnt}")
- pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
- print(f"Q-table saved: {gym_name}.dat")
-
- except KeyboardInterrupt:
- print (u"{}[2J{}[;H".format(chr(27), chr(27)))
- print("Training: stopped.\n")
- print(f"Successes totally: {suc_cnt}")
- pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
- print(f"Q-table saved: {gym_name}_stopped.dat")
- exit()
|