I have a keras model of the following form:
class MyAbstractModel(keras.Model)
...
def predict(self, **kwargs)
kwargs["x"] = self.preprocess(kwargs["x"])
super().predict(kwargs)
class MyModel(MyAbstractModel):
def fit(self, **kwargs)
kwargs["x"] = self.preprocess(kwargs["x"])
kwargs["y"] = self.preprocess2(kwargs["y"])
super().predict(kwargs)
def save(self, path):
# I have to save the weights because I need to retain the custom logic in this class
self.save_weights(path / "weights", save_format="tf")
with open(path / "vectorizer", "wb") as file:
joblib.dump({
"config": self.vecorizer.get_config(),
"weights": self.vecorizer.get_weights()
}, file) # self.vectorizer is of type keras.layers.TextVectorizer
@classmethod
def load(cls, path):
instance = cls()
instance.load_weights(path / "weights")
with open(path / "vectorizer", "rb") as file:
loaded_vec = joblib.load(file)
self.vectorizer.from_config(loaded_vec["config"])
self.vectorizer.set_weights(loaded_vec["weights"])
return instance
Note that my model predicts multiple classes with the following head: keras.layers.Dense(N_CLASSES, activation="softmax", name="classifier") and for training, I'm converting my classes into the desired format with keras.utils.to_categorical(label_mapping[y] for y in response) where label_mapping is a dictionary like this: {"A": 0, "B": 1, "C": 2}.
Immediately after training the model, I can successfully use it to generate predictions, however, after saving it and reloading it, the order of my classes is all over the place. The raw predict method gives me an array of N dimensions (where N=number of classes) but the order in which the probabilities are produced changes every time I reload the model. I've found this issue which seems to be relevant and I've tried adding manual_variable_initialization(True) to my training script, but it didn't help.
I've also noticed that I'm getting a warning WARNING:tensorflow: Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. following by a bunch of Value in checkpoint could not be found in the restored object: (root).optimizer's state 'm' for (root).classifier.kernel, etc, but I don't think that is related to the problem. I've traced that warning to this issue.
Many thanks in advance!
EDIT: I've also tried saving the model object with joblib and keras.savel_model with the custom preprocessing function moved outside of the class but this problem remains!
from Loading model weights in tensorflow changes order of predicted classes
No comments:
Post a Comment