Wednesday, August 8, 2018

[ONNX] Train in Tensorflow and export to ONNX (Part I)

From my point of view, ONNX is a model description spec and ONNX model needs Deep Learning framework or backend tool/compiler which supports it to run.
The advantage of ONNX as I know is about portable and exchangeable between DL frameworks.
Here I will use this tutorial to convert TensorFlow's model to ONNX model by myself.

https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb




First, use their example to train a simple CNN model with MNIST dataset:
https://github.com/onnx/tutorials/blob/7b549ae622ff8d74a5f5e0c32e109267f4c9ccae/tutorials/assets/tf-train-mnist.py
python tf-train-mnist.py
Second, freeze the model by TensorFlow's freeze_graph: ( the format of input graph is binary )
You can compare with the post: Train in Tensorflow and export to ONNX (Part II)
bazel-bin/tensorflow/python/tools/freeze_graph \
    --input_graph=/home/liudanny/workspace/tf_to_onnx/graph.proto \
    --input_checkpoint=/home/liudanny/workspace/tf_to_onnx/ckpt/model.ckpt \
    --output_graph=/home/liudanny/workspace/tf_to_onnx/frozen_graph.pb \
    --output_node_names=fc2/add \
    --input_binary=True
Third, convert the model to ONNX format

with tf.gfile.GFile("frozen_graph.pb", "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    onnx_model = tensorflow_graph_to_onnx_model(graph_def,
                                     "fc2/add")
    file = open("mnist.onnx", "wb")
    file.write(onnx_model.SerializeToString())
    file.close()
So, we get "mnist.onnx" file which is converted from TensorFlow's model.
Now, I use Netron to visualize my model in graph view on web. Here it looks like:







No comments: