-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
First of all, thank you for sharing this excellent codebase. I'm currently encountering a crash when training the gcbf+ algorithm in the DoubleIntegrator environment. The error seems related to a GPU kernel fusion failure in ptxas, leading to an INTERNAL XLA error and process abort. Full context and logs are provided below.
Environment
OS: Ubuntu 22.04
Python: 3.10
JAX: jax==0.6.0, jaxlib==0.6.0+cuda12
GPU: NVIDIA RTX 4070 12GB
Command
python train.py --algo gcbf+ --env DoubleIntegrator -n 5 --area-size 4 --loss-action-coef 1e-4 \
--n-env-train 8 --lr-actor 1e-5 --lr-cbf 1e-5 --horizon 16
Log
F external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1099] Non-OK-status: executable.status()
Status: INTERNAL: ptxas exited with non-zero error code 139, output:
- Failure occurred when compiling fusion gemm_fusion_dot.610 with config '{block_m:128,block_n:32,block_k:32,...}'
Fused HLO computation:
%gemm_fusion_dot.610_computation (parameter_0.51: f32[21888,256], parameter_1.51: f32[128,171,128]) -> f32[256,128] {
...
dot_general source_file="/.../gcbfplus/nn/gnn.py" source_line=62
}
Aborted (core dumped)
The crash does not happen when I use JAX_PLATFORM_NAME=cpu, but it is very slow.
How can I deal with this issue?
Looking forward to your answer and thank you very much!
Metadata
Metadata
Assignees
Labels
No labels