Thursday, 7 September 2023

GridSearchCV: You must compile your model before training/testing. But my model is already compiled

I am trying to use GridSearchCV for tuning the hyper-parameters (epochs) of my model. This is the minimal, reproducible example:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

latent_dimension = 25
encoder = Encoder(latent_dimension, (28, 28, 1))
decoder = Decoder()
vae = VAE(encoder, decoder)
vae.compile(Adam())  # Compiled here

param_grid = {'epochs': [10, 20, 30]}
grid = GridSearchCV(vae, param_grid, scoring=mean_absolute_error, cv=2)
grid.fit(x_train, y_train)

But unfortunately grid.fit(x_train, y_train) gives:

RuntimeError: You must compile your model before training/testing. Use model.compile(optimizer, loss).

But I already compiled my model. How can I fix the problem?

This is VAE, Encoder and Decoder implementation:

import tensorflow as tf
from keras import layers
from keras.optimizers.legacy import Adam
from sklearn.base import BaseEstimator
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import GridSearchCV
from tensorflow import keras

def sample(z_mean, z_log_var):
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch, dim))
    stddev = tf.exp(0.5 * z_log_var)
    return z_mean + stddev * epsilon

class VAE(keras.Model, BaseEstimator):
    def __init__(self, encoder, decoder, epochs=None, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.epochs = epochs
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    def call(self, inputs, training=None, mask=None):
        _, _, z = self.encoder(inputs)
        outputs = self.decoder(z)
        return outputs

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            # Forward pass
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)

            # Compute losses
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        # Compute gradient
        grads = tape.gradient(total_loss, self.trainable_weights)

        # Update weights
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # Update my own metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def test_step(self, data):
        # Forward pass
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)

        # Compute losses
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # Update my own metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }


class Encoder(keras.Model):
    def __init__(self, latent_dimension, input_shape):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dimension
        self.conv_block1 = keras.Sequential([
            layers.Input(shape=input_shape),
            layers.Conv2D(filters=64, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.conv_block2 = keras.Sequential([
            layers.Conv2D(filters=128, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.conv_block3 = keras.Sequential([
            layers.Conv2D(filters=256, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(units=100, activation="relu")
        self.z_mean = layers.Dense(latent_dimension, name="z_mean")
        self.z_log_var = layers.Dense(latent_dimension, name="z_log_var")
        self.sampling = sample

    def call(self, inputs, training=None, mask=None):
        x = self.conv_block1(inputs)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.flatten(x)
        x = self.dense(x)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        z = self.sampling(z_mean, z_log_var)
        return z_mean, z_log_var, z


class Decoder(keras.Model):
    def __init__(self):
        super(Decoder, self).__init__()
        self.dense1 = keras.Sequential([
            layers.Dense(units=4096, activation="relu"),
            layers.BatchNormalization()
        ])
        self.dense2 = keras.Sequential([
            layers.Dense(units=1024, activation="relu"),
            layers.BatchNormalization()
        ])
        self.dense3 = keras.Sequential([
            layers.Dense(units=4096, activation="relu"),
            layers.BatchNormalization()
        ])
        self.reshape = layers.Reshape((4, 4, 256))
        self.deconv1 = keras.Sequential([
            layers.Conv2DTranspose(filters=256, kernel_size=3, activation="relu", strides=2, padding="same"),
            layers.BatchNormalization()
        ])
        self.deconv2 = keras.Sequential([
            layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=1, padding="same"),
            layers.BatchNormalization()
        ])
        self.deconv3 = keras.Sequential([
            layers.Conv2DTranspose(filters=128, kernel_size=3, activation="relu", strides=2, padding="valid"),
            layers.BatchNormalization()
        ])
        self.deconv4 = keras.Sequential([
            layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=1, padding="valid"),
            layers.BatchNormalization()
        ])
        self.deconv5 = keras.Sequential([
            layers.Conv2DTranspose(filters=64, kernel_size=3, activation="relu", strides=2, padding="valid"),
            layers.BatchNormalization()
        ])
        self.deconv6 = layers.Conv2DTranspose(filters=1, kernel_size=2, activation="sigmoid", padding="valid")

    def call(self, inputs, training=None, mask=None):
        x = self.dense1(inputs)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.reshape(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.deconv3(x)
        x = self.deconv4(x)
        x = self.deconv5(x)
        decoder_outputs = self.deconv6(x)
        return decoder_outputs

And this is the full traceback of the error:

Traceback (most recent call last):
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/del.py", line 195, in <module>
    grid.fit(x_train, y_train)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 874, in fit
    self._run_search(evaluate_candidates)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 1388, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_search.py", line 851, in evaluate_candidates
    _warn_or_raise_about_fit_failures(out, self.error_score)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 367, in _warn_or_raise_about_fit_failures
    raise ValueError(all_fits_failed_message)
ValueError: 
All the 6 fits failed.
It is very likely that your model is misconfigured.
You can try to debug the error by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
6 fits failed with the following error:
Traceback (most recent call last):
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 686, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/alex/PycharmProjects/VAE-EEG-XAI/venv/lib/python3.11/site-packages/keras/src/engine/training.py", line 3875, in _assert_compile_was_called
    raise RuntimeError(
RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)`. 


from GridSearchCV: You must compile your model before training/testing. But my model is already compiled

No comments:

Post a Comment