XLA can be enabled using model = tf.function(model, jit_compile=True). Some model types are faster that way, some are slower. So far, so good.
But why can model = tf.function(model, jit_compile=None) speed things up significantly (without TPU) in some cases?
The jit_compile docs state:
If
None(default), compiles the function with XLA when running on TPU and goes through the regular function execution path when running on other devices.
I'm running my tests on two non-TPU (and even non-GPU) machines (with the latest TensorFlow (2.13.0) installed).
import timeit
import numpy as np
import tensorflow as tf
model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
def run(model):
model(np.random.random(size=(1, 384, 384, 3)))
# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)
runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs
print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078
from Why does tensorflow.function (without jit_compile) speed up forward passes of a Keras model?
No comments:
Post a Comment