Skip to content

RuntimeError: mat1 and mat2 shapes cannot be multiplied (69x64 and 32x4096) in inference_server.py #33

@ahmetkca

Description

@ahmetkca
Traceback (most recent call last):
  File "/root/hertz-dev/inference_server.py", line 168, in <module>
    audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/inference_server.py", line 59, in __init__
    self.initialize_state(prompt_path)
  File "/root/hertz-dev/inference_server.py", line 80, in initialize_state
    self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/model.py", line 323, in next_audio_from_audio
    next_latents = self.next_latent(latents_in, temps)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/model.py", line 341, in next_latent
    logits = self.forward(model_input)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/hertz-dev/model.py", line 310, in forward
    x = self.input(data)
        ^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (69x64 and 32x4096)
(ai) root@87a45c3f6c90:~/hertz-dev# nvidia-smi
Wed Dec  4 01:55:13 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0 Off |                  Off |
|  0%   26C    P8             21W /  450W |       2MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
(ai) root@87a45c3f6c90:~/hertz-dev# pip freeze
annotated-types==0.7.0
anyio==4.6.2.post1
asttokens==3.0.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.4.0
click==8.1.7
contourpy==1.3.1
cycler==0.12.1
decorator==5.1.1
einops==0.8.0
executing==2.1.0
fastapi==0.115.4
filelock==3.16.1
fonttools==4.55.1
fsspec==2024.10.0
h11==0.14.0
hf_transfer==0.1.8
huggingface-hub==0.26.2
idna==3.10
IProgress==0.4
ipython==8.18.1
jedi==0.19.2
Jinja2==3.1.4
kiwisolver==1.4.7
MarkupSafe==3.0.2
matplotlib==3.9.2
matplotlib-inline==0.1.7
mpmath==1.3.0
networkx==3.4.2
numpy==1.26.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
packaging==24.2
parso==0.8.4
pexpect==4.9.0
pillow==11.0.0
prompt_toolkit==3.0.48
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.3
pydantic_core==2.27.1
Pygments==2.18.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
PyYAML==6.0.2
requests==2.32.3
six==1.16.0
sniffio==1.3.1
sounddevice==0.5.1
soundfile==0.12.1
stack-data==0.6.3
starlette==0.41.3
sympy==1.13.1
torch==2.5.1
torchaudio==2.5.1
tqdm==4.66.6
traitlets==5.14.3
triton==3.1.0
typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.32.0
wcwidth==0.2.13
websockets==13.1
(ai) root@87a45c3f6c90:~/hertz-dev# conda info

     active environment : ai
    active env location : /root/miniconda3/envs/ai
            shell level : 2
       user config file : /root/.condarc
 populated config files : /root/miniconda3/.condarc
          conda version : 24.9.2
    conda-build version : not installed
         python version : 3.12.7.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=zen3
                          __conda=24.9.2=0
                          __cuda=12.4=0
                          __glibc=2.35=0
                          __linux=6.5.0=0
                          __unix=0=0
       base environment : /root/miniconda3  (writable)
      conda av data dir : /root/miniconda3/etc/conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /root/miniconda3/pkgs
                          /root/.conda/pkgs
       envs directories : /root/miniconda3/envs
                          /root/.conda/envs
               platform : linux-64
             user-agent : conda/24.9.2 requests/2.32.3 CPython/3.12.7 Linux/6.5.0-45-generic ubuntu/22.04.5 glibc/2.35 solver/libmamba conda-libmamba-solver/24.9.0 libmambapy/1.5.8 aau/0.4.4 c/. s/. e/.
                UID:GID : 0:0
             netrc file : None
           offline mode : False


Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions