Skip to content

Conversation

@shpgy-shpgy
Copy link
Contributor

@shpgy-shpgy shpgy-shpgy commented Jan 8, 2026

Overview:

The fix modifies the initialization logic to create and reuse a single processor across multiple requests. This eliminates the repeated initialization cost.
The fix also enables the router to handle plain text inputs in multi-modal scenarios.

Details:

Created a new default_multimodal_input_loader function in Dynamo that mirrors the necessary steps originally handled by the external trtllm module. Modified the function signature to accept a pre-initialized processor object as a parameter. This allows the caller to instantiate the processor once and pass the same processor for all subsequent requests.
The handling of plain text inputs is aligned with the method used in the tensorrt_llm.

Where should the reviewer start?

components/src/dynamo/trtllm/multimodal_processor.py
lib/llm/src/preprocessor.rs

Related Issues: (use one of the action keywords Closes / Fixes / Resolves / Relates to)

Summary by CodeRabbit

Release Notes

  • New Features
    • Added comprehensive multimodal input support, enabling processing of images, videos, audio, and mixed content types in conversations.
    • Enhanced chat message handling with improved placeholder management for media content.
    • Refined request processing pipeline for seamless multimodal content assembly and streaming response generation.

✏️ Tip: You can customize this high-level summary in your review settings.

@shpgy-shpgy shpgy-shpgy requested review from a team as code owners January 8, 2026 11:54
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions
Copy link

github-actions bot commented Jan 8, 2026

👋 Hi shpgy-shpgy! Thank you for contributing to ai-dynamo/dynamo.

Just a reminder: The NVIDIA Test Github Validation CI runs an essential subset of the testing framework to quickly catch errors.Your PR reviewers may elect to test the changes comprehensively before approving your changes.

🚀

@github-actions github-actions bot added external-contribution Pull request is from an external contributor fix labels Jan 8, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

Walkthrough

The changes introduce comprehensive multimodal processing support by adding three new functions (default_multimodal_input_loader, get_multimodal_inputs, parse_chat_messages_coroutines), extending MultimodalRequestProcessor with conditional AutoProcessor initialization, and refactoring request handling to branch between pure-text and multimodal pathways. A separate fix ensures extra_args is unconditionally prepared in the Rust preprocessor.

Changes

Cohort / File(s) Summary
Multimodal Processing Refactor
components/src/dynamo/trtllm/multimodal_processor.py
Introduces three new public functions for multimodal input handling: default_multimodal_input_loader (converts prompts and media into conversation messages), get_multimodal_inputs (assembles payload from chat messages), and parse_chat_messages_coroutines (parses messages with media placeholders). Adds TokenizerProtocol typing. Updates MultimodalRequestProcessor to initialize optional self.processor attribute and refactors process_openai_request to delegate to new multimodal pathways. Enhances error handling and logging for tensor loading and file access.
Control Flow Fix
lib/llm/src/preprocessor.rs
Moves logic for preserving original messages into extra_args from conditional block (only when multimodal data exists) to unconditional block (always executes), ensuring extra_args is consistently prepared regardless of media presence.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 wiggles nose excitedly
Images, videos, audio align at last!
New processors weave each modality fast,
Through payloads and streams they dance and play,
Rust and Python coordinate the way—
Multimodal magic, the rabbit's display! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main changes: processor reuse optimization and support for plain text input in multimodal mode.
Description check ✅ Passed The description covers all required template sections with sufficient detail about what was changed, why, and which files to review.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
components/src/dynamo/trtllm/multimodal_processor.py (1)

1-490: Fix pipeline failures: formatting issues block CI.

The pre-commit hooks for isort and black have failed, indicating that the code does not meet the project's formatting standards. These must be resolved before the PR can be merged.

Run the following commands to fix the formatting:

#!/bin/bash
# Fix import sorting and code formatting
pre-commit run isort --files components/src/dynamo/trtllm/multimodal_processor.py
pre-commit run black --files components/src/dynamo/trtllm/multimodal_processor.py
🧹 Nitpick comments (3)
components/src/dynamo/trtllm/multimodal_processor.py (3)

343-359: Add logging for exception handling in image_audio modality detection.

The silent exception handling in the image_audio modality detection logic makes debugging difficult when media loading fails. Consider logging the exceptions before passing.

♻️ Proposed fix to add logging
                 if _modal is None:
                     try:
                         data = load_image(m,
                                           format=image_data_format,
                                           device=device)
                         _modal = "image"
-                    except Exception:
+                    except Exception as e:
+                        logging.debug(f"Failed to load as image: {e}")
                         pass
                 if _modal is None:
                     try:
                         data = load_audio(m, device=device)
                         _modal = "audio"
