Skip to content

Conversation

@LuJunru
Copy link
Contributor

@LuJunru LuJunru commented Jan 8, 2026

What does this PR do?

This PR adds the implementation for the released Youtu-LLM model. The model has the following features:

  • Type: Autoregressive Causal Language Models with Dense MLA
  • Release versions: Base and Instruct

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

@LuJunru LuJunru mentioned this pull request Jan 8, 2026
5 tasks
@LuJunru
Copy link
Contributor Author

LuJunru commented Jan 8, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43166&sha=5dab39

Hi @ArthurZucker @Cyrilvallez

May I ask if it is possible to concentrate the test only on Youtu-LLM (the new model)? The summary here seems report errors raised by other models.
截屏2026-01-08 19 01 30

junru

LuJunru and others added 7 commits January 8, 2026 19:14
…ition_embedding in DiT (huggingface#43068)

* qwen2_5_omni: make max_mel_frames an inference-time knob

* not fail with raising ValueError, instead make it continue to run by choosing a target_duration that's capped and aligned

* added unit tests for Token2WavShape shape mismatch

Signed-off-by: Dong Wang <dongw2019@gmail.com>

* make fixup

* remove unit test which takes too much GPU memory

Signed-off-by: Dong Wang <dongw2019@gmail.com>

* reduce gpu memory usage from the unit test

* addressed comments

Signed-off-by: Dong Wang <dongw2019@gmail.com>

---------

Signed-off-by: Dong Wang <dongw2019@gmail.com>
@LuJunru
Copy link
Contributor Author

LuJunru commented Jan 9, 2026

Hi @ArthurZucker @Cyrilvallez

It seems Youtu-LLM-related codes have passed the auto review. The remaining check fails on other models.
截屏2026-01-09 09 49 13

@molbap molbap self-assigned this Jan 12, 2026
@molbap molbap self-requested a review January 12, 2026 14:13
@molbap
Copy link
Contributor

molbap commented Jan 12, 2026

run-slow: youtu_llm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/youtu_llm"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems clean, good modular file with simply Llama + MLA, beautiful. Asked a few questions, let me know and I'll re-review!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the official name YoutuLLM or Youtu as in the prefixes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We chose to use Youtu as the prefix of modules, as it is more suitable for extension (e.g., we plan to introduce YoutuVL in near future). Youtu-LLM is rather a brand name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then everything that has to see with model name (youtu) should be named as such, like the model directory


model_sdpa = YoutuForCausalLM.from_pretrained(
"tencent/Youtu-LLM-2B-Base",
dtype=torch.float16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make sdpa explicit here



class YoutuModel(LlamaModel):
_keys_to_ignore_on_load_unexpected = [""]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to remove the Llama attribute? if so, ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the current version of the model (Youtu-LLM-2B family), this line of code could be removed.

@require_torch_accelerator
@pytest.mark.torch_compile_test
@require_read_token
def test_compile_static_cache(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding an integration test! however naming-wise, seems to measure dynamic and static Cache no? By the way, could we have a simple no-compile integration test that works in the simplest setting, just to avoid regressions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have provided inference tests below based on no-compile dynamic cache and no-compile static cache. Basically, I implemented this test function by referencing test function of DeepSeek V3.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, can we update the name though to make it more clear and separate in two tests? that way if it breaks at some point it's easier to debug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, is there any official examples that I can follow up?

Comment on lines 238 to 265
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Youtu-LLM is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass

@unittest.skip("Youtu-LLM is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@unittest.skip("Youtu-LLM is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass

@unittest.skip("Youtu-LLM uses MLA so it is not compatible with the standard cache format")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass

@unittest.skip("Youtu-LLM uses MLA so it is not compatible with the standard cache format")
def test_greedy_generate_dict_outputs_use_cache(self):
pass

@unittest.skip(reason="SDPA can't dispatch on flash due to unsupported head dims")
def test_sdpa_can_dispatch_on_flash(self):
pass

@unittest.skip(reason="Youtu-LLM is not suitable for testing with extreme small vocabulary")
def test_resize_tokens_embeddings(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are all these tests indeed not working?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check if we can fix the majority by moving the tests under the CausalLM wrapper classes

@LuJunru
Copy link
Contributor Author

LuJunru commented Jan 13, 2026

Hi @molbap

I've updated a new version of code according to the discussion aforementioned. Can you help start a new solo test of Youtu-LLM (run-slow: youtu_llm)?

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing some quick fixes so some comments are outdated but it will still need another look into the tests especially in regards to the causal lm tester (I have done something quickly but shapes are not working as expected for some tests so either needs to fix the init properly or overwrite some tests that check shapes)

I hope that lifts some confusions

Youtu_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


class YoutuConfig(PreTrainedConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could remove unrelated attributes after the super with self.attr

Comment on lines +122 to +123
initializer_range: float | None = None,
embedding_initializer_range: float | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could at least avoid the else branch, but yes fair I overlooked that embedding was also dependent there

Comment on lines +148 to +163
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldnt be needed, we should have this directly as arg if we already have the power to change this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the YoutuConfig is inherited from DeepseekV3Config, so most codes removed.

Comment on lines 39 to 49
class YoutuDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: YoutuConfig, layer_idx: int):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size

self.self_attn = YoutuAttention(config=config, layer_idx=layer_idx)

self.mlp = YoutuMLP(config)

self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class YoutuDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: YoutuConfig, layer_idx: int):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = YoutuAttention(config=config, layer_idx=layer_idx)
self.mlp = YoutuMLP(config)
self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class YoutuDecoderLayer(LlamaDecoderLayer):
pass

Should now be the same as Llama, no? The init should overlap now too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Comment on lines 61 to 64
if isinstance(module, nn.Linear):
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you use initializer_range to init linear layers so we handle this already, see

if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range or 0.02
get initializer range
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
if getattr(module, "weight", None) is not None:
init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
init.zeros_(module.bias)
update linear layer

Hence, you can leave that part out and only update the embedding part

"parakeet": "ParakeetCTCConfig",
"lasr": "LasrCTCConfig",
"wav2vec2-with-lm": "Wav2Vec2Config",
"youtu": "YoutuConfig",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed, would like to leave this out if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

"YolosImageProcessor",
"YolosModel",
"YosoConfig",
"YoutuConfig",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run make fix-repo (without having this added to any exceptions) --> you will see it will update a few strings like I showed above, namely initializer_range and embedding_initializer_range - you might need to left tab them to have the same spacing again as before.

Comment on lines +345 to +347
if model.config.tie_word_embeddings:
# Youtu-LLM-2B-Base contains extra repeated weights for the tied embeddings, we can tie weights here according to its config
model.tie_weights()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes gotcha

@LuJunru
Copy link
Contributor Author

LuJunru commented Jan 27, 2026

Hi @vasqu

I have fixed most of the issues mentioned above, please check again.

There is one specific issue related to check_docstring.py. After using make fix-repo, the original docstring will be changed from:

    Args:
            vocab_size (`int`, *optional*, defaults to 128256):
                Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
                `inputs_ids` passed when calling [`YoutuModel`]
            hidden_size (`int`, *optional*, defaults to 2048):
                Dimension of the hidden representations.
            intermediate_size (`int`, *optional*, defaults to 6144):
                Dimension of the MLP representations.
            num_hidden_layers (`int`, *optional*, defaults to 32):
                Number of hidden layers in the Transformer decoder.
            num_attention_heads (`int`, *optional*, defaults to 16):
                Number of attention heads for each attention layer in the Transformer decoder.
            num_key_value_heads (`int`, *optional*, defaults to 16):
                In MLA, num_key_value_heads=num_attention_heads.
            kv_lora_rank (`int`, *optional*, defaults to 512):
                Rank of the LoRA matrices for key and value projections.
            q_lora_rank (`int`, *optional*, defaults to 1536):
                Rank of the LoRA matrices for query projections.
            qk_rope_head_dim (`int`, *optional*, defaults to 64):
                Dimension of the query/key heads that use rotary position embeddings.
            v_head_dim (`int`, *optional*, defaults to 128):
                Dimension of the value heads.
            qk_nope_head_dim (`int`, *optional*, defaults to 128):
                Dimension of the query/key heads that don't use rotary position embeddings.
            hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
                The non-linear activation function (function or string) in the decoder.
            max_position_embeddings (`int`, *optional*, defaults to 131072):
                The maximum sequence length that this model might ever be used with.
            initializer_range (`float`, *optional*):
                The standard deviation of the truncated_normal_initializer for initializing all weight matrices, except embedding matrices.
            embedding_initializer_range (`float`, *optional*):
                The standard deviation of the truncated_normal_initializer for initializing all embedding matrices.
            rms_norm_eps (`float`, *optional*, defaults to 1e-06):
                The epsilon used by the rms normalization layers.
            use_cache (`bool`, *optional*, defaults to `True`):
                Whether or not the model should return the last key/values attentions (not used by all models). Only
                relevant if `config.is_decoder=True`.
            pad_token_id (`int`, *optional*):
                Padding token id.
            bos_token_id (`int`, *optional*, defaults to 128000):
                Beginning of stream token id.
            eos_token_id (`int`, *optional*, defaults to 128001):
                End of stream token id.
            tie_word_embeddings (`bool`, *optional*, defaults to `True`):
                Whether to tie weight embeddings
            rope_parameters (`RopeParameters`, *optional*):
                Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
                a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
                with longer `max_position_embeddings`.
            rope_interleave (`bool`, *optional*, defaults to `True`):
                Whether to interleave the rotary position embeddings.
            attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
                Whether to use a bias in the query, key, value and output projection layers during self-attention.
            attention_dropout (`float`, *optional*, defaults to 0.0):
                The dropout ratio for the attention probabilities.

to

    Args:
            vocab_size (`int | None`, *optional*, defaults to 128256): <fill_docstring>
            hidden_size (`int | None`, *optional*, defaults to 2048): <fill_docstring>
            intermediate_size (`int | None`, *optional*, defaults to 6144): <fill_docstring>
            num_hidden_layers (`int | None`, *optional*, defaults to 32): <fill_docstring>
            num_attention_heads (`int | None`, *optional*, defaults to 16): <fill_docstring>
            num_key_value_heads (`int | None`, *optional*, defaults to 16): <fill_docstring>
            kv_lora_rank (`int | None`, *optional*, defaults to 512): <fill_docstring>
            q_lora_rank (`int | None`, *optional*, defaults to 1536): <fill_docstring>
            qk_rope_head_dim (`int | None`, *optional*, defaults to 64): <fill_docstring>
            v_head_dim (`int | None`, *optional*, defaults to 128): <fill_docstring>
            qk_nope_head_dim (`int | None`, *optional*, defaults to 128): <fill_docstring>
            hidden_act (`str | None`, *optional*, defaults to `"silu"`): <fill_docstring>
            max_position_embeddings (`int | None`, *optional*, defaults to 131072): <fill_docstring>
            initializer_range (`float | None`, *optional*): <fill_docstring>
            embedding_initializer_range (`float | None`, *optional*): <fill_docstring>
            rms_norm_eps (`int | None`, *optional*, defaults to 1e-06): <fill_docstring>
            use_cache (`bool | None`, *optional*, defaults to `True`): <fill_docstring>
            pad_token_id (`int | None`, *optional*): <fill_docstring>
            bos_token_id (`int | None`, *optional*, defaults to 128000): <fill_docstring>
            eos_token_id (`int | None`, *optional*, defaults to 128001): <fill_docstring>
            tie_word_embeddings (`bool | None`, *optional*, defaults to `True`): <fill_docstring>
            rope_parameters (`transformers.modeling_rope_utils.RopeParameters | dict[str, transformers.modeling_rope_utils.RopeParameters]`, *optional*): <fill_docstring>
            rope_interleave (`bool | None`, *optional*, defaults to `True`): <fill_docstring>
            attention_bias (`bool | None`, *optional*, defaults to `False`): <fill_docstring>
            attention_dropout (`float | None`, *optional*, defaults to 0.0): <fill_docstring>

I noticed many <fill_docstring> placeholder here, is this correct? Meanwhile, even though I use the updated docstring by make fix-repo, it still raise check_repository_consistency errors.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed a few things in regards to the config and tests, that resolves your issue as well it seems. Lmk if not!

Just a few last nits but approving since it's nothing major, quick checking with our slow CI in a second (might need to adjust values because of GPU differences)

logger = logging.get_logger(__name__)


class YoutuConfig(DeepseekV3Config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that was confusing on my side --> I meant to say to add this to modular then. You can see that it will unfold the inherited attributes in the config file (which also solves the consistency issues) but better double check I haven't missed something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, got it

Comment on lines +202 to +203
def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs):
raise AttributeError("Not overwritten for the Youtu model!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi, this way you can disable inheriting function from others

Comment on lines +45 to +85
class YoutuModelTester(CausalLMModelTester):
if is_torch_available():
base_model_class = YoutuModel

def __init__(
self,
parent,
kv_lora_rank=16,
q_lora_rank=32,
qk_rope_head_dim=32,
qk_nope_head_dim=32,
v_head_dim=32,
):
super().__init__(parent=parent)
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim


@require_torch
class YoutuModelTest(CausalLMModelTest, unittest.TestCase):
model_tester_class = YoutuModelTester

def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
"""Needs to be overridden as youtu-llm has special MLA cache format (though we don't really use the MLA)"""
self.assertIsInstance(past_key_values, Cache)

# (batch, head, seq_length, head_features)
expected_common_shape = (
batch_size,
getattr(config, "num_key_value_heads", config.num_attention_heads),
seq_length,
)
expected_key_shape = expected_common_shape + (config.qk_nope_head_dim + config.qk_rope_head_dim,)
expected_value_shape = expected_common_shape + (config.v_head_dim,)

for layer in past_key_values.layers:
self.assertEqual(layer.keys.shape, expected_key_shape)
self.assertEqual(layer.values.shape, expected_value_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this to use our causal lm class - it makes it easier for use to refactor tests in the future

Comment on lines +90 to +91
def tearDown(self):
cleanup(torch_device, gc_collect=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also clean on setup e.g.

def setup(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
# some memory allocated in the cache, which means some object is not being released properly. This causes some
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
# Investigate the root cause.
cleanup(torch_device, gc_collect=True)

(no need to copy the comment)

@vasqu
Copy link
Contributor

vasqu commented Jan 27, 2026

It does seem we need to fix at least: FAILED tests/models/youtu/test_modeling_youtu.py::YoutuModelTest::test_config - ZeroDivisionError: float division by zero

@vasqu
Copy link
Contributor

vasqu commented Jan 27, 2026

run-slow: youtu

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/youtu"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • youtu:
    tests/models/youtu/test_modeling_youtu.py::YoutuModelTest::test_config
    tests/models/youtu/test_modeling_youtu.py::YoutuModelTest::test_sdpa_can_dispatch_on_flash

@LuJunru
Copy link
Contributor Author

LuJunru commented Jan 27, 2026

It does seem we need to fix at least: FAILED tests/models/youtu/test_modeling_youtu.py::YoutuModelTest::test_config - ZeroDivisionError: float division by zero

alright, i guess this was the error that led me to add conditions of hidden_size = 0 in the configuration_youtu.py previously.

@vasqu
Copy link
Contributor

vasqu commented Jan 27, 2026

test_sdpa_can_dispatch_on_flash seems to fail because of some hidden dim incompatibility - you can either skip with the proper reason or try to adjust the values in the tester

@vasqu
Copy link
Contributor

vasqu commented Jan 28, 2026

run-slow: youtu

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, youtu

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/youtu"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@vasqu vasqu merged commit be87564 into huggingface:main Jan 28, 2026
21 of 24 checks passed
@vasqu
Copy link
Contributor

vasqu commented Jan 28, 2026

@LuJunru I just updated some last few nits, everything passes (the other tests are unrelated) so I merged

Thanks for iterating and gz on the model addition 🤗

@LuJunru
Copy link
Contributor Author

LuJunru commented Jan 28, 2026

@LuJunru I just updated some last few nits, everything passes (the other tests are unrelated) so I merged

Thanks for iterating and gz on the model addition 🤗

@molbap @vasqu @xenova 🤗 Thank you for the professional suggestions! I'm going to update the usage in our official repos.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants