@@ -0,0 +1,6 @@ | |||||
# Byte-compiled / optimized / DLL files | |||||
__pycache__/ | |||||
*.py[cod] | |||||
*$py.class | |||||
video/ |
@@ -8,7 +8,11 @@ import os | |||||
from .settings import * | from .settings import * | ||||
env = gym.make(gym_name).env | env = gym.make(gym_name).env | ||||
if use_video: | |||||
env = gym.wrappers.Monitor(env, './video/',video_callable=lambda episode_id: True,force=True) | |||||
from .train import * | from .train import * | ||||
from .play import * | from .play import * | ||||
from .review import * | from .review import * | ||||
frames = [] |
@@ -26,6 +26,8 @@ def play(player_episodes=1): | |||||
done = True | done = True | ||||
ftext = "" | ftext = "" | ||||
if use_video: | |||||
env.render(mode="ansi") | |||||
if use_ansi: | if use_ansi: | ||||
ftext = env.render(mode="ansi") | ftext = env.render(mode="ansi") | ||||
else: | else: | ||||
@@ -6,27 +6,32 @@ def review(frames): | |||||
rew = 0 | rew = 0 | ||||
cnt = 1 | cnt = 1 | ||||
for i, frame in enumerate(frames): | |||||
print (u"{}[2J{}[;H".format(chr(27), chr(27))) | |||||
print(f"Session: {frame['session']}") | |||||
print(frame['frame']) | |||||
print(f"Timestep: {i + 1}") | |||||
print(f"State: {frame['state']}") | |||||
print(f"Action: {frame['action']}") | |||||
print(f"Reward: {frame['reward']}") | |||||
try: | |||||
for i, frame in enumerate(frames): | |||||
out = u"{}[2J{}[;H".format(chr(27), chr(27)) | |||||
out += f"Session: {frame['session']}\n" | |||||
out += frame['frame'] | |||||
out += f"Timestep: {i + 1}\n" | |||||
out += f"State: {frame['state']}\n" | |||||
out += f"Action: {frame['action']}\n" | |||||
out += f"Reward: {frame['reward']}\n" | |||||
sess = frame['session'] | |||||
if sess != prevSess: | |||||
if rew > 0: | |||||
sucs += "+" | |||||
sleep(1) | |||||
elif rew < 0: | |||||
sucs += "-" | |||||
else: | |||||
sucs += "." | |||||
prevSess = frame['session'] | |||||
cnt += 1 | |||||
rew = frame['reward'] | |||||
sess = frame['session'] | |||||
if sess != prevSess: | |||||
if rew > 0: | |||||
sucs += "+" | |||||
sleep(1) | |||||
elif rew < 0: | |||||
sucs += "-" | |||||
elif prevSess >= 0: | |||||
sucs += "." | |||||
prevSess = frame['session'] | |||||
cnt += 1 | |||||
rew = frame['reward'] | |||||
print(f"\nSuccesses: [{sucs}]") | |||||
sleep(.1) | |||||
out += f"\nSuccesses: [{sucs}]\n" | |||||
print(out) | |||||
sleep(.1) | |||||
except KeyboardInterrupt: | |||||
return() |
@@ -15,3 +15,4 @@ max_iterations = 1000 | |||||
# Render settings | # Render settings | ||||
use_ansi = True | use_ansi = True | ||||
use_video = False |
@@ -1,6 +1,9 @@ | |||||
from . import * | from . import * | ||||
def train(training_episodes=10000, resume=True): | def train(training_episodes=10000, resume=True): | ||||
global frames | |||||
frames = [] | |||||
if resume and os.path.exists(gym_name+".dat"): | if resume and os.path.exists(gym_name+".dat"): | ||||
q_table = pickle.load( open( gym_name+".dat", "rb" )) | q_table = pickle.load( open( gym_name+".dat", "rb" )) | ||||
else: | else: | ||||
@@ -9,10 +12,12 @@ def train(training_episodes=10000, resume=True): | |||||
episodes = training_episodes | episodes = training_episodes | ||||
percentage = episodes / 100 | percentage = episodes / 100 | ||||
frames = [] | |||||
suc_cnt = 0 | suc_cnt = 0 | ||||
try: | try: | ||||
print (u"{}[2J{}[;H".format(chr(27), chr(27))) | |||||
print("Training: 0%") | |||||
print("Successes so far: 0") | |||||
for i in range(1, episodes + 1): | for i in range(1, episodes + 1): | ||||
state = env.reset() | state = env.reset() | ||||
@@ -41,6 +46,8 @@ def train(training_episodes=10000, resume=True): | |||||
if i % percentage == 0: | if i % percentage == 0: | ||||
ftext = "" | ftext = "" | ||||
if use_video: | |||||
env.render(mode="ansi") | |||||
if use_ansi: | if use_ansi: | ||||
ftext = env.render(mode="ansi") | ftext = env.render(mode="ansi") | ||||
else: | else: | ||||
@@ -73,5 +80,3 @@ def train(training_episodes=10000, resume=True): | |||||
pickle.dump(q_table , open( gym_name+".dat", "wb" ) ) | pickle.dump(q_table , open( gym_name+".dat", "wb" ) ) | ||||
print(f"Q-table saved: {gym_name}_stopped.dat") | print(f"Q-table saved: {gym_name}_stopped.dat") | ||||
exit() | exit() | ||||
return frames |