Tuesday, 7 September 2021

AssertionError: Tried to export a function which references untracked resource

I wrote a unit-test in order to safe a model after noticing that I am not able to do so (anymore) during training.

@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_save_model(speech_model: Tuple[TransducerBase, SpeechFeaturesConfig]):
    model, speech_features_config = speech_model
    speech_features_config: SpeechFeaturesConfig
    channels = 3 if speech_features_config.add_delta_deltas else 1
    num_mel_bins = speech_features_config.num_mel_bins
    enc_inputs = np.random.rand(1, 50, num_mel_bins, channels)
    dec_inputs = np.expand_dims(np.random.randint(0, 25, size=10), axis=1)
    inputs = enc_inputs, dec_inputs
    model(inputs)

    # Throws KeyError:
    # graph = tf.compat.v1.get_default_graph()
    # tensor = graph.get_tensor_by_name("77040:0")

    directory = tempfile.mkdtemp(prefix=f"{model.__class__.__name__}_")
    try:
        model.save(directory)
    finally:
        shutil.rmtree(directory)

Trying to save the model will always throw the following error:

E         AssertionError: Tried to export a function which references untracked resource Tensor("77040:0", shape=(), dtype=resource). TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.
E         
E         Trackable Python objects referring to this tensor (from gc.get_referrers, limited to two hops):
E         <tf.Variable 'transformer_transducer/transducer_encoder/inputs_embedding/convolution_stack/conv2d/kernel:0' shape=(3, 3, 3, 32) dtype=float32>

Note: As you can see in the code above, but I am also not able to retrieve this tensor with tf.compat.v1.get_default_graph().get_tensor_by_name("77040:0").

The problem is that I do not understand why I am getting this because the affected layer is tracked by Keras as you can see in the screenshot below. I took it during a debug-session in the call() function.

enter image description here

I have no explanation for this and I am running out of ideas what the issue might be here.

The transformations list in the screenshot is a property of and getting constructed by a layer InputsEmbedding like so:

class InputsEmbedding(layers.Layer, TimeReduction):
    def __init__(self, config: InputsEmbeddingConfig, **kwargs):
        super().__init__(**kwargs)

        if config.transformations is None or not len(config.transformations):
            raise RuntimeError("No transformations provided.")

        self.config = config

        self.transformations = list()
        for transformation in self.config.transformations:
            layer_name, layer_params = list(transformation.items())[0]
            layer = _get_layer(layer_name, layer_params)
            self.transformations.append(layer)

        self.init_time_reduction_layer()

    def get_config(self):
        return self.config.dict()

In order to verify that the problem is not the InputsEmbedding, I created a unit-text for saving a model that is using just this particular layer.

@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_inputs_embedding_save_model():
    convolutions = [
        "filters=2, kernel_size=(3, 3), strides=(2, 1)",
        "filters=4, kernel_size=(3, 3), strides=(2, 1)",
        "filters=8, kernel_size=(3, 4), strides=(1, 1)",
    ]

    config = InputsEmbeddingConfig()
    config.transformations = [dict(conv2d_stack=dict(convolutions=convolutions)), dict(stack_frames=dict(n=2))]

    num_features = 8
    num_channels = 3

    inputs = layers.Input(shape=(None, num_features, num_channels))
    x = inputs
    x, _ = InputsEmbedding(config)(x)
    model = keras.Model(inputs=inputs, outputs=x)
    model.build(input_shape=(1, 20, num_features, num_channels))

    directory = tempfile.mkdtemp(prefix=f"{model.__class__.__name__}_")
    try:
        model.save(directory)
    finally:
        shutil.rmtree(directory)

Here I am able to save this layer without any issues:

enter image description here



from AssertionError: Tried to export a function which references untracked resource

No comments:

Post a Comment