ql
This commit is contained in:
11
README.md
11
README.md
@@ -1,3 +1,12 @@
|
|||||||
# openai-tests
|
# openai-tests
|
||||||
|
|
||||||
My OpenAI playground
|
My OpenAI playground
|
||||||
|
|
||||||
|
# Q-Learning
|
||||||
|
|
||||||
|
```
|
||||||
|
import ql
|
||||||
|
frames = ql.train(10000) # run 10k training sessions
|
||||||
|
ql.review(frames) # review the training process
|
||||||
|
ql.play() # see the trained algorithm in action
|
||||||
|
```
|
||||||
BIN
Taxi-v3.dat
Normal file
BIN
Taxi-v3.dat
Normal file
Binary file not shown.
14
ql/__init__.py
Normal file
14
ql/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import gym
|
||||||
|
from time import sleep
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .settings import *
|
||||||
|
|
||||||
|
env = gym.make(gym_name).env
|
||||||
|
|
||||||
|
from .train import *
|
||||||
|
from .play import *
|
||||||
|
from .review import *
|
||||||
59
ql/play.py
Normal file
59
ql/play.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from . import *
|
||||||
|
|
||||||
|
def play(player_episodes=1):
|
||||||
|
q_table = pickle.load( open( gym_name+".dat", "rb" ))
|
||||||
|
|
||||||
|
total_epochs = 0
|
||||||
|
episodes = player_episodes
|
||||||
|
|
||||||
|
print(f"Evaluation: 0%")
|
||||||
|
|
||||||
|
total_epochs, total_rewards = 0, 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for ep in range(episodes):
|
||||||
|
state = env.reset()
|
||||||
|
epochs, reward = 0, 0
|
||||||
|
|
||||||
|
done = False
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while not done:
|
||||||
|
action = np.argmax(q_table[state])
|
||||||
|
state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
if i > max_iterations:
|
||||||
|
done = True
|
||||||
|
|
||||||
|
ftext = ""
|
||||||
|
if use_ansi:
|
||||||
|
ftext = env.render(mode="ansi")
|
||||||
|
else:
|
||||||
|
ftext = str(info)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
|
||||||
|
print(f"Evaluation: {100 * ep / episodes}%")
|
||||||
|
print(f"{ftext}")
|
||||||
|
sleep(.1)
|
||||||
|
|
||||||
|
epochs += 1
|
||||||
|
|
||||||
|
total_epochs += epochs
|
||||||
|
total_rewards += reward
|
||||||
|
|
||||||
|
|
||||||
|
sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"Results after {episodes} episodes:")
|
||||||
|
print(f"Average timesteps per episode: {total_epochs / episodes}")
|
||||||
|
print(f"Average rewards per episode: {total_rewards / episodes}")
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
|
||||||
|
print("Evaluation: finished.\n")
|
||||||
|
|
||||||
|
print(f"Results after {episodes} episodes:")
|
||||||
|
print(f"Average timesteps per episode: {total_epochs / episodes}")
|
||||||
|
print(f"Average rewards per episode: {total_rewards / episodes}")
|
||||||
32
ql/review.py
Normal file
32
ql/review.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from . import *
|
||||||
|
|
||||||
|
def review(frames):
|
||||||
|
sucs = ""
|
||||||
|
prevSess = -1
|
||||||
|
rew = 0
|
||||||
|
cnt = 1
|
||||||
|
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
|
||||||
|
print(f"Session: {frame['session']}")
|
||||||
|
print(frame['frame'])
|
||||||
|
print(f"Timestep: {i + 1}")
|
||||||
|
print(f"State: {frame['state']}")
|
||||||
|
print(f"Action: {frame['action']}")
|
||||||
|
print(f"Reward: {frame['reward']}")
|
||||||
|
|
||||||
|
sess = frame['session']
|
||||||
|
if sess != prevSess:
|
||||||
|
if rew > 0:
|
||||||
|
sucs += "+"
|
||||||
|
sleep(1)
|
||||||
|
elif rew < 0:
|
||||||
|
sucs += "-"
|
||||||
|
else:
|
||||||
|
sucs += "."
|
||||||
|
prevSess = frame['session']
|
||||||
|
cnt += 1
|
||||||
|
rew = frame['reward']
|
||||||
|
|
||||||
|
print(f"\nSuccesses: [{sucs}]")
|
||||||
|
sleep(.1)
|
||||||
17
ql/settings.py
Normal file
17
ql/settings.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# OpenAI Gym settings
|
||||||
|
|
||||||
|
gym_name = "Taxi-v3"
|
||||||
|
|
||||||
|
# Q-Learning training settings
|
||||||
|
|
||||||
|
alpha = 0.1
|
||||||
|
gamma = 0.8
|
||||||
|
epsilon = 0.1
|
||||||
|
|
||||||
|
# Q-learning player settings
|
||||||
|
|
||||||
|
max_iterations = 1000
|
||||||
|
|
||||||
|
# Render settings
|
||||||
|
|
||||||
|
use_ansi = True
|
||||||
77
ql/train.py
Normal file
77
ql/train.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from . import *
|
||||||
|
|
||||||
|
def train(training_episodes=10000, resume=True):
|
||||||
|
if resume and os.path.exists(gym_name+".dat"):
|
||||||
|
q_table = pickle.load( open( gym_name+".dat", "rb" ))
|
||||||
|
else:
|
||||||
|
q_table = np.zeros([env.observation_space.n, env.action_space.n])
|
||||||
|
|
||||||
|
episodes = training_episodes
|
||||||
|
percentage = episodes / 100
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
suc_cnt = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in range(1, episodes + 1):
|
||||||
|
state = env.reset()
|
||||||
|
|
||||||
|
epochs, reward, = 0, 0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
if random.uniform(0, 1) < epsilon:
|
||||||
|
action = env.action_space.sample()
|
||||||
|
else:
|
||||||
|
action = np.argmax(q_table[state])
|
||||||
|
|
||||||
|
next_state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
old_value = q_table[state, action]
|
||||||
|
next_max = np.max(q_table[next_state])
|
||||||
|
|
||||||
|
new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
|
||||||
|
q_table[state, action] = new_value
|
||||||
|
|
||||||
|
if reward > 0:
|
||||||
|
suc_cnt += 1
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
epochs += 1
|
||||||
|
|
||||||
|
if i % percentage == 0:
|
||||||
|
ftext = ""
|
||||||
|
if use_ansi:
|
||||||
|
ftext = env.render(mode="ansi")
|
||||||
|
else:
|
||||||
|
ftext = str(info)
|
||||||
|
|
||||||
|
frames.append({
|
||||||
|
'frame': ftext,
|
||||||
|
'state': state,
|
||||||
|
'action': action,
|
||||||
|
'reward': reward,
|
||||||
|
'session': i
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if i % percentage == 0:
|
||||||
|
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
|
||||||
|
print(f"Training: {i/percentage}%")
|
||||||
|
print(f"Successes so far: {suc_cnt}")
|
||||||
|
sleep(.1)
|
||||||
|
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
|
||||||
|
print("Training: finished.\n")
|
||||||
|
print(f"Successes totally: {suc_cnt}")
|
||||||
|
pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
|
||||||
|
print(f"Q-table saved: {gym_name}.dat")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print (u"{}[2J{}[;H".format(chr(27), chr(27)))
|
||||||
|
print("Training: stopped.\n")
|
||||||
|
print(f"Successes totally: {suc_cnt}")
|
||||||
|
pickle.dump(q_table , open( gym_name+".dat", "wb" ) )
|
||||||
|
print(f"Q-table saved: {gym_name}_stopped.dat")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
return frames
|
||||||
Reference in New Issue
Block a user