P.S: you have to be able to get GrapplerItem and GraphDef objects in your code.
First, check my example node name in Tensorboard as follows:
conv1/Conv2D
Method 1: Use GraphView ( also can get input/output port id )
static void FindAllNodesOfInputandOutputViaGraphView( const GraphView& graph_view, const GraphDef* graph) { for (const auto& node : graph->node()) { VLOG(1) << "...[DEBUG2] Traverse via View, node name:" << node.name(); // false for on control dependency for (const auto& fanout : graph_view.GetFanouts(node, false)) { VLOG(1) << ".......[DEBUG2] get output name:port_id = " << fanout.node->name() << ":" << fanout.port_id; } for (const auto& fanin : graph_view.GetFanins(node, false)) { VLOG(1) << ".......[DEBUG2] get input name:port_id = " << fanin.node->name() << ":" << fanin.port_id; } } }
Result:
...[DEBUG2] Traverse via View, node name:conv1/Conv2D
.......[DEBUG2] get output name:port_id = conv1/BiasAdd:0
.......[DEBUG2] get input name:port_id = conv1/kernel/read:0
.......[DEBUG2] get input name:port_id = conv1/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer:0
Method 2: Use NodeMap
static void FindAllNodesOfInputandOutputViaNodeMap( const NodeMap& node_map, const GraphDef* graph) { for (const auto& node : graph->node()) { VLOG(1) << "...[DEBUG2] Traverse via Map, node name:" << node.name(); for (const NodeDef* output_node : node_map.GetOutputs(node.name())) { VLOG(1) << ".......[DEBUG2] get output name:" << output_node->name(); } for (const string& input_name : node.input()) { const NodeDef* input_node = node_map.GetNode(input_name); VLOG(1) << ".......[DEBUG2] get input name:" << input_node->name(); } } }
Result:
...[DEBUG2] Traverse via Map, node name:conv1/Conv2D
.......[DEBUG2] get output name:conv1/BiasAdd
.......[DEBUG2] get input name:conv1/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer
.......[DEBUG2] get input name:conv1/kernel/read
Here you go!
P.S: the input node: conv1/Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer is different from the one in the graph picture because I turn on Layout Optimizer and it modifies the graph before my code traverses them.
Reference:
To know what GraphDef and NodeDef are as follows:
https://qiita.com/AtuNuka/items/8b6a9d632641fbc787d4
No comments:
Post a Comment