[XLA 研究] How to use XLA AOT compilation in TensorFlow

This document is going to explain how to use AOT compilation in TensorFlow. We will use the tool: tfcompile, which is a standalone tool that ahead-of-time (AOT) compiles TensorFlow graphs into executable code. It can reduce the total binary size, and also avoid some runtime overheads. A typical use-case of tfcompile is to compile an inference graph into executable code for mobile devices. The following steps are as follows:

1. Build tool: tfcompile
> bazel build --config=opt --config=cuda //tensorflow/compiler/aot:tfcompile

2. Run this file: make_simple_test_graph.py to build graph & config files as follows:
import argparse
import os
import sys

from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib

FLAGS = None

def tfmatmul(_):
  x = array_ops.placeholder(dtypes.float32, name='x_hold')
  y = array_ops.placeholder(dtypes.float32, name='y_hold')
  math_ops.matmul(x, y, name='x_y_prod')

def tfmatmulandadd(_):
  # This tests multiple outputs.
  x = array_ops.placeholder(dtypes.float32, name='x_hold')
  y = array_ops.placeholder(dtypes.float32, name='y_hold')
  math_ops.matmul(x, y, name='x_y_prod')
  math_ops.add(x, y, name='x_y_sum')

def write_graph(build_graph, out_dir):
  """Build a graph using build_graph and write it out."""
  g = ops.Graph()
  with g.as_default():
    build_graph(out_dir)
    filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__)
    with open(filename, 'wb') as f:
      f.write(g.as_graph_def().SerializeToString())

def main(_):
  write_graph(tfmatmul, FLAGS.out_dir)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.register('type', 'bool', lambda v: v.lower() == 'true')
  parser.add_argument(
      '--out_dir',
      type=str,
      default='',
      help='Output directory for graphs, checkpoints and savers.')
  FLAGS, unparsed = parser.parse_known_args()
  app.run(main=main, argv=[sys.argv[0]] + unparsed)

3. Add scripts to BUILD and generate C++ header files:
> vi BUILD

load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
tf_library(
    name = "test_graph_tfmatmul",
    graph = "test_graph_tfmatmul.pb",
    config = "test_graph_tfmatmul.config.pbtxt",
    cpp_class = "foo::bar::MatMulComp"
)
> bazel build :test_graph_tfmatmul

4. Write test C++ code:
> vi my_code.cc

# include "tensorflow/compiler/aot/tests/myaot/test_graph_tfmatmul.h"

int main(int argc, char** argv) {
  foo::bar::MatMulComp matmul;

  // Set up args and run the computation.
  const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
  std::copy(args + 0, args + 6, matmul.arg0_data());
  std::copy(args + 6, args + 12, matmul.arg1_data());
  matmul.Run();

  // Check result
  if (matmul.result0(0, 0) == 58) {
    std::cout << "Success" << std::endl;
  } else {
    std::cout << "Failed. Expected value 58 at 0,0. Got:"
              << matmul.result0(0, 0) << std::endl;
  }

  return 0;
}

5. Add scripts to BUILD and build my_code.cc
> vi BUILD

cc_binary(
    name = "my_binary",
    srcs = [ "my_code.cc" ],
    deps = [ ":test_graph_tfmatmul", "//third_party/eigen3" ],
    linkopts = [ "-lpthread" ]
)
> bazel build :my_binary

6. To run the result:
> bazel run my_binary
Build XLA AOT compilation

So, I got the result==> "Success", which is proved that the MatMul calculation in AOT compilation is correct!





Post a Comment

Popular posts from this blog

[JSON] How to use jansson lib to generate JSON data in C

[Open vSwitch] How to get port statistics from interface in OVS

[Quagga] How to compile and install Quagga on Ubuntu 12.04