[Reinforcement Learning] Get started to learn Q-Learning for reinforcement learning
Basically, Deep Q-Learning ( DQN ) is upgraded the Q-Learning algorithm and the Q-table is replaced by the neural network. For the DQN tutorial, I refer to these as follows: ( sorry, they are written in Chinese )
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-1-A-DQN/
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-1-DQN1/
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-2-DQN2/
https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/4-3-DQN3/
And, the DQN demo example is in GitHub:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/5_Deep_Q_Network
First, let us recall what the Q-Learning is. Please refer to the picture below. Assume we right now at State1 and there are 2 actions. According to the Q-Table, we take action 2 because of the more potential reward. Then, we can update Q-Table of Q(s1, a2) based on the Q-Learning algorithm.
For DQN, there is something different. At state1, we take action based on picking up the max Q value from Evaluation NN. Then, we use Target NN to get the next state’s Q value for updating NN. ( NN means neural network )
In the DQN demo example, I draw this diagram to help understand it more:
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 1, 0, array([0.25, 0.25]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 1, 0, array([0.25, 0.25]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 0, 0, array([0.25, 0. ]))
('[DEBUG] s, a, r, _s:', array([0.25, 0. ]), 1, 0, array([0.25, 0.25]))
('[DEBUG] s, a, r, _s:', array([0.25, 0.25]), 1, 0, array([0.25, 0.25]))
('[DEBUG] batch_memory:', array([[ 0. , 0.25, 0. , 1. , 0. , 0. ],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 3. , 0. , 0. , 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0. , -0.5 , 2. , 0. , 0.25, -0.5 ],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0. , 0.25, 2. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0.25, 0.25, 1. , 0. , 0.25, 0.25],
[ 0. , 0.25, 1. , 0. , 0. , 0.25]]))
('[DEBUG] q_next, q_eval:', array([[-0.06391249, 0.19887918, 0.22256142, -0.02985281],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.01857544, 0.11471414, 0.10633466, -0.11460548],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.04373299, 0.18115267, 0.1792323 , -0.08237151],
[-0.09851839, 0.26159775, 0.2543864 , -0.04858976]],
dtype=float32), array([[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.05664021, 0.16240907, 0.21497497, -0.04578716],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.04104728, 0.17953533, 0.17629224, -0.08140446],
[-0.09573826, 0.2598316 , 0.25124246, -0.04765747]],
dtype=float32))
('[DEBUG] batch_index, eval_act_index:', array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
dtype=int32), array([0, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 2, 1, 1, 1]))
('[DEBUG] reward, gamma:', array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 0.9)
('[DEBUG] q_target[batch_index, eval_act_index]:', array([1.2003052 , 0.23543797, 0.23543797, 0.23543797, 0.1630374 ,
0.1630374 , 0.1630374 , 0.23543797, 0.1630374 , 0.1630374 ,
0.1630374 , 0.1630374 , 0.23543797, 0.23543797, 0.1630374 ,
0.1630374 , 0.1630374 , 0.10324273, 0.1630374 , 0.1630374 ,
0.1630374 , 0.1630374 , 0.1630374 , 0.23543797, 0.1630374 ,
0.1630374 , 0.23543797, 0.1630374 , 0.1630374 , 0.1630374 ,
0.1630374 , 0.23543797], dtype=float32)) 0.0 0.0 0.0 0.0
If I choose batch_index[1]=1, eval_act_index[1]=1
q_next[1,1] = 0.26159775
q_next[1,1] * gamma + reward = 0.26159775 * 0.9 + 0 = 0.23543797
q_target[1,1] = 0.23543797
So, we verify the q_target[1,1] value OK.
No comments:
Post a Comment