Whitepaper: Implementation of Control Flow in TensorFlow
http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf
There are five basic element ops to perform the while loop function:
Switch : A Switch operator forwards the input tensor d to one of its outputs depending on the
boolean tensor of the control input p. A Switch is enabled for execution when both its inputs are
available.
Merge : A Merge operator forwards one of its available inputs to its output. A Merge is enabled
for execution when any of its inputs is available. It is unspecified which available input it outputs if there are multiple inputs available.
Enter(name) : An Enter operator forwards its input to the execution frame that is uniquely
identified by the given name. This Enter op is used to pass a tensor in one execution frame to a
child execution frame. There can be multiple Enter ops to the same child execution frame, each
making a tensor available (asynchronously) in that child execution frame. An Enter is enabled
for execution when its input is available. A new execution frame is instantiated in the
TensorFlow runtime when the first Enter op to that frame is executed.
Exit : An Exit operator forwards a value from an execution frame to its parent execution frame.
This Exit op is used to return a tensor computed in a child execution frame back to its parent
frame. There can be multiple Exit ops to the parent frame, each asynchronously passing a
tensor back to the parent frame. An Exit is enabled when its input is available.
NextIteration: A NextIteration operator forwards its input to the next iteration in the current
execution frame.
We also can see the implementation of while_loop function in C++ as follows:
tensorflow/cc/ops/while_loop.cc
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, const CondGraphBuilderFn& cond, const BodyGraphBuilderFn& body, const string& frame_name, OutputList* outputs, bool create_while_ctx, Output* cond_output)
A while loop with a single loop variable looks like this:
(output)
^ +---------------+
| | body subgraph +-------------+
Exit +---------------+ |
^ ^ |
| | |
Switch<--------+ v
^ | NextIteration
| +------+--------+ |
+---->| cond subgraph | |
| +---------------+ |
Merge<---------------------------+
^
|
Enter
^
|
(input)
If there are multiple loop variables, each of the control flow ops is
duplicated for each loop variable.
Because of so many points in it, here I only want to highlight and mention the point of memory swapping related as follows:
"To reuse forward values in backprop loop, we automatically detect, during the construction of
the backprop while loop, the forward values that are needed in the backprop. For each such
forward value x, we automatically introduce a stack and add nodes in the forward loop to save
its value at each iteration to the stack. The backprop loop uses the values from the stack in the
reverse order. The stack lives outside the forward and backprop loops and is shared by the
two loops."
And this stack push operation can be used in While_Loop operation, and TensorFlow will generate stack pop one in the backpropagation phase. If you check the source code, it can be found in GradLoopState Class.
tensorflow/python/ops/control_flow_ops.py
780 class GradLoopState(object):
781 """The state used for constructing the gradient graph for a while loop.
782
783 We create a GradLoopState for each while loop in forward and its
784 corresponding while loop in backprop. This gives us access to both
785 the forward and the backprop WhileContexts.
786
787 During the construction of gradient graph, any time when we detect
788 a forward value that is needed for backprop, we create a history
789 accumulator and add it to `history_map`. Any time when we backprop
790 a loop switch op (in _SwitchGrad), we add the grad merge op in
791 `switch_map`.
792 """
...
...
For more explanation and experiments in dynamic control flow, please refer to the paper:Dynamic Control Flow in Large-Scale Machine Learning
No comments:
Post a Comment