Friday 11 August 2023

Why does tensorflow.function (without jit_compile) speed up forward passes of a Keras model?

XLA can be enabled using model = tf.function(model, jit_compile=True). Some model types are faster that way, some are slower. So far, so good.

But why can model = tf.function(model, jit_compile=None) speed things up significantly (without TPU) in some cases?

The jit_compile docs state:

If None (default), compiles the function with XLA when running on TPU and goes through the regular function execution path when running on other devices.

I'm running my tests on two non-TPU (and even non-GPU) machines (with the latest TensorFlow (2.13.0) installed).

import timeit

import numpy as np
import tensorflow as tf

model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)


def run(model):
    model(np.random.random(size=(1, 384, 384, 3)))


# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)

runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs

print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078


from Why does tensorflow.function (without jit_compile) speed up forward passes of a Keras model?

No comments:

Post a Comment