This commit is contained in:
2019-11-05 01:38:46 +03:00
parent 564370fa80
commit 8e38700009
7 changed files with 209 additions and 1 deletions

View File

@@ -1,3 +1,12 @@
# openai-tests
My OpenAI playground
# Q-Learning
```
import ql
frames = ql.train(10000) # run 10k training sessions
ql.review(frames) # review the training process
ql.play() # see the trained algorithm in action
```

BIN
Taxi-v3.dat Normal file

Binary file not shown.

14
ql/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
import gym
from time import sleep
import numpy as np
import pickle
import random
import os
from .settings import *
env = gym.make(gym_name).env
from .train import *
from .play import *
from .review import *

59
ql/play.py Normal file
View File

@@ -0,0 +1,59 @@
from . import *
def play(player_episodes=1):
q_table = pickle.load( open( gym_name+".dat", "rb" ))
total_epochs = 0
episodes = player_episodes
print(f"Evaluation: 0%")
total_epochs, total_rewards = 0, 0
try:
for ep in range(episodes):
state = env.reset()
epochs, reward = 0, 0
done = False
i = 0
while not done:
action = np.argmax(q_table[state])
state, reward, done, info = env.step(action)
if i > max_iterations:
done = True
ftext = ""
if use_ansi:
ftext = env.render(mode="ansi")
else:
ftext = str(info)
i += 1
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
print(f"Evaluation: {100 * ep / episodes}%")
print(f"{ftext}")
sleep(.1)
epochs += 1
total_epochs += epochs
total_rewards += reward
sleep(1)
except KeyboardInterrupt:
print(f"Results after {episodes} episodes:")
print(f"Average timesteps per episode: {total_epochs / episodes}")
print(f"Average rewards per episode: {total_rewards / episodes}")
exit()
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
print("Evaluation: finished.\n")
print(f"Results after {episodes} episodes:")
print(f"Average timesteps per episode: {total_epochs / episodes}")
print(f"Average rewards per episode: {total_rewards / episodes}")

32
ql/review.py Normal file
View File

@@ -0,0 +1,32 @@
from . import *
def review(frames):
sucs = ""
prevSess = -1
rew = 0
cnt = 1
for i, frame in enumerate(frames):
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
print(f"Session: {frame['session']}")
print(frame['frame'])
print(f"Timestep: {i + 1}")
print(f"State: {frame['state']}")
print(f"Action: {frame['action']}")
print(f"Reward: {frame['reward']}")
sess = frame['session']
if sess != prevSess:
if rew > 0:
sucs += "+"
sleep(1)
elif rew < 0:
sucs += "-"
else:
sucs += "."
prevSess = frame['session']
cnt += 1
rew = frame['reward']
print(f"\nSuccesses: [{sucs}]")
sleep(.1)

17
ql/settings.py Normal file
View File

@@ -0,0 +1,17 @@
# OpenAI Gym settings
gym_name = "Taxi-v3"
# Q-Learning training settings
alpha = 0.1
gamma = 0.8
epsilon = 0.1
# Q-learning player settings
max_iterations = 1000
# Render settings
use_ansi = True

77
ql/train.py Normal file
View File

@@ -0,0 +1,77 @@
from . import *
def train(training_episodes=10000, resume=True):
if resume and os.path.exists(gym_name+".dat"):
q_table = pickle.load( open( gym_name+".dat", "rb" ))
else:
q_table = np.zeros([env.observation_space.n, env.action_space.n])
episodes = training_episodes
percentage = episodes / 100
frames = []
suc_cnt = 0
try:
for i in range(1, episodes + 1):
state = env.reset()
epochs, reward, = 0, 0
done = False
while not done:
if random.uniform(0, 1) < epsilon:
action = env.action_space.sample()
else:
action = np.argmax(q_table[state])
next_state, reward, done, info = env.step(action)
old_value = q_table[state, action]
next_max = np.max(q_table[next_state])
new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
q_table[state, action] = new_value
if reward > 0:
suc_cnt += 1
state = next_state
epochs += 1
if i % percentage == 0:
ftext = ""
if use_ansi:
ftext = env.render(mode="ansi")
else:
ftext = str(info)
frames.append({
'frame': ftext,
'state': state,
'action': action,
'reward': reward,
'session': i
}
)
if i % percentage == 0:
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
print(f"Training: {i/percentage}%")
print(f"Successes so far: {suc_cnt}")
sleep(.1)
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
print("Training: finished.\n")
print(f"Successes totally: {suc_cnt}")
pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
print(f"Q-table saved: {gym_name}.dat")
except KeyboardInterrupt:
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
print("Training: stopped.\n")
print(f"Successes totally: {suc_cnt}")
pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
print(f"Q-table saved: {gym_name}_stopped.dat")
exit()
return frames