-                    except Exception:
+                    except Exception as e:
+                        logging.debug(f"Failed to load as audio: {e}")
                         pass

396-397: Address static analysis warnings for loop variable and zip.

The loop variable prompt_idx is unused, and zip() should include an explicit strict= parameter for better error detection.

♻️ Proposed fix
-    for prompt_idx, (prompt,
-                     media) in enumerate(zip(prompts, media_or_embeddings)):
+    for prompt, media in zip(prompts, media_or_embeddings, strict=True):

271-436: Consider refactoring this complex function.

The default_multimodal_input_loader function is 165 lines long and handles multiple modalities with significant complexity. Consider extracting the modality-specific conversion logic (lines 286-371) into separate helper functions for better maintainability.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 92748c9 and bd21b96.

📒 Files selected for processing (2)
  • components/src/dynamo/trtllm/multimodal_processor.py
  • lib/llm/src/preprocessor.rs
🧰 Additional context used
🧠 Learnings (5)
📓 Common learnings
Learnt from: tanmayv25
Repo: ai-dynamo/dynamo PR: 5143
File: examples/multimodal/components/processor.py:207-215
Timestamp: 2026-01-03T01:27:16.084Z
Learning: In examples/multimodal/components/processor.py, the Dynamo multimodal processor intentionally supports flexible content ordering in multimodal requests, allowing audio, image, or video content to appear before text content. This design aligns with the OpenAI multimodal API specification which permits content items in any order, treating the text-first recommendation as advisory rather than mandatory.
Learnt from: KrishnanPrash
Repo: ai-dynamo/dynamo PR: 3067
File: lib/llm/src/preprocessor/prompt/template/oai.rs:87-134
Timestamp: 2025-09-16T19:47:30.312Z
Learning: In Dynamo, multimodal requests (containing image_url or other non-text content) are processed through a completely different workflow than text-only requests, so the may_be_fix_msg_content function in lib/llm/src/preprocessor/prompt/template/oai.rs will only encounter text-only content arrays.
Learnt from: tanmayv25
Repo: ai-dynamo/dynamo PR: 5143
File: examples/multimodal/components/processor.py:207-214
Timestamp: 2026-01-03T01:29:50.237Z
Learning: In examples/multimodal/components/processor.py, the processor example is designed to support only a single text field across all messages in a request. Supporting multiple text fields is out of scope for the current implementation.
📚 Learning: 2026-01-03T01:27:16.084Z
Learnt from: tanmayv25
Repo: ai-dynamo/dynamo PR: 5143
File: examples/multimodal/components/processor.py:207-215
Timestamp: 2026-01-03T01:27:16.084Z
Learning: In examples/multimodal/components/processor.py, the Dynamo multimodal processor intentionally supports flexible content ordering in multimodal requests, allowing audio, image, or video content to appear before text content. This design aligns with the OpenAI multimodal API specification which permits content items in any order, treating the text-first recommendation as advisory rather than mandatory.

Applied to files:

  • components/src/dynamo/trtllm/multimodal_processor.py
📚 Learning: 2026-01-03T01:29:50.237Z
Learnt from: tanmayv25
Repo: ai-dynamo/dynamo PR: 5143
File: examples/multimodal/components/processor.py:207-214
Timestamp: 2026-01-03T01:29:50.237Z
Learning: In examples/multimodal/components/processor.py, the processor example is designed to support only a single text field across all messages in a request. Supporting multiple text fields is out of scope for the current implementation.

Applied to files:

  • components/src/dynamo/trtllm/multimodal_processor.py
📚 Learning: 2025-10-28T04:09:48.264Z
Learnt from: ayushag-nv
Repo: ai-dynamo/dynamo PR: 3634
File: components/src/dynamo/vllm/multimodal_handlers/processor_handler.py:66-72
Timestamp: 2025-10-28T04:09:48.264Z
Learning: In components/src/dynamo/vllm/multimodal_handlers/processor_handler.py, the AutoTokenizer.from_pretrained call with trust_remote_code=True is intentional and expected for the vLLM multimodal handler implementation.

Applied to files:

  • components/src/dynamo/trtllm/multimodal_processor.py
📚 Learning: 2025-09-16T19:47:30.312Z
Learnt from: KrishnanPrash
Repo: ai-dynamo/dynamo PR: 3067
File: lib/llm/src/preprocessor/prompt/template/oai.rs:87-134
Timestamp: 2025-09-16T19:47:30.312Z
Learning: In Dynamo, multimodal requests (containing image_url or other non-text content) are processed through a completely different workflow than text-only requests, so the may_be_fix_msg_content function in lib/llm/src/preprocessor/prompt/template/oai.rs will only encounter text-only content arrays.

