Skip to content

Training crashes with ptxas exited with error code 139 during GNN layer fusion #11

@lacie-life

Description

@lacie-life

First of all, thank you for sharing this excellent codebase. I'm currently encountering a crash when training the gcbf+ algorithm in the DoubleIntegrator environment. The error seems related to a GPU kernel fusion failure in ptxas, leading to an INTERNAL XLA error and process abort. Full context and logs are provided below.

Environment

OS: Ubuntu 22.04

Python: 3.10

JAX: jax==0.6.0, jaxlib==0.6.0+cuda12

GPU: NVIDIA RTX 4070 12GB

Command

python train.py --algo gcbf+ --env DoubleIntegrator -n 5 --area-size 4 --loss-action-coef 1e-4 \
--n-env-train 8 --lr-actor 1e-5 --lr-cbf 1e-5 --horizon 16 

Log

F external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1099] Non-OK-status: executable.status()
Status: INTERNAL: ptxas exited with non-zero error code 139, output:
 - Failure occurred when compiling fusion gemm_fusion_dot.610 with config '{block_m:128,block_n:32,block_k:32,...}'
Fused HLO computation:
%gemm_fusion_dot.610_computation (parameter_0.51: f32[21888,256], parameter_1.51: f32[128,171,128]) -> f32[256,128] {
...
dot_general source_file="/.../gcbfplus/nn/gnn.py" source_line=62
}
Aborted (core dumped)

The crash does not happen when I use JAX_PLATFORM_NAME=cpu, but it is very slow.

How can I deal with this issue?

Looking forward to your answer and thank you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions