Wednesday, October 17, 2018

[TensorFlow Grappler] The ways to traverse all nodes' input and output in the graph using C++ in TensorFlow Grappler

Here I want to introduce 2 ways to traverse all nodes' input and output in the graph using C++ in Grappler.
P.S: you have to be able to get GrapplerItem and GraphDef objects in your code.

First, check my example node name in Tensorboard as follows:
conv1/Conv2D




Method 1: Use GraphView ( also can get input/output port id )
static void FindAllNodesOfInputandOutputViaGraphView(
    const GraphView& graph_view, const GraphDef* graph) {
  for (const auto& node : graph->node()) {
    VLOG(1) << "...[DEBUG2] Traverse via View, node name:" << node.name();
    // false for on control dependency
    for (const auto& fanout : graph_view.GetFanouts(node, false)) {
     VLOG(1) << ".......[DEBUG2] get output name:port_id = " << fanout.node->name()
      << ":" << fanout.port_id;
    }
    for (const auto& fanin : graph_view.GetFanins(node, false)) {
     VLOG(1) << ".......[DEBUG2] get input name:port_id = " << fanin.node->name()
      << ":" << fanin.port_id;
    }
  }
}

Result:
...[DEBUG2] Traverse via View, node name:conv1/Conv2D
.......[DEBUG2] get output name:port_id = conv1/BiasAdd:0
.......[DEBUG2] get input name:port_id = conv1/kernel/read:0
.......[DEBUG2] get input name:port_id = conv1/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer:0


Method 2: Use NodeMap

static void FindAllNodesOfInputandOutputViaNodeMap(
    const NodeMap& node_map, const GraphDef* graph) {
  for (const auto& node : graph->node()) {
    VLOG(1) << "...[DEBUG2] Traverse via Map, node name:" << node.name();
    for (const NodeDef* output_node : node_map.GetOutputs(node.name())) {
     VLOG(1) << ".......[DEBUG2] get output name:" << output_node->name();
    }
    for (const string& input_name : node.input()) {
      const NodeDef* input_node = node_map.GetNode(input_name);
      VLOG(1) << ".......[DEBUG2] get input name:" << input_node->name();
    }
  }
}

Result:
...[DEBUG2] Traverse via Map, node name:conv1/Conv2D
.......[DEBUG2] get output name:conv1/BiasAdd
.......[DEBUG2] get input name:conv1/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer
.......[DEBUG2] get input name:conv1/kernel/read

Here you go!
P.S: the input node: conv1/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer is different from the one in the graph picture because I turn on Layout Optimizer and it modifies the graph before my code traverses them.

Reference:
To know what GraphDef and NodeDef are as follows:
https://qiita.com/AtuNuka/items/8b6a9d632641fbc787d4




No comments: