From 8e38700009ad402658419e011277a68caddf4b11 Mon Sep 17 00:00:00 2001 From: KaiSD Date: Tue, 5 Nov 2019 01:38:46 +0300 Subject: [PATCH] ql --- README.md | 11 ++++++- Taxi-v3.dat | Bin 0 -> 24161 bytes ql/__init__.py | 14 +++++++++ ql/play.py | 59 +++++++++++++++++++++++++++++++++++++ ql/review.py | 32 ++++++++++++++++++++ ql/settings.py | 17 +++++++++++ ql/train.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 Taxi-v3.dat create mode 100644 ql/__init__.py create mode 100644 ql/play.py create mode 100644 ql/review.py create mode 100644 ql/settings.py create mode 100644 ql/train.py diff --git a/README.md b/README.md index e37bb2a..aee828c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,12 @@ # openai-tests -My OpenAI playground \ No newline at end of file +My OpenAI playground + +# Q-Learning + +``` +import ql +frames = ql.train(10000) # run 10k training sessions +ql.review(frames) # review the training process +ql.play() # see the trained algorithm in action +``` \ No newline at end of file diff --git a/Taxi-v3.dat b/Taxi-v3.dat new file mode 100644 index 0000000000000000000000000000000000000000..141eb2ed5f68fcbf2ec32da678a3b6b8c6b870f7 GIT binary patch literal 24161 zcmbTec|29$_djkPF88|EJaf&8GLbC6$OMDYJxBhRRSOLzF}sNamuVoJA@m z%9J6MA@fY;`0d;4EuTMr|J?8W6iobFdqWnH#Ip zPDYbS#+__e-Pr#3w|1UBuJ$-Lb_*s328I(G+&FeJ>|_jf<1}-_DeYvoU_^JfxN+^? zz1w&R{r>)ECxfROw{55%97X>~=kL_s8ty+$Y+xGdwuk=#6X=Z!aM-u7!Hd>L6|ER1 z2)}=|QgrVew*UBIdwRG6%q;8;=vOYM6nmDMu(;1qk_~yMpP#Ix^tM{}KFFV;d`*47 zQ#+KCS{SL}829ydlB)_LO65%SkXNg7q??`D?7V-sMKms~=KHSmjOzS*^wO@G~ily1}i%-9~e1byGay6rFC|&O*m0a--QcIV>RdW#6d)zP>& zX++rVrjX*HF9m!o4-KEXkb#h9y-IG88&XL=54h#V>; z9ufc%|MxfZf&^egt$*XL6e67N(RoO!CBh~4xRPDyd0*J9({j>@;PphAQhEaID{one zb)Y2O-+NN*r{mI+JXK@6|Kgt|Rq7C#wP6HZpU{;UTR_*{xds|bOG4B?H0fE=q%NPo z;0>F5_FbOC6tpPC(pqC!#$Z&8Y4IH9_jCsX=Q?@N%+UMpcZUi3lq@4kfAdpQyLcY8 zPH|A>9=msY8B9^k;x=)^PEIPe^`*`+J|T$r;?}u&c7`%7?j)WoSVIwwJS@`E)raLW z-}vm6lS;7-!3AIO>Bmgi3@R?U%EGR5OFzz^pTp=q&wgKitWN8}fz~+ZLVN=bX8YRj zMmXVMMzK@U?J5&w+enqkQ>c2X2Eri5ISyK3@Gzj|fba+nbu|Ie`5h&eZ8C2L#wjynA|+4KkG&8)@{X+|VCu(WUOP75Tzd$Nk}5T)T|ly=60 za56*%jA&7jd^JPoeQQGkqz-LyUw?-HliAtl8#DW$Bqe@EN*qqkhUaSR7Nz@pIrZmHZA0&Wq22lClRtQn z-x!_ro^)8Q^PHZz%`elJ#v%? zq|Rp#3|)xuLrjbzfJ+=GT=gcd8>NAr&-Pp-(hp9Xgo`_6xdHckLUq?ZZt&Y$P<*6; z3mTsuFj;RW2C83opBIx5hOV*~6iqaLJAVemI1)jmb@Rti2O=1IN(4pw6QPvP_TE_; z8Q7k7%=iK${rY+DG)E3!z!F6vvsIJ%IwSRVui0AbOGA{lBW_uXNPguV*<}x$?8aPP z?%F)fRgR6++_yM{|3E2tU%@#x&IIbuQU^JdnBnU0$Iet7GkDrjc?Wdxz~nl$B?u+CdfqUUK0YbWfH z;evBr3G5-x($MJ>!f{?!nC|aA|KOS5H6L!!OI#;88O8;2r)ETH^GgQ^thI3Df+7`% z;=nTk@bSo(+KGp_YVzw}>+q1Eqg~ygAp`t;R6bQrQHXV+ zelu23R?e* zb*1KpaYS&-9I`y}0PSmj6gC}9gzF(T=htl`gRAV>`~$ZIVO65Jq&|}O)*okN8y?_+ zJ?x%IN)#TJ+VN8@4R{dAtF150k^nXL?Y`f0rQmZOW>5QG6%!Rg?io(laZ%McRF?yC zF8N6w2TriO_&DMP@?-QkzuwFx3IlK5&L=c|NKJz0w}S}~RgqvQDnoz=srqf|sRXzq z=C%;GN)WP?W6trj3evs5#5Gs8PJEfBtaAEx^z^+MtbeDHi2B+&%$9qvuf)q4?3Sx` zk(4+)Jc_zl2uNQH|NLENRbGef`#dCn<4Y~Z81f`p{TQCQe9iFpH;NI=M)}D~z!Z1PEQXFHDbwVWV9U%HO^Npznu>XExqSsJ9R9kr-eWI@-_oR6k&jZ+R?O727Y_Jin~ zSBLSi$6C^#8|huM0OQaegdf%!E3VC8ADHg=o{&zln??b?W-XYT4N za-Rsl)rpq3N@QT`!gRX`p9Gxv;+E<I5 zx#+5Y@6bCuJUw&vw$n8NB<*rwOvwwc3gU%eNh(F+JmEL{=nPL?Ie*4 zBKctB|`Rt8eqWMScl#5Sr6JR%6ZBg_n9?sc3?sX9%Kpwxgtd_I@OcI#O z?i8V?`68jgqAeX8Q0Cdmbwn1%RFzT;w#kCKbCm6M z;op?jS`n{++rKDpJ%rNDQzq6b=9YT(To2axA;-0}qz(&;C>x02uc6$@ zlSo%DTf%OJeqj!ZVT2eCNt@0vMkwefU|U!TfYU27usJLUACrGQ)2`>IW~6N2%+)c1 zv1~9C$U6ERvpRXPYUtqz=Bm{8@mAClR{hvdWQf2?_x@H+X!bCp_^&DPz1)5j4;F07 z;2A{mp?+V;$P5~<+$vDx*e(m-pElbbbCQ7^)rqUp1$Zb6`L5wjC4fl!46i>Y0lvJQ ztu#S;)oFI^(qx1Ha0%CR8^_DQNa(cddUQU;>{F80>64(>x+3m~74l!*mMS))`^N;g z=)Ji>29^rOT!#e6#H4bhhq}elrZh9-7)L?v8&u}8ML?|ciXAd{B>jqPp|eCGDUX9XMV<}l>kc* z&Azer5#Yr)EL9J!&(Olgz^yT2kSu*i#zIXH%mrVjU&7FOSSEe0tR;bS?#?@Z-jaY+ zRHEXK?td=ERS?tBIU~ESqpC^K{k<(YWSkXR=Undq-!f#+O9JOLWl_A4QB;>lyMHOf zZFQ!r2&CriRGxUM0OJw~n+?%>X6fK7Gv@JNwVpX7r3??2{Gv&uShT<2ZKNC(+^P^I>U1&oi-kzc}|%5+?JuH?eBa|GjV3+GGvI2C!>c z$`9L*&0$Aa4qkiZwSu!SCa&lFj&PRl{F9Gef|;O!#VBMJ#m7fKEI8BT*1e~BBP~uo zGPUl+47$%+r4qaF{}V5l4vO;4o}R;Y_b}TJ9!7CsP)fkIL;YCohz7|I z=-mDf55gJ#iKrEsp$JDc41D-FLvgQGSz@ner*>bnjCOr00Us1UJUw`|oWeloueP2k z(04oTpB)Mhm33p-hwR+sCc~>;Y;c57xyvG55oGRDs-l$?p=?%NuHfOyxt-bmA@0S2 zer%)M%?X|_Rg_sno2B~**C@Ci@8#xAhL6@MJF6O2-zxLWY=#uk-KM!25-1O}H zE*@C*W0Mna83AVBUUyq%DF?qQw12U@q~Cx3c%WH`79YIvd)`iA#Y6GdjYE(Oz? z24oR!K}gIfbHE=Y!uQ$C%(=@%@HP`Zxe@vIHN$daYpsdUZER|3%q#~{OWOoJHR;E% zat~{)($tjtA3mi@U0ZDGsEn>B_=fXnxE9}m!RC&oB}#Yty>ke_`iC#UnUGf>o#wEV z6up*BnV%@#QoU9rp?<9ERc=1xd09B3c(3YEs4OUJ(tc9Zq799GhxT()Wr8<9VqM^% zen?19yoB*nH;75DNti;o;4|NFu`xU>>-XYgGnQ-h`mz2eQe8PNGZcIS-{7$~Rh094 zLN<(Bk-^W)DuK!7{@j}gC)03!GBUCIEu~AjZ|liN{r~*RZ>1ghlX9|<7IxHGRZkXhpI!&L z?7+cYvmdb{!oAjb?!rJycA&Tm#I9OzpeJpd+}2p5v1J? zWBaO!u>EzBt*8PCSgrcL_6N&>)s8Rs>t4vxZI71c4{4VYyN~SA#ds-R`AH(!e+%M6jz(^_NFJ6y#$e7%=mII8ENesj2d+)aBikg#Ipy9)CnC~44=u7HwxZWoX zb@AJHX!>D+Z{&nUB@b*d`h2}Tod?!@Oj`(gPk>X@RS$|@Nkh*Tr{Zf*<>>xi{`ddA z|BT@w^21vM$l!3;s^fs{C~t=S!%>6}TjqzCAe`u?wsq9TO(>o<`R{!`)Q#U}I8210 zO9YweY##WPk(qqj0$uy*oML2CfXcmz@5fUV;GuuZU;U64aWFT5$OTzOHBTxo;((IR z!+nm(4YLaqyIDfTkzf9!H+@?C-~8<#qc|AMiExO&aIo$t4+wDyckMMLLfDN*gCprO zpyZ)4&{jx)d}N;6T*pS}r5K(rIe9~wm3lsYr)AeH7d2!f4RMyI%HmG$Cf*&zb_q$o z?dTuFqMY>7Y5voAdw*RewuEgF8YM7Xp2gIdzpL5pTg1rrqt2}6LQwuwqJRB8et6@{ z;a7t2=duh_*Di%t%sC#JDoiPkR^c}6g~OxCWVP!jlDW8K^|L!f(qS96{Et6#s5MOsdM zB|_^nLMvZ65&XMzoxge$q4&YBXKi2Q;X=6DeQE_6^6q~A8}D=#98F@$zyX(Je_;SW z7i7QI(lBP=gvs9Nx9K=(U|UChq-84kZ@=>Un|D=hCxW7mg!0lJB1Ft-7Uo9rfZ`is z;ud##$Xw5RHU5?yoLz7AtYwe@_hdZZKeWREmHglN%T`m}lQYi^n%6_fW+h5UPOgew zt5Kl)do}J@@wD^sfQqdRWu4{9{3eH3%NdYdTHmm4yPl;8zt`hAx*3$f=Nsvfaxfb# zeod6yG0O~N5zoeGxOrZ=()%PYR(SQC|MFumA=m|DSR|SEpC7~Fx50XN6bD8x1GCq4 zT>VPLc`Dh*8h%2LP{Si%8$9W#vT(V!EyWX`NwaG z@B-IlyZ8?2!;8F|2?ywxU!{PZfAvF!eTMh$V?;PhaZmO(Ai^TMG{0aM5BSb+${I&{ zBv*1*S(^Zo8y11T{=ojsgq~CD3BclP7<)sP8xB?l4~wwifp`DTV-m>zP98hZCYg%x zjm~X7smEBMSdjVA`C<;Z=3hpU6l|*pRZgqR0QZ*v_A@0! zHRU++n@vSRtvT)SFmr{&*u(}8LLK$rzM^xs?sQLG>Jt6(8-16Hl_oEnE!j;{HjQIj zPF}>js0?A@GxOiDt;5*bBK5N~?YywXM~^3BJqwU8b5C#`?WBxJ$XEwCv|;=3hQ__> zQq-C78)RsQkJ50P39)7 zQ2fCc0XrrL_|E0HkAsWu^Q>R{pS=@2*sRh?Cc(Vnv{~Xt5;#(=N>4o}0`vJ7UYcEU zQ1Y5^amGKe}J`=Z3&T5sr$$0l0f%-`?GnhMYQC#=aMlV63%= zD#J_!%S1TVmrDd*_4C{J=b-U_dV^cLic+{u9vVuI+|InG02!SuPEUSw!vXx0N)=J| zf6puV!$H|R%MB-2j1c4N_2nkVCcKru_^Sv|MM?C)UW2SmSP8mhr|ud zL)=i7)%&ThNfEeW*lM(q9IItqYE`H}arpo8)VtT%ZVu)j#J1R%Bz-7lg4}CR=7yJp z46=}tpQ9WUe7Ujn8lDrb?`^nxzyuEgf?HfX_T&FOPmEnjkKvFs*qk3dV}|g7Wq*J- zW_nYS56VmHG99f9M)AV-v$|rPL>PB|Zu>h{0lFgmZXaQwUtcY+^TWL&u>u_^pL~1{Ay-~8CrF`>8H@lcng_IfJ;4@x=`ZNK>Nuv#rWHEx{@EIk_h`C^QIe{$Jh^}D_y z*Rr+`v#q|d^M-6c=6Hn4d6j>LLh$4&mh&Wo$-2v5{SjVaZlE&$+i!_tw{|{;J&Q=4 zr=^9FoxWgSltkLZL|tQau;gXj81}+{cjKh}*gw3Z|Cckv)b@UCL)Lsou6RG@{Z`|) z(}@`han3{1-H8l(-1k1-^&rFh$kl)SZR?YGRxbUon14{CWe@2F<$a09;VywOir-na zhxy_xP}sngu5xODZv2_;+x2P(6#tQzZcGsZNx?8@x<@ddG1y$}_em)xJiWLMQ2{jAM)<-SXQNDO&R0U}gTg^x<0PICi*2 zh*|4pfrB`ZR8uOHgJB~k{WVA)83Q(3T~l5B7ay3`e^8|uPh+!{Tb+IQ7$8uDI#sn! z8nTjy-I^lBpzcIL;lNLX#}BmehKdp3eyg-~PCOo>Os9Nl>*M`7-p#323OpkN$E+8{ zVM?p3Mtm;`0%Ge9Dj>b>cSbqZ8Rezb)C_OaaA^_QVYY$Ka*$F)%;kF`LihLPzuwfc zFOLY0f3jMi6G*WCxnrhVB@q;^*%Z*$|1PzoIVOR=PT%TA{My!#QbKy8VdJd z?z1(o@R0aJfAMT5H`LcI@T&ixczM@o((gY=eh)+tU#p?}=Qp=uizrTZcoLe7WSdg? zqmygs(*I3=t46uyTuX%oN1J_0d;7q*KqtUj61z&NRH)WTE zkP~bkGCh*u-J6oU?c<8Pgo)>CWH#gAyA}B*u7eZmUx>CmHzH zNPx5bet=*``e&b%{2Y9J*il||H0b)RGh|qQ-eUjfVfyv+l(rP(vj7svBf(X-m=ivuW`fj z+={(Z`?|+BmL1{K4;NAg+3?{0*x`v0!VQMGFXk7UD8WET`3{^h{r<*sek@z`Rp%0# z?_rA;Whc^K>`I?|FLbj3|D3?Nnj}TI@nxN5b&w*&sSn7o>?1-lzZDO&1o2<|!^$>1 zHgTQ@V3XR{K{-j#h%+bX=SaZ~hNOJkK;(zP z`fEOuO%iaIW3=A{`SI_w*ST3}AU??d;%)PLEq8tl@j$sbcT#f(%EJww)a;I+-+xwe znMelYuP)nBSlj4B<4fJ@iSXAz@)*$MEvczFf#Djx?$X9d-3i|wl+WK=6?E+)8kg;N z&>5NRUd|ed%Cm#L>Oo|%QFOPyg~sK1-&{$2@iz|sa?zjmysgp6g*183+H5aRoA-{v z%dNDrgxtGz2Tu)U>3){&fI!VZexiTJ__TfV45ijrA!4pv0(`GcCVoXD1D(I#LZ`P* z2>)A-Gwqrje~ux3AXVP}`cFq39E{(6>eVeJxb5TBa{iMNsF*QbwkTWiCvP>APpQ9P z0*&>;kKFK-zvEW)!x&J&JPe_CWD;$f`jzW&;Bs=DcS-@0`YdPhm}mqutQguzdDZo z^XxZjBfoc!2y#UpH+S)@%pbmShA)bR1QO389@s1>L4OT??l!j~d=$IbqF2cc@<#)| zhyPePpMo>b_Ah+phQaiLV~;K>Ku2OS$Hg5Au5_~*gb}^oH#f?G_``M0E*p4YJ@eJvi7h-3{ax>_0Ci$yx75@6O)EARu!mr9%4dXB4{R1L@VKZinA=I`h+IKznM<+wOE%!S4!mh3y;B{P(nbLo8>T4pHA{7b8Vj-3+d z+Vr-%w&Z3Z_#@8vK^Kj5iZ4UBg@|P;!l{-{Ut}|LMf|;_@~#=gYg@)4UryfE|E{bA z=Ul(cd&?-p{L==Tgb6m-w14Bn5mk2hX&RvEG_#T)^1o6+$@?u2^)};!rab89whFp zf4g6e9b!$xb?&TT2e#;@@jndg|Kh}P2Bpi|G4k+jPRlafNdbEB@`ejni7>~V?3Z{E z#k((Z^OVrIy#DTnUIyZ-%5bcvg!%FRL%iSXo7Q?(j|XDPo_c4y@<8vp1+DETSLB9y zPh?I&o+4a-nS0fZNePDQ)a@Ve;9mN7Odta53#9p(hDExkB@E(#WdMg&QNX<{-in8%}`#d?zpw@ z|Hw^OkJW@3k`t!h?`dpkygMx5=7RpNye5&Oub-c~EMHlYAKXk4E{%(l@zh3tr)A3f zej6Gq$k|nW@6i2o3Gt_!W+>IW9+0#C&-^>e1x_DAf9HCKXKpVV_m|cXE+KuQCSd9O z{?dw^uhsj&9*^YSa@c%U2Fdxum*00-ppjnr4F#J|JQu#o$VTMyFkp#ML; ze@~v%hwRG*r#v^4YwGd9cP~3EZi@`8`LS)-%z=J9wwmSIwr^VqQ0lv;$-#je!j$>s zX!*;M%5pJj#520};((1@n-ru6#r0KWioxc*5A7)^f4Y1wrBvf|K1$AGx}1a!Hi%~; zJwNs(zfB(ACpG<96D0?GHte3=Z@~qdpR!gAb!Ae zrh8^T5pS)(rsko9gCcxkkS^%MrfXKO^^VeibP(l88PIBCU@6W;u%kQLei@fOrQx%RN_(FP>L^h5XlL9)V01{o?qMj=mYYDSrb1E*br(cf|PUPn&>&p^BX;? z?CMAJ?$>Ed>_FqPzb&~_X@4>Dv)jww-hPcnrqJ>oG`}9FF5X4+vv2%yy`zOjdgXTu zr|U_X9f-gGv(oA&!b^nrR%lf%5y41j?r$FUv#PRe(hC`gh>U11-6#S>Ox91Y8*{_q z-=uBNT)83Lh2>{<0T+y2AH6??&iTIG3#^@80>FQC{heHKF+0-o-~Pa>(1;bNG5^&``KT1P?X6%0&8HL193Cw2(Lt2+2u>857?wu&#*pDo zpZC-Da6B}{7Voz3Kx5cLh2|cV2i5v7pS~&f%+OhFGBj?z_(R=H7Q|<>93$;Wpmme? z-6_OZi4>fUW6KEF2JHt&!^3h!!Pjt}J?1heP|xlk-t?3UWT=VtiG4T_ z+})v0r3S=7O7|tP$x+oFG$1{#E-! z7#^8TTFB?n&u`Z?cez^ld>#|Ep6y(?IgcIPB~oNkFoiL+Tt6Fz>}jcV%kNgiXLap= z>ERJXgoj&sxoC0hvYegW;_!}!XW4%(A=BimtgLAbO}^fADjuPYC;YV(gyv_kQwx{J zXFe=ovf1_cX1f9G@lD^Sv^=u!9u|LX5nhNoy5Xj^l>l_EElwJg;iOvNO>MQ-aZ>HH zjH@&2`2XpN^Ng#G)gitm??j~Wp=mtb_LpO!q5hXK0SJF^{4_f>_)O$yt^ffpcX&K- zClbKHUB$`gy%YrZ9}aHlkcPq!qfPzgEMVl0TcFkbJ&C#bu;3;$T)+3IL-G--SCVR5 zBhDfTs&3h@i{lXQZZa*yzJVJynNIx8^BbRdNM1d`1@}$L+GG~FpxSifhWsniP#{P8 z5dU2gEcgACPO&Dz&h*t0H2igMArxhTE?N@f%A{oZBQZTPQ3Q%=JS8H*a0K->~ z{LQ0owMY(%xI%yftG29i{Yrqx3H2{O%?QB2`Jkc1N92Eeo0fx|FOFOFV`Jn#n*(<% zDb|nPkQ4o?C|O_MPQ)TSpO?|pv;7eHpFQHRcU5rNhV0!FytczA!j+EJ_)a#nfwpo$ z(5(PPNZ^0-?D=E*?a{K|uvtYVRIh^#Myd}R%c1(s#@!cvE$=gd+<>qLqlp3puFJV* zZmR%>8L1Ca0*G)7bNs7c@8OTTmmuCov+fyID}*~a-u?L-Hz)xyaW;vI`qFUt*%8bG z@w)3PTmIt5?`J-8l&5n;_qn95xX0XJJN?BH+NGhJB6-*I85u;~!%mWliJ-qB`0xCx zvL|o5qPmK2>eDsOQAC(+k;}(A_>vf`(9JihoxzlNW>KKr{- z1{&{k+`{mdm)K@dMW&&**;4HuXo85&Hmk&!a_h7s{(y z3+QlPaF&Jt>JC+@Y;sd9G#-}8Rh=WWd{fIe)A2<#c6M#|LkTF#&h;ZqG~7B+CZubh z!U|t;?5R<+w!kNfrbpke7b71j7;Ez-c9ds2`X`^0&6<8Z)_IC=Z=yI6gf0H#-K;yg zX7<-K4~)f^R+R?xfLji3>x0$O5R;pG3n-iQO=T*2M{o1YM66AjvWlDygNA*E#I5YKpJEdSW@{#iEYn*T<;^to- z)>Vh~PE|IXaGTQlpiPJqBBt~>QzNATC-Ry+kU)mOxU;`cyRZ1OTh5Ge&3X`_puFqc zetyI|Q~ex$GF}nv7kgse)BkiovY=N zhFt+p+FTpBAT#_w|Ki<1@9)_%PH?VP;Tf^QL6t#9sB{4tG>r2TCq9!QH=S`(?gYZA z>?TC!eiFb(-01If!p7X!yf)*3hyAOM9L`2OlA*)oH&gW6-}+9~n~%-gHe>lZUfC(jQZrqH?!G!tY1j8=RntQY*Xsn zgJ)^D@T)I>46jdPk6G8=K8A2LN2{O%Uw)uEpG*Uu9~G0>+(4}PT{OQF+paGib`x0X zr=D%*yggXf_I6)>R44dcFk{|!locinCX|Qte`8x8hxlJ11r{DCTs-Zqd71#0#%4Kq3o=B!XNf;*F9Wgt(WroUk0br~q|4onw}Z~C%-{RlAV^A? z2y6d*v^beego<$GuLcwuz`NYSSv@4f=4;G8+VNZ<8~Nj}|L`p7@TKHp9DFV^yEr({ z1@B9BetP^6hu3@C?g`mQ!|Z{o95%$eQcVB$cYW5B2FWTW;UFQ^-h-W&17iH+aXybl zp#OX|PgIRKWM6+S=7;oEc~tMOeTYA3J?2vO4%s7*w7@o+e|sZ}gmpTQVM{3aU?B3( z8R-0Fem6)9kVbj;@8>NPR`EdNn$TBvUOe#Jl`Wmd+vM`}dVWh-20V-YSf;%d1FeBw zBi3;Q(9QghC*`#(>y$RCBkbC^DslB^0`%s7?TOZsg#L%#d)BU!gWs}u>N;O?f~0uu z-|u~zX3MH!#Q`mMcKf6aal&PBN^n}N0DR5s>5`%GzI}>?=DAn!v$tw~yE91y=&aD{ zNmby1j~p@vx;y2dtS`rT$4mO{uVCHxBWikMSb3N}=pcM(pM1Aj_pTvqUSV|NS5PbV zYV{?jZ4xXHZV~gwSCn)PUDQ!;KoYVo#_w|Lij;tZF=n` z`;!AIk1nJYIFDm7pc-)v$$K(otcW9b0xNj^j=kJz2rJyd&9pjr5j*$$Kv=07{qp-# z|FO0bsvBJ90V%KPX?<-@f)w3f`8N=sPaRtlzmuo{t2rCKRUedriI1y_wxN2~WnA*t zboKj}&k--@j8mXPBo~ml>kXf*CqtFtAJe_-q@eoj=YzqBFChCrz7_9PHeG$hXYN{5 zy_*$?>^}3zgY|1=K-Y$!d&53yxL2zrxYmgPamgYi+WFFu`8hhW8r7qY=CH1=M0x&= zmtuUHg`gNaWcsdx{`!Thc=x^g19(8`>}v^+K0F+du=%T3zE4m)!~+Rne&W%m)4#}I zAo_zv8}aSa1~)(H-Np$U-2UT@vJFI2&wgixJB;BmPh{A^q~=TAiPMr$!~d;s$8|E? zPTY6B@-ZF+4s`yN_fDA_`KQr%_+EBG_7&of?LO9iVl&F8*UDXdyPYTomREV*JnpQl ze^063ouwNHPkdBZF8Bd5Qz!Qq6o+Yfr#;F>zeq8;l0((>~Cwh|u&kQqz08TD15TVK82o-C!0>VTH> zxQ~M(J5BEnW3P;SnB>+}QwryO#P6IVgTJ7JwiX(TFHYsXwpz*SY`XCv^71Ess#4E4 zmK&%(VN2L+jdB$sAf_y>dZ#D^Dv=h!-qe-(3%!Gi{=^Pq40TWIu5?sUmKFksrVzjD ztJ;s&eU{{Z{`Qq8W{tG`Go`ArRSoG01-*V{n!nx88Xhr&{A}$G>V4#!Q|MIB3@xRv z#_xKNAo{^Z*KJxPkh@`I6^!%@uhoga@u7F?TVbQ~^00{e%&ymxFZw`j_c00!p3+Qeb2f#d{vrMU~-uqpC$na7$O|%1-2`204!= ze4j=B@te7a=@_av8T#JMXCVV#BGT*A#bjXyUn+B-0}nYb{^R=u6@K|cP2>ilom@I| zSMd;0{rgZ}tpt2+y4$!G^}C?=dxd1!tN3BXpBx;YP78I#L;q9VW7{GSU&^p1|LInF zNE2t$*@OB8$(Ts79_nF(y>44S(s<7@_IorO@>pRsMmqY?5q9X45ofnR{>Ev?n_ROO zC19U%6{X=m0SbKZfAIvinA66Gk^JUwZiOsV_oH=dX=F!|41_WVY`_L3LFs2yhWH~A zaG%S3xeo1veRZf--7ONlI;`<`{heJ){mp!2pt|U5X8HyCm` zeHjd|UMJ0rAUl6oq?6Vs=TecSm{5`=_~q^}6GU|buUz_Lz20*`@2P%9@g+`B^u1ZB zFvbB-rduhrI&5x-d*6BnL_mBj`?2vxDUf5o;2Zmy2|l`i4{GFKhK!}*;C(xp;7sq{ z)@=wM`N_@ng>ac4tfISA4%+hrUd`jnm=Mwj{vP2s00$e^xt{$h$q8>${;R|O@@U}2 z;T^K@gE~K*dXE13Bhw#VW#85>VO4AO&b2L6P{ckvG~!+@VJ!jpqzfran3nRr=?4?s zuy{#KXjKNP&$QTg`_RV`%G`AXQz7}{us*nREzdCJxXZy$i<^ch4OBK}&H`32t_aJY zI6i|VapoPUFJlCvb!3FP?k5Vq{g02`B}O>2sMpJ&F^fHtoMWwgOag5~^_kWp#ACFR zvl2TtPmwx2ah-;Tg>%%{mvZz`sts~%Yj1H=zb{-?eL`i0im-ju4~dI(&Z;GxZ z=sa^KS_YmYvO@6HA3QrOS)q)#@^bqbJbb!#e*gtc5OJovw@w8Qo+a*Ajr5 z8Dl#k2^ko?#l`Xl)p=g(_mMc@^ohc>FDzN=_6$Yite3HK$Vz{(!3VoUJ98x9(0AW@ zm0TwHgyWaHhU`&YbIN2Ks!P^X)sE6b_|meyTHuN|dcIzPZhaLSyak7vd0=}ST{rNe zd=C3;O1S#9Ijr2@^!+1bFPC}UTYR-A&TfX$^ws?U93gyh?6e|(pIAra!E z=Hp|Ky@sC`{v6*X1DI56zw23PxaMV)oB0C|wJf(B?9lnCzpNd*YYq?bd%U*le8a>3 zy=e?@ctk+n#POioDiLU8Zkv*rME$i!Jo9M!T+(jf+bX0_@-DR2&W{m+tw?i0&_Nh% z@RxNxPtzYSL<>msUtGiaFOPdq|LZqt3MchKQAqf0TTW`-@y$K!#Dzd6@?f6+<{66o zo&v>}XnmH~;m-S~CBuGmSYo+%J?_FBw(&DVxHRg2rSwJ2F%a!bt&ZOXTAgsFUgC#2 zWIwyAw`b6<|o zpcV5F8LUSfW+XB>fSOUX!xQOaFJ|qyHD;WUzdy0i5$RFg>b_b;sD*mB^LuWO3xV*J z?4O({-dWRuAJ~KTm+gJ~7M&eD;2JJo)QsZd8vmJ3#m!O>#ko^mi5Jz;KEGBA$e%xP z_FzyZ>L=5ldPcYejiYIr=jOMQU{<%IyY7@2^!SuqS2Lu4f0cax9MITE0*<4cTA6D| z@bg1FmWIZ4&u#_mM!fH@s}ed(QQy5gF2WJ-_R2u`=wb*F)!i)KJr_wq?@vhC(M=n> zIBRpf?NGd;$r0dfEde$`H9EDe(!j)`7y0!h7dY2se?N=X=U163YpDnqn$OdG=T98e zrldLd?H7g)Q{3_FJOX@Dw=RqfC4y9Z#9usiSDV{GD`Y2kk|bYD-9Y_H&Rv;Tlpupj zfUWXIIr{Zg{|l7_aW@_~i};YVc;s3tA26Ukd7F9{*#B7aKydG4w=FudP~u}8YJ>VT zyk?p0RuAL>;|&rUX?3~QJ*yASgz&&_D#3GPoCjLF)qXCdqq=VQgaSJwX(%mz(fXhm z@uk0wH%B4;Ts|3?6r9WjhozeZW8ZMWY{(Gmo+AuR>i23@Wo2NmOnP|;FX|T*K%Cu$ z=Fj+H+jX)Hy@#-s0v{_8CL>4nYvaYBl3(nX?Hc;|&&z!-G?KHdTu!{h=Kb%RKYuoY zE#LpoUmL}&G?kuBn=N7X1&2hA_|IbM>)!p*c0%`OHKI+&XR#*p=E;_@xqo;RLHdh^ zw*ebWP>Zj5y_m!#3oPcnkGEpwm$sS=X#K|G<#HYTdM7bMqeB|=pV=VIcVEeQDdg8Z zD)`v*c@i7A>Kv5IF@jwJ^K0i4CNX=(m$jn(lmF&l#=YqEC%;XfAvhy>J-E%an})+C z5F>|FuMxqlz&!SqAHuQB_Z;$6lmyQC@=Z^2#2`ei;vnlYP6+)R!aR@ics8a9A31+; zLMgF-*uaAmJ{MY+Oo{V?SBe13eY^zJtS&BDgZe0HeYXgr*)PUb+*k0=kzL28yxH&Y zz_XHh@eO^#U{w4}xur+~3=71rZcRt|;!oLs^Ur%#M5FD(5NVhMI0Wu%_m>JBn&%-pRxw`6G6q1 z{XI302qV#cC&Cauu*Gjkf@Yuk!(!Sad86kP!!8y^*e;oN6 zahvPT+~2Zyf)!j+@;Mf=^#!NLBefErp~G`ory*LFAJpmEz=74;HY zhc{BD!NsCbdN}gdK2cF<9a{I{Gs3rX3X4=}eK)hV-O*1?-~oq?^0&j@@Ia;Ym^Ra0 zgwOL{+W2&h2;Kh1vYf&x9qA3?&6uS@w!eek2-fj5ez#-&Ft*Aq=wXG|a-XSDZ18Nt zuDkwo*xdsO@4PMo%v ztZx17BQFk2VNO@oaXAaWF@q*Ymn}+D|KxGGkDXrruh*yYC$MjY2S>DYCa~q}vi$6$ zeNvMc$9PU*d&Uwr_F375jWv&6#iV5lA`^$?wBd%v_L2L%F1%}|Lkzws~^K*?#T5{CBl z6J=wg0x)&|!^JR8)X&NGKR!t0(U*gI1|*oO-+;YAaoLHEzO_~ecYcUlRK@+0fe&Zj zh`&UAT5mII?meN&1NNa&wFl6c-%Sxnu;2l5%&9xEw?*OD!MukuL}W|o{O!BRmRVQM z16ysp?yX1Xx3gTRYUThBBn#ixQV!*TgTGe!9A)Q2as4^>$1SLD;@WH*_d*WXd_6jK z8l6v1g}r(w?Kt2hxhZ{V9@Tf^*4Ia#6^3z!z^B2X0+1cFBoKI<0B&8!*&m?w`DJFs z#D0(f?Hi-lU0zK9>r>&EtC%HWXP0FKyD}eaJaYL4&jS)<#p)(Ip?tTezM$BqCnN|I zTZ+$2BSH6_*^=8R&y;e&+(Z2Tpt4oS3OoE@AXE0w^wobF?)viZeG&EeOqAZ-2sptIOhuy3EHB@)AA5(3($mVBMOkTNig*Ds&F( zCun)NK4gNu>t2j#py!;JbOwKD&UA65Bn$e*Wkdl~o(70o-?R{T}hEf2n#49#J$ zH}JG)^Yvqb`a6nmgvi2<)8VmTHHTGl+zNOQwi4G6K6L2q6=bJwS{vxHS9*r>*fUJm z?xh4=IC3;mq^Oc&wsuV|U&u-yG<~yt@%gWF*yxUy5RYU17!%jEAA!eFU$a%LXY){; z@oa$TK)EVk)pFS{%}Pm-OwpGbx$K96sC(CW(U4e?vi{5WHM34IhtRek6^jB)cb)`vTtxmbuO6eOWF-fL7RhPFqx^t2I*bPn3BaegvX5dm!Z3K` zsv`Fb0yH`O(Wb3WOwv2qXPpE%!uP0N`622*d;RXkhDzi=#OBGi{S*VosU1y%og~2W z2O_Vb^)|3r_%t#^f+WKY4Q8KEyl}R3e-kDN{6m5g!>aV#U$Zp>JFcPiPf89GU5ong zcC^jSIiolspr|3u9j$-p*Vf{}9tkMd?`YpYED3$K(Zz2~*x}dyt4H8HJ7|b%#;J=V z9GbE1f;hrE-VvPl)tK_b>dAyMgAFoZplGOzJHLeGmBv&XyDeg}=O1MbOEALS(o<`a zP~T&l1F8*IT=?Ku)5W#qEFloTC2ROLk{#-Ua;|Is;edx<7W%C0IN{KXlpz{ESXe+9 ziJn6J5TdkB8ye8RzYVSZmjb34z`(tuSxlJ`+DmM8az_|I^0euk+hQM9SD0gZ3*m!N zSk|UR908X5K!zV3>eZ-3xO2{h@Lkp~Deg>M(@Wn5uqMZX+cy`O{-;MqF?MNvD%qb;oypA8LyY@X;tXQOs=Lo4E soB2)Y70ixI(#BKyl2}A({cQh}*HLUdF~O|Qu?KU*|9>kf max_iterations: + done = True + + ftext = "" + if use_ansi: + ftext = env.render(mode="ansi") + else: + ftext = str(info) + + i += 1 + print (u"{}[2J{}[;H".format(chr(27), chr(27))) + print(f"Evaluation: {100 * ep / episodes}%") + print(f"{ftext}") + sleep(.1) + + epochs += 1 + + total_epochs += epochs + total_rewards += reward + + + sleep(1) + except KeyboardInterrupt: + print(f"Results after {episodes} episodes:") + print(f"Average timesteps per episode: {total_epochs / episodes}") + print(f"Average rewards per episode: {total_rewards / episodes}") + + exit() + + print (u"{}[2J{}[;H".format(chr(27), chr(27))) + print("Evaluation: finished.\n") + + print(f"Results after {episodes} episodes:") + print(f"Average timesteps per episode: {total_epochs / episodes}") + print(f"Average rewards per episode: {total_rewards / episodes}") diff --git a/ql/review.py b/ql/review.py new file mode 100644 index 0000000..3c3fde7 --- /dev/null +++ b/ql/review.py @@ -0,0 +1,32 @@ +from . import * + +def review(frames): + sucs = "" + prevSess = -1 + rew = 0 + cnt = 1 + + for i, frame in enumerate(frames): + print (u"{}[2J{}[;H".format(chr(27), chr(27))) + print(f"Session: {frame['session']}") + print(frame['frame']) + print(f"Timestep: {i + 1}") + print(f"State: {frame['state']}") + print(f"Action: {frame['action']}") + print(f"Reward: {frame['reward']}") + + sess = frame['session'] + if sess != prevSess: + if rew > 0: + sucs += "+" + sleep(1) + elif rew < 0: + sucs += "-" + else: + sucs += "." + prevSess = frame['session'] + cnt += 1 + rew = frame['reward'] + + print(f"\nSuccesses: [{sucs}]") + sleep(.1) diff --git a/ql/settings.py b/ql/settings.py new file mode 100644 index 0000000..137a2d7 --- /dev/null +++ b/ql/settings.py @@ -0,0 +1,17 @@ +# OpenAI Gym settings + +gym_name = "Taxi-v3" + +# Q-Learning training settings + +alpha = 0.1 +gamma = 0.8 +epsilon = 0.1 + +# Q-learning player settings + +max_iterations = 1000 + +# Render settings + +use_ansi = True diff --git a/ql/train.py b/ql/train.py new file mode 100644 index 0000000..cb0de74 --- /dev/null +++ b/ql/train.py @@ -0,0 +1,77 @@ +from . import * + +def train(training_episodes=10000, resume=True): + if resume and os.path.exists(gym_name+".dat"): + q_table = pickle.load( open( gym_name+".dat", "rb" )) + else: + q_table = np.zeros([env.observation_space.n, env.action_space.n]) + + episodes = training_episodes + percentage = episodes / 100 + + frames = [] + suc_cnt = 0 + + try: + for i in range(1, episodes + 1): + state = env.reset() + + epochs, reward, = 0, 0 + done = False + + while not done: + if random.uniform(0, 1) < epsilon: + action = env.action_space.sample() + else: + action = np.argmax(q_table[state]) + + next_state, reward, done, info = env.step(action) + + old_value = q_table[state, action] + next_max = np.max(q_table[next_state]) + + new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max) + q_table[state, action] = new_value + + if reward > 0: + suc_cnt += 1 + + state = next_state + epochs += 1 + + if i % percentage == 0: + ftext = "" + if use_ansi: + ftext = env.render(mode="ansi") + else: + ftext = str(info) + + frames.append({ + 'frame': ftext, + 'state': state, + 'action': action, + 'reward': reward, + 'session': i + } + ) + + if i % percentage == 0: + print (u"{}[2J{}[;H".format(chr(27), chr(27))) + print(f"Training: {i/percentage}%") + print(f"Successes so far: {suc_cnt}") + sleep(.1) + print (u"{}[2J{}[;H".format(chr(27), chr(27))) + print("Training: finished.\n") + print(f"Successes totally: {suc_cnt}") + pickle.dump(q_table , open( gym_name+".dat", "wb" ) ) + print(f"Q-table saved: {gym_name}.dat") + + except KeyboardInterrupt: + print (u"{}[2J{}[;H".format(chr(27), chr(27))) + print("Training: stopped.\n") + print(f"Successes totally: {suc_cnt}") + pickle.dump(q_table , open( gym_name+".dat", "wb" ) ) + print(f"Q-table saved: {gym_name}_stopped.dat") + exit() + + return frames \ No newline at end of file