Applied to files:

  • components/src/dynamo/trtllm/multimodal_processor.py
  • lib/llm/src/preprocessor.rs
🧬 Code graph analysis (2)
components/src/dynamo/trtllm/multimodal_processor.py (1)
examples/multimodal/utils/audio_loader.py (1)
  • load_audio (61-80)
lib/llm/src/preprocessor.rs (2)
lib/llm/src/preprocessor/prompt/template/oai.rs (2)
  • messages (211-214)
  • messages (268-279)
lib/llm/src/preprocessor/prompt.rs (1)
  • messages (54-54)
🪛 GitHub Actions: Pre Merge Validation of (ai-dynamo/dynamo/refs/pull/5271/merge) by shpgy-shpgy.
components/src/dynamo/trtllm/multimodal_processor.py

[error] pre-commit: isort hook failed; files were modified by this hook. The hook will fail CI until changes are committed.


[error] pre-commit: black hook failed; files were reformatted by this hook. The hook will fail CI until changes are committed.


[error] pre-commit hooks made changes. Re-run pre-commit with 'pre-commit run --all-files'.

🪛 Ruff (0.14.10)
components/src/dynamo/trtllm/multimodal_processor.py

314-316: Avoid specifying long messages outside the exception class

(TRY003)


326-328: Avoid specifying long messages outside the exception class

(TRY003)


335-337: Avoid specifying long messages outside the exception class

(TRY003)


349-350: try-except-pass detected, consider logging the exception

(S110)


349-349: Do not catch blind exception: Exception

(BLE001)


355-356: try-except-pass detected, consider logging the exception

(S110)


355-355: Do not catch blind exception: Exception

(BLE001)


358-358: Avoid specifying long messages outside the exception class

(TRY003)


370-370: Avoid specifying long messages outside the exception class

(TRY003)


396-396: Loop control variable prompt_idx not used within loop body

Rename unused prompt_idx to _prompt_idx

(B007)


397-397: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: Build and Test - dynamo
  • GitHub Check: clippy (lib/runtime/examples)
  • GitHub Check: tests (launch/dynamo-run)
  • GitHub Check: tests (.)
  • GitHub Check: clippy (launch/dynamo-run)
  • GitHub Check: clippy (.)
  • GitHub Check: tests (lib/runtime/examples)
  • GitHub Check: clippy (lib/bindings/python)
  • GitHub Check: tests (lib/bindings/python)
🔇 Additional comments (6)
lib/llm/src/preprocessor.rs (1)

373-379: LGTM! Consistent handling of original messages.

Moving the extra_args preparation outside the multimodal-only conditional ensures that downstream processors (like TRT-LLM multimodal) can always access the original messages, regardless of whether multimodal data was detected. This aligns well with the PR objective to support plain text inputs in multimodal scenarios.

components/src/dynamo/trtllm/multimodal_processor.py (5)

20-33: LGTM! Necessary imports for multimodal support.

The new imports from transformers and tensorrt_llm are required for the multimodal processing functionality being added.


40-50: LGTM! Good use of Protocol for type safety.

The TokenizerProtocol provides clean type hinting for tokenizers, resolving mypy errors related to the decode method.


76-80: LGTM! Processor initialization supports request reuse.

The processor is now initialized once at startup and reused across requests, eliminating repeated initialization overhead as intended by the PR objectives. The use of trust_remote_code=True is consistent with established patterns in the Dynamo multimodal implementation.

Based on learnings, trust_remote_code=True is intentional for multimodal handlers.


191-228: LGTM! Refactored to support both pure-text and multimodal inputs.

The branching logic correctly handles pure-text requests (using get_multimodal_inputs) and multimodal requests (using default_multimodal_input_loader), while reusing the pre-initialized processor and tokenizer to eliminate per-request overhead.


439-462: LGTM! Clean async implementation for multimodal input assembly.

The function properly orchestrates the parsing of chat messages, application of chat templates, and asynchronous retrieval of multimodal data.

@rmccorm4
Copy link
Contributor

rmccorm4 commented Jan 8, 2026

Hi @shpgy-shpgy, thanks for raising this - @indrajit96 @KrishnanPrash please review

@rmccorm4
Copy link
Contributor

rmccorm4 commented Jan 8, 2026

/ok to test 5cf3f29

Signed-off-by: shpgy-shpgy <875664365@qq.com>
Signed-off-by: shpgy-shpgy <875664365@qq.com>
@shpgy-shpgy shpgy-shpgy force-pushed the fix_reinit_tokenizer_and_processor branch from 5cf3f29 to e7a8b75 Compare January 9, 2026 02:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

external-contribution Pull request is from an external contributor fix size/L

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants