Skip to content

Keras Fails to load quantized model #21378

Open
@pctablet505

Description

@pctablet505

Whenever we try to save a keras-hub model after quantization, we are unable to load the quantized model. I've tried from_preset() method for that model, and also keras.models.load_model nothing works.

I've attached notebook
https://colab.research.google.com/gist/pctablet505/b5ef8ab36dceb58527e992b571aefb70/keras-quantized-model-not-loading.ipynb

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-5-3241078411>](https://localhost:8080/#) in <cell line: 0>()
      8 import keras_hub
      9 
---> 10 gemma_lm_quantized = keras_hub.models.Gemma3CausalLM.from_preset("stored_gemma_int8")
     11 gemma_lm_quantized.generate("hello, what is your name?")

4 frames
[/content/keras_hub_repo/keras_hub/src/models/task.py](https://localhost:8080/#) in from_preset(cls, preset, load_weights, **kwargs)
    196         # images, audio).
    197         load_task_weights = "num_classes" not in kwargs
--> 198         return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
    199 
    200     def load_task_weights(self, filepath):

[/content/keras_hub_repo/keras_hub/src/utils/preset_utils.py](https://localhost:8080/#) in load_task(self, cls, load_weights, load_task_weights, **kwargs)
    701             else:
    702                 jax_memory_cleanup(task.backbone)
--> 703             self._load_backbone_weights(task.backbone)
    704         return task
    705 

[/content/keras_hub_repo/keras_hub/src/utils/preset_utils.py](https://localhost:8080/#) in _load_backbone_weights(self, backbone)
    754                 # Download the sharded weights.
    755                 _ = get_file(self.preset, sharded_filename)
--> 756         backbone.load_weights(filepath)
    757 
    758 

[/content/keras_repo/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

[/content/keras_repo/keras/src/saving/saving_lib.py](https://localhost:8080/#) in _raise_loading_failure(error_msgs, warn_only)
    648         warnings.warn(msg)
    649     else:
--> 650         raise ValueError(msg)
    651 
    652 

ValueError: A total of 183 objects could not be loaded. Example error message for object <ReversibleEmbedding name=token_embedding, built=True>:

Layer 'token_embedding' expected 1 variables, but received 0 variables during loading. Expected: ['embeddings']

List of objects that could not be loaded:
[<ReversibleEmbedding name=token_embedding, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gat...

I get similar error for other models like LLama, or others.
I've tested on gemma2, gemma3, llama3.1, lllama3.2, llama3 and more.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions