Skip to content

[Bug] Qwen3Moe(ForCausalLM) does not work with LoRA at int4 because some data is float32 #3001

@Killusions

Description

@Killusions

Qwen3Moe(ForCausalLM) does not work with LoRA at int4 because some data is float32.

Newest versions of everything, python 3.11, cuda 12.6.

For int8 I was able to specify float32 as the dtype, but for int4 that is not possible.

Error log:

Traceback (most recent call last):
…
    outputs = model(**batch)
              ^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/peft/peft_model.py", line 1845, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 216, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py", line 758, in forward
    return Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, cache_position, logits_to_keep, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py", line 540, in Qwen3MoeForCausalLM_forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py", line [547](https://code.siemens.com/linus/baa/-/jobs/282984733#L547), in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/transformers/modeling_layers.py", line 83, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py", line 350, in forward
    hidden_states, self_attn_weights = self.self_attn(
                                       ^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py", line 319, in forward
    return Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py", line 249, in Qwen3MoeAttention_forward
    query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "unsloth_compiled_cache/Linear4bit_peft_forward.py", line 78, in unsloth_forward
    return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "unsloth_compiled_cache/Linear4bit_peft_forward.py", line 25, in lora_forward
    output = torch_addmm(
             ^^^^^^^^^^^^
RuntimeError: self and mat2 must have the same dtype, but got BFloat16 and Float

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions