-
Notifications
You must be signed in to change notification settings - Fork 111
Open
Description
Traceback (most recent call last):
File "/root/hertz-dev/inference_server.py", line 168, in <module>
audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/hertz-dev/inference_server.py", line 59, in __init__
self.initialize_state(prompt_path)
File "/root/hertz-dev/inference_server.py", line 80, in initialize_state
self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/hertz-dev/model.py", line 323, in next_audio_from_audio
next_latents = self.next_latent(latents_in, temps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/hertz-dev/model.py", line 341, in next_latent
logits = self.forward(model_input)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/hertz-dev/model.py", line 310, in forward
x = self.input(data)
^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/ai/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (69x64 and 32x4096)
(ai) root@87a45c3f6c90:~/hertz-dev# nvidia-smi
Wed Dec 4 01:55:13 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05 Driver Version: 550.127.05 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:01:00.0 Off | Off |
| 0% 26C P8 21W / 450W | 2MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
(ai) root@87a45c3f6c90:~/hertz-dev# pip freeze
annotated-types==0.7.0
anyio==4.6.2.post1
asttokens==3.0.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.4.0
click==8.1.7
contourpy==1.3.1
cycler==0.12.1
decorator==5.1.1
einops==0.8.0
executing==2.1.0
fastapi==0.115.4
filelock==3.16.1
fonttools==4.55.1
fsspec==2024.10.0
h11==0.14.0
hf_transfer==0.1.8
huggingface-hub==0.26.2
idna==3.10
IProgress==0.4
ipython==8.18.1
jedi==0.19.2
Jinja2==3.1.4
kiwisolver==1.4.7
MarkupSafe==3.0.2
matplotlib==3.9.2
matplotlib-inline==0.1.7
mpmath==1.3.0
networkx==3.4.2
numpy==1.26.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
packaging==24.2
parso==0.8.4
pexpect==4.9.0
pillow==11.0.0
prompt_toolkit==3.0.48
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.3
pydantic_core==2.27.1
Pygments==2.18.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
PyYAML==6.0.2
requests==2.32.3
six==1.16.0
sniffio==1.3.1
sounddevice==0.5.1
soundfile==0.12.1
stack-data==0.6.3
starlette==0.41.3
sympy==1.13.1
torch==2.5.1
torchaudio==2.5.1
tqdm==4.66.6
traitlets==5.14.3
triton==3.1.0
typing_extensions==4.12.2
urllib3==2.2.3
uvicorn==0.32.0
wcwidth==0.2.13
websockets==13.1
(ai) root@87a45c3f6c90:~/hertz-dev# conda info
active environment : ai
active env location : /root/miniconda3/envs/ai
shell level : 2
user config file : /root/.condarc
populated config files : /root/miniconda3/.condarc
conda version : 24.9.2
conda-build version : not installed
python version : 3.12.7.final.0
solver : libmamba (default)
virtual packages : __archspec=1=zen3
__conda=24.9.2=0
__cuda=12.4=0
__glibc=2.35=0
__linux=6.5.0=0
__unix=0=0
base environment : /root/miniconda3 (writable)
conda av data dir : /root/miniconda3/etc/conda
conda av metadata url : None
channel URLs : https://repo.anaconda.com/pkgs/main/linux-64
https://repo.anaconda.com/pkgs/main/noarch
https://repo.anaconda.com/pkgs/r/linux-64
https://repo.anaconda.com/pkgs/r/noarch
package cache : /root/miniconda3/pkgs
/root/.conda/pkgs
envs directories : /root/miniconda3/envs
/root/.conda/envs
platform : linux-64
user-agent : conda/24.9.2 requests/2.32.3 CPython/3.12.7 Linux/6.5.0-45-generic ubuntu/22.04.5 glibc/2.35 solver/libmamba conda-libmamba-solver/24.9.0 libmambapy/1.5.8 aau/0.4.4 c/. s/. e/.
UID:GID : 0:0
netrc file : None
offline mode : False
Metadata
Metadata
Assignees
Labels
No labels