Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions nemo_deploy/service/fastapi_interface_to_pytriton.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# limitations under the License.

import json
import logging
import os

import numpy as np
Expand All @@ -19,12 +20,7 @@

from nemo_deploy.llm import NemoQueryLLMPyTorch

try:
from nemo.utils import logging
except (ImportError, ModuleNotFoundError):
import logging

logging = logging.getLogger(__name__)
logger = logging.getLogger(__name__)


class TritonSettings(BaseSettings):
Expand All @@ -39,10 +35,7 @@ def __init__(self):
self._triton_service_port = int(os.environ.get("TRITON_PORT", 8000))
self._triton_service_ip = os.environ.get("TRITON_HTTP_ADDRESS", "0.0.0.0")
except Exception as error:
logging.error(
"An exception occurred trying to retrieve set args in TritonSettings class. Error:",
error,
)
logger.error(f"An exception occurred trying to retrieve set args in TritonSettings class. Error: {error}")
return

@property
Expand Down Expand Up @@ -81,7 +74,7 @@ class BaseRequest(BaseModel):
def set_greedy_params(self):
"""Validate parameters for greedy decoding."""
if self.temperature == 0 and self.top_p == 0:
logging.warning("Both temperature and top_p are 0. Setting top_k to 1 to ensure greedy sampling.")
logger.warning("Both temperature and top_p are 0. Setting top_k to 1 to ensure greedy sampling.")
self.top_k = 1
return self

Expand Down Expand Up @@ -134,7 +127,7 @@ async def check_triton_health():
triton_url = (
f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready"
)
logging.info(f"Attempting to connect to Triton server at: {triton_url}")
logger.info(f"Attempting to connect to Triton server at: {triton_url}")
try:
response = requests.get(triton_url, timeout=5)
if response.status_code == 200:
Expand Down Expand Up @@ -233,7 +226,7 @@ async def query_llm_async(
async def completions_v1(request: CompletionRequest):
"""Defines the completions endpoint and queries the model deployed on PyTriton server."""
url = f"http://{triton_settings.triton_service_ip}:{triton_settings.triton_service_port}"
logging.info(f"Request: {request}")
logger.info(f"Request: {request}")
prompts = request.prompt
if not isinstance(request.prompt, list):
prompts = [request.prompt]
Expand Down Expand Up @@ -266,7 +259,7 @@ async def completions_v1(request: CompletionRequest):
output_serializable["choices"][0]["logprobs"]["token_logprobs"].insert(0, None)
else:
output_serializable["choices"][0]["logprobs"] = None
logging.info(f"Output: {output_serializable}")
logger.info(f"Output: {output_serializable}")
return output_serializable


Expand All @@ -279,7 +272,7 @@ def dict_to_str(messages):
async def chat_completions_v1(request: ChatCompletionRequest):
"""Defines the chat completions endpoint and queries the model deployed on PyTriton server."""
url = f"http://{triton_settings.triton_service_ip}:{triton_settings.triton_service_port}"
logging.info(f"Request: {request}")
logger.info(f"Request: {request}")
prompts = request.messages
if not isinstance(request.messages, list):
prompts = [request.messages]
Expand Down Expand Up @@ -315,5 +308,5 @@ async def chat_completions_v1(request: ChatCompletionRequest):
0
][0]

logging.info(f"Output: {output_serializable}")
logger.info(f"Output: {output_serializable}")
return output_serializable
35 changes: 10 additions & 25 deletions nemo_export/multimodal/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import shutil
import tarfile
import tempfile
from pathlib import Path
from time import time
from types import SimpleNamespace
from typing import List
Expand All @@ -26,11 +25,8 @@
import yaml
from packaging import version

from nemo_export.tensorrt_llm import TensorRTLLM
from nemo_export.trt_llm.nemo_ckpt_loader.nemo_file import load_nemo_model
from nemo_export_deploy_common.import_utils import (
MISSING_NEMO_MSG,
MISSING_TENSORRT_LLM_MSG,
MISSING_TENSORRT_MSG,
MISSING_TRANSFORMERS_MSG,
UnavailableError,
Expand Down Expand Up @@ -108,24 +104,12 @@ def build_trtllm_engine(
max_lora_rank: int = 64,
lora_ckpt_list: List[str] = None,
):
"""Build TRTLLM engine by nemo export."""
if not HAVE_TRT_LLM:
raise UnavailableError(MISSING_TENSORRT_LLM_MSG)

trt_llm_exporter = TensorRTLLM(model_dir=model_dir, lora_ckpt_list=lora_ckpt_list, load_model=False)
trt_llm_exporter.export(
nemo_checkpoint_path=visual_checkpoint_path if llm_checkpoint_path is None else llm_checkpoint_path,
model_type=llm_model_type,
tensor_parallelism_size=tensor_parallelism_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_seq_len=max_input_len + max_output_len,
max_batch_size=max_batch_size,
dtype=dtype,
load_model=False,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
"""Build TRTLLM engine by nemo export.

Note: TensorRT-LLM export support has been removed.
"""
raise NotImplementedError(
"TensorRT-LLM export support has been removed from this codebase. This function is no longer available."
)


Expand Down Expand Up @@ -350,9 +334,10 @@ def build_neva_engine(
mp0_weights = torch.load(weights_path, map_location=device)
else:
# extract NeMo checkpoint
with tempfile.TemporaryDirectory() as temp:
temp_path = Path(temp)
mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path)
raise NotImplementedError(
"Loading NeMo checkpoints via trt_llm utilities has been removed. "
"Please extract the checkpoint manually or use an earlier version."
)

vision_config = nemo_config["mm_cfg"]["vision_encoder"]

Expand Down
Loading