Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot save and load custom model after updating to keras 3 #4753

Closed
3zhang opened this issue Aug 3, 2024 · 2 comments
Closed

Cannot save and load custom model after updating to keras 3 #4753

3zhang opened this issue Aug 3, 2024 · 2 comments
Labels

Comments

@3zhang
Copy link

3zhang commented Aug 3, 2024

Describe the current behavior
This is an example from official keras document:

@keras.saving.register_keras_serializable(package="MyLayers")
class CustomLayer(keras.layers.Layer):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def call(self, x):
        return x * self.factor

    def get_config(self):
        return {"factor": self.factor}


@keras.saving.register_keras_serializable(package="my_package", name="custom_fn")
def custom_fn(x):
    return x**2

def get_model():
    inputs = keras.Input(shape=(4,))
    mid = CustomLayer(0.5)(inputs)
    outputs = keras.layers.Dense(1, activation=custom_fn)(mid)
    model = keras.Model(inputs, outputs)
    model.compile(optimizer="rmsprop", loss="mean_squared_error")
    return model

def train_model(model):
    input = np.random.random((4, 4))
    target = np.random.random((4, 1))
    model.fit(input, target)
    return model

test_input = np.random.random((4, 4))
test_target = np.random.random((4, 1))

model = get_model()
model = train_model(model)
model.save("custom_model.keras")

# Note that here I restart the colab session and run:
reconstructed_model = keras.models.load_model("custom_model.keras")

Error:


TypeError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/keras/src/saving/serialization_lib.py in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
717 try:
--> 718 instance = cls.from_config(inner_config)
719 except TypeError as e:

10 frames
TypeError: Could not locate class 'CustomLayer'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable(). Full object config: {'module': None, 'class_name': 'CustomLayer', 'config': {'factor': 0.5}, 'registered_name': 'MyLayers>CustomLayer', 'build_config': {'input_shape': [None, 4]}, 'name': 'custom_layer', 'inbound_nodes': [{'args': [{'class_name': 'keras_tensor', 'config': {'shape': [None, 4], 'dtype': 'float32', 'keras_history': ['input_layer', 0, 0]}}], 'kwargs': {}}]}

During handling of the above exception, another exception occurred:

TypeError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/keras/src/saving/serialization_lib.py in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
718 instance = cls.from_config(inner_config)
719 except TypeError as e:
--> 720 raise TypeError(
721 f"{cls} could not be deserialized properly. Please"
722 " ensure that components that are Python object"

TypeError: <class 'keras.src.models.functional.Functional'> could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config() are explicitly deserialized in the model's from_config() method.

config={'module': 'keras.src.models.functional', 'class_name': 'Functional', 'config': {'name': 'functional', 'trainable': True, 'layers': [{'module': 'keras.layers', 'class_name': 'InputLayer', 'config': {'batch_shape': [None, 4], 'dtype': 'float32', 'sparse': False, 'name': 'input_layer'}, 'registered_name': None, 'name': 'input_layer', 'inbound_nodes': []}, {'module': None, 'class_name': 'CustomLayer', 'config': {'factor': 0.5}, 'registered_name': 'MyLayers>CustomLayer', 'build_config': {'input_shape': [None, 4]}, 'name': 'custom_layer', 'inbound_nodes': [{'args': [{'class_name': 'keras_tensor', 'config': {'shape': [None, 4], 'dtype': 'float32', 'keras_history': ['input_layer', 0, 0]}}], 'kwargs': {}}]}, {'module': 'keras.layers', 'class_name': 'Dense', 'config': {'name': 'dense', 'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}, 'units': 1, 'activation': {'module': 'builtins', 'class_name': 'function', 'config': 'my_package>custom_fn', 'registered_name': 'function'}, 'use_bias': True, 'kernel_initializer': {'module': 'keras.initializers', 'class_name': 'GlorotUniform', 'config': {'seed': None}, 'registered_name': None}, 'bias_initializer': {'module': 'keras.initializers', 'class_name': 'Zeros', 'config': {}, 'registered_name': None}, 'kernel_regularizer': None, 'bias_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}, 'registered_name': None, 'build_config': {'i...

Exception encountered: Could not locate class 'CustomLayer'. Make sure custom classes are decorated with @keras.saving.register_keras_serializable(). Full object config: {'module': None, 'class_name': 'CustomLayer', 'config': {'factor': 0.5}, 'registered_name': 'MyLayers>CustomLayer', 'build_config': {'input_shape': [None, 4]}, 'name': 'custom_layer', 'inbound_nodes': [{'args': [{'class_name': 'keras_tensor', 'config': {'shape': [None, 4], 'dtype': 'float32', 'keras_history': ['input_layer', 0, 0]}}], 'kwargs': {}}]}

Describe the expected behavior
Should load the model successfully

What web browser you are using
Chrome

@3zhang 3zhang added the bug label Aug 3, 2024
@cperry-goog
Copy link

You'll need to upgrade your code to support Keras 3 in the medium term (see #4744) In the short term you can use the fallback version as described there.

@3zhang
Copy link
Author

3zhang commented Aug 12, 2024

You'll need to upgrade your code to support Keras 3 in the medium term (see #4744) In the short term you can use the fallback version as described there.

Well, this code example is from karas 3 official document.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants