Wednesday, November 28, 2018

[Reinforcement Learning] Get started to learn DQN for reinforcement learning

The previous post about Q-Learning is here:
[Reinforcement Learning] Get started to learn Q-Learning for reinforcement learning

Basically, Deep Q-Learning ( DQN ) is upgraded the Q-Learning algorithm and the Q-table is replaced by the neural network. For the DQN tutorial, I refer to these as follows: ( sorry, they are written in Chinese )
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-1-A-DQN/
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-1-DQN1/
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-2-DQN2/
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-3-DQN3/



And, the DQN demo example is in GitHub:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/5_Deep_Q_Network

First, let us recall what the Q-Learning is. Please refer to the picture below. Assume we right now at State1 and there are 2 actions. According to the Q-Table, we take action 2 because of the more potential reward. Then, we can update Q-Table of Q(s1, a2) based on the Q-Learning algorithm.


For DQN, there is something different. At state1, we take action based on picking up the max Q value from Evaluation NN. Then, we use Target NN to get the next state’s Q value for updating NN. ( NN means neural network )

In the DQN demo example, I draw this diagram to help understand it more:


Again, I dump some of the data during the training process of DQN as follows. ( it can be compared with the diagram above. )

('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 1, 0, array([0.25, 0.25]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 1, 0, array([0.25, 0.25]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 0, 0, array([0.25, 0.  ]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.  ]), 1, 0, array([0.25, 0.25]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 1, 0, array([0.25, 0.25]))
('[DEBUG] batch_memory:', array([[ 0.  ,  0.25,  0.  ,  1.  ,  0.  ,  0.  ],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  3.  ,  0.  ,  0.  ,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.  , -0.5 ,  2.  ,  0.  ,  0.25, -0.5 ],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.  ,  0.25,  2.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.25,  0.25,  1.  ,  0.  ,  0.25,  0.25],
       [ 0.  ,  0.25,  1.  ,  0.  ,  0.  ,  0.25]]))
('[DEBUG] q_next, q_eval:', array([[-0.06391249,  0.19887918,  0.22256142, -0.02985281],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.01857544,  0.11471414,  0.10633466, -0.11460548],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.04373299,  0.18115267,  0.1792323 , -0.08237151],
       [-0.09851839,  0.26159775,  0.2543864 , -0.04858976]],
      dtype=float32), array([[-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.05664021,  0.16240907,  0.21497497, -0.04578716],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.04104728,  0.17953533,  0.17629224, -0.08140446],
       [-0.09573826,  0.2598316 ,  0.25124246, -0.04765747]],
      dtype=float32))
('[DEBUG] batch_index, eval_act_index:', array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
      dtype=int32), array([0, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 2, 1, 1, 1]))
('[DEBUG] reward, gamma:', array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 0.9)
('[DEBUG] q_target[batch_index, eval_act_index]:', array([1.2003052 , 0.23543797, 0.23543797, 0.23543797, 0.1630374 ,
       0.1630374 , 0.1630374 , 0.23543797, 0.1630374 , 0.1630374 ,
       0.1630374 , 0.1630374 , 0.23543797, 0.23543797, 0.1630374 ,
       0.1630374 , 0.1630374 , 0.10324273, 0.1630374 , 0.1630374 ,
       0.1630374 , 0.1630374 , 0.1630374 , 0.23543797, 0.1630374 ,
       0.1630374 , 0.23543797, 0.1630374 , 0.1630374 , 0.1630374 ,
       0.1630374 , 0.23543797], dtype=float32))                  0.0  0.0  0.0  0.0


If I choose batch_index[1]=1, eval_act_index[1]=1
q_next[1,1] = 0.26159775

q_next[1,1]  * gamma + reward = 0.26159775 * 0.9 + 0 = 0.23543797
q_target[1,1] = 0.23543797

So, we verify the q_target[1,1] value OK.

No comments: