Sunday, 30 June 2019

Unable to go from tf.keras model -> quantized frozen graph -> .tflite with TOCO

I am new to all of these tools. I'm trying to get started with using Tensorflow Lite to ultimately run my own deep learning models on the Coral Edge TPU.

I have built a toy XOR network with the Keras API, written out the tensorflow graph, and frozen it. Now I'm trying to use TOCO to convert the frozen model to tflite format. I'm getting the following error:

ValueError: Input 0 of node dense_1/weights_quant/AssignMinLast was passed float from dense_1/weights_quant/min:0 incompatible with expected float_ref.

I have seen others talking about similar errors on github but I have not been able to find a solution.

Full code below:

training_data = np.array([[0,0],[0,1],[1,0],[1,1]], "uint8")
target_data = np.array([[0],[1],[1],[0]], "uint8")

model = Sequential()
model.add(Dense(16, input_dim=2, use_bias=False, activation='relu'))
model.add(Dense(1, use_bias=False, activation='sigmoid'))

session = tf.keras.backend.get_session()
tf.contrib.quantize.create_training_graph(session.graph)
session.run(tf.global_variables_initializer())

model.compile(loss='mean_squared_error',
              optimizer='adam',
              metrics=['binary_accuracy'])

model.fit(training_data, target_data, nb_epoch=1000, verbose=2)
print model.predict(training_data).round()
model.summary()

saver = tf.train.Saver()
saver.save(keras.backend.get_session(), 'xor-keras.ckpt')

tf.io.write_graph(session.graph, '.', 'xor-keras.pb')

Then freeze the model:

python freeze_graph.py \
  --input_graph='xor-keras.pb' \
  --input_checkpoint='xor-keras.ckpt' \
  --output_graph='xor-keras-frozen.pb' \
  --output_node_name='dense_2/Sigmoid'

Then calling toco like this:

toco \
  --graph_def_file=xor-keras-frozen.pb \
  --output_file=xor-keras-frozen.tflite \
  --input_shapes=1,2 \
  --input_arrays='dense_1_input' \
  --output_arrays='dense_2/Sigmoid' \
  --inference_type=QUANTIZED_UINT8

Here is the full output from TOCO:

2019-06-26 15:31:17.374904: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
2019-06-26 15:31:17.404237: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2600000000 Hz
2019-06-26 15:31:17.407613: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55bbcf9a5ed0 executing computations on platform Host. Devices:
2019-06-26 15:31:17.407741: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
Traceback (most recent call last):
  File "/home/redacted/.local/bin/toco", line 11, in <module>
    sys.exit(main())
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 503, in main
    app.run(main=run_main, argv=sys.argv[:1])
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/redacted/.local/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/redacted/.local/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 499, in run_main
    _convert_tf1_model(tflite_flags)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 124, in _convert_tf1_model
    converter = _get_toco_converter(flags)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 111, in _get_toco_converter
    return converter_fn(**converter_kwargs)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/lite.py", line 628, in from_frozen_graph
    _import_graph_def(graph_def, name="")
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 431, in import_graph_def
    raise ValueError(str(e))
ValueError: Input 0 of node dense_1/weights_quant/AssignMinLast was passed float from dense_1/weights_quant/min:0 incompatible with expected float_ref.



from Unable to go from tf.keras model -> quantized frozen graph -> .tflite with TOCO

No comments:

Post a Comment