-
Notifications
You must be signed in to change notification settings - Fork 73
Open
Labels
questionFurther information is requestedFurther information is requested
Description
Describe the issue
Hi, thank you for your excellent work on MInference and for open-sourcing such a well-structured and impactful repository.
I'm trying to reproduce the end-to-end benchmark experiments from this repository, but I'm encountering a CUDA out-of-memory error (with a 80G A100 GPU).
For e2e benchmark with target length the following command, the OOM error occurs
python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 1_000_000 --kv_cache_cpu
The log is:
CUDA_VISIBLE_DEVICES=0 python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 1_000_000 --kv_cache_cpu
/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/utils/hub.py:111: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
warnings.warn(
INFO 12-02 19:13:31 [__init__.py:243] Automatically detected platform cuda.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.84it/s]
<---- MInference Config Detail ----> attn_type minference, kv_type dense
Traceback (most recent call last):
File "/home/cpwu/working/SPTCAttn/MInference/experiments/benchmarks/benchmark_e2e.py", line 147, in <module>
run_target_length(args.context_window, model, args.attn_type)
File "/home/cpwu/working/SPTCAttn/MInference/experiments/benchmarks/benchmark_e2e.py", line 33, in run_target_length
model(
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward
outputs: BaseModelOutputWithPast = self.model(
^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 969, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 453, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 48, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/minference/patch.py", line 527, in forward_llama_decoder_layer
attention_outputs = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/minference/patch.py", line 874, in <lambda>
lambda self, *args, **kwargs: forward(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/minference/modules/forward.py", line 85, in attn_forward
query_states, key_states = apply_rotary_pos_emb(
^^^^^^^^^^^^^^^^^^^^^
File "/home/cpwu/working/SPTCAttn/MInference/.venv/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 145, in apply_rotary_pos_emb
q_embed = (q * cos) + (rotate_half(q) * sin)
~~~~~~~~~~~~~~~^~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.63 GiB. GPU 0 has a total capacity of 79.25 GiB of which 2.23 GiB is free. Including non-PyTorch memory, this process has 77.02 GiB memory in use. Of the allocated memory 65.06 GiB is allocated by PyTorch, and 11.46 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Similarly, for all latency experiments, the OOM also occurs for context_window of 500k and 1000k (success for 1k~100k), with command
python experiments/benchmarks/benchmark_e2e.py --run_benchmark
And in the NIAH test, OOM also occurs (successful for jobs 0-4, OOM for jobs 4-15)
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested