P.S: If you use Keras to write your RNN model, you won't need to deal with these details.
The short example of Static RNN
Please pay a tension about the output shape in the following picture.
batch_size = 32
time_step = 5
input_size = 4
rnn_cell = 20
X = tf.placeholder(tf.float32, shape=[batch_size, time_step, input_size])
x=tf.unstack(X,axis=1)
lstm_cell = rnn.BasicLSTMCell(rnn_cell)
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
output=outputs[-1]
The short example of Dynamic RNN
Please pay a tension about the output shape in the following picture.
batch_size = 32
time_step = 5
input_size = 4
rnn_cell = 20
X = tf.placeholder(tf.float32, shape=[batch_size, time_step, input_size])
lstm_cell = rnn.BasicLSTMCell(rnn_cell)
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
outputs=tf.transpose(outputs, [1, 0, 2])
output=outputs[-1]
RNN API:
https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn
https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
No comments:
Post a Comment