Friday, 17 May 2019

How to use tf.keras with bfloat16

I'm trying to get a tf.keras model to run on a TPU using mixed precision. I was wondering how to build the keras model using bfloat16 mixed precision. Is it something like this?

with tf.contrib.tpu.bfloat16_scope():
    inputs = tf.keras.layers.Input(shape=(2,), dtype=tf.bfloat16)
    logits = tf.keras.layers.Dense(2)(inputs)

logits = tf.cast(logits, tf.float32)
model = tf.keras.models.Model(inputs=inputs, outputs=logits)
model.compile(optimizer=tf.keras.optimizers.Adam(.001),
              loss='mean_absolute_error', metrics=[])

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
        model,
        strategy=tf.contrib.tpu.TPUDistributionStrategy(
            tf.contrib.cluster_resolver.TPUClusterResolver(tpu='my_tpu_name')
        )
    )



from How to use tf.keras with bfloat16

No comments:

Post a Comment