Skip to content

Conversation

@Giyu917
Copy link

@Giyu917 Giyu917 commented Dec 17, 2025

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR fixes a precision loss issue ($1e-3$) during the weight conversion process for the Qwen3-Next model between Hugging Face (HF) and Megatron-Core (MCore) formats. During the round-trip conversion (HF -> MCore -> HF), we observed a considerable discrepancy in the LayerNorm weights. This precision drift (up to $10^{-3}$) may be sufficient to cause numerical instability during inference, leading to corrupted outputs (e.g., the model generating repetitive ! in some corner cases).

Root Cause

The issue stems from a subtle implementation difference between the standard RMSNorm in Megatron-Core and the native implementation used in qwen3-next.

Result

When MCore's default RMSNorm is used to load Qwen3 weights, the transformation is not strictly bijective, leading to the observed loss.

Solution

I have replaced the generic Megatron RMSNorm with the Qwen3 native RMSNorm implementation within the MCore model definition for Qwen3-Next. This ensures that the weight mapping and the mathematical operations during inference are perfectly aligned with the original Hugging Face implementation.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Giyu917, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the Qwen3-Next model's conversion process between Hugging Face and Megatron-Core, where a subtle difference in RMSNorm implementations led to significant precision loss and potential model instability. By integrating a custom RMSNorm layer that precisely mirrors the Hugging Face specification, the PR ensures accurate weight mapping and robust numerical stability, preventing corrupted outputs during inference.

Highlights

  • Precision Loss Fix: Resolves a 1e-3 precision loss issue during Qwen3-Next model conversion between Hugging Face (HF) and Megatron-Core (MCore), which previously caused numerical instability and corrupted outputs.
  • Custom RMSNorm Implementation: Introduces Qwen3NextRMSNorm, a custom zero-centered RMSNorm layer designed to exactly match the Hugging Face implementation, thereby eliminating discrepancies.
  • Weight Conversion Alignment: The new Qwen3NextRMSNorm uses a (1 + weight) scaling approach, removing the need for +1/-1 offsets during layernorm weight conversion.
  • Module Integration: Integrates Qwen3NextRMSNorm across all relevant layernorm instances within the get_qwen3_next_transformer_layer_spec function, ensuring consistent application.
  • Code Refinements: Updates imports, re-adds a utility function get_local_layer_specs, and removes an outdated assertion.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a precision loss issue during weight conversion for the Qwen3-Next model by introducing a custom Qwen3NextRMSNorm implementation that aligns with the Hugging Face version. The changes are well-contained and logical, correctly removing the need for weight offsets during conversion. I've made a couple of minor suggestions to improve code clarity and robustness. Overall, this is a solid fix.

def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodules, layer_number: int, **kwargs):
assert config.context_parallel_size == 1, 'Qwen3Next currently does not support context parallel.'
assert _Qwen3NextGatedDeltaNet is not object, 'please update the `transformers` version.'
_Qwen3NextGatedDeltaNet.__init__(self, config, layer_number)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert _Qwen3NextGatedDeltaNet is not object, 'please update the transformers version.' was removed. This check provided a clear error message if Qwen3NextGatedDeltaNet failed to import, guiding the user to update their transformers installation. Without it, a missing dependency will cause a less intuitive TypeError inside _Qwen3NextGatedDeltaNet.__init__. It would be beneficial to restore this assertion for better error handling and developer experience. You could add it back before this line.

if layer_type == 'linear_attention':
layer_spec.submodules.input_layernorm = TENorm
layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear
layer_spec.submodules.input_layernorm = layer_norm_impl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assignment to layer_spec.submodules.input_layernorm is redundant because the same assignment is performed unconditionally on line 527 for all layer types. This line can be safely removed to improve code clarity.

@Jintao-Huang
Copy link
Collaborator

thanks a lot

@Giyu917
Copy link
Author

Giyu917 commented Dec 17, 2025

thanks a lot

It is an awesome framework, and I enjoy using it. Thank you for your continuous efforts.

@Jintao-Huang
Copy link
Collaborator

Hi, I've tested it and found another bug. The linear_qkv in full_attention still contains a layer_norm.

I will submit a new PR based on this one.

@Jintao-Huang
Copy link
Collaborator

#7097

@Giyu917
Copy link
Author

Giyu917 commented Dec 18, 2025

#7097

Thank you for your time and effort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants