-
Notifications
You must be signed in to change notification settings - Fork 46
Description
Summary
The AWS Neuron compiler (torch-neuronx) produces incorrect outputs when compiling a transformer-based object detection decoder with discrete attention mechanism. The compiled model produces vastly different logits and bounding boxes compared to PyTorch, resulting in zero detections on real images.
Environment
# System info
OS: Ubuntu (Linux 6.8.0-1040-aws)
Instance: inf2.xlarge (tested on CPU instance for compilation)
# Python packages (confirmed versions)
Python: 3.10.12
torch: 2.8.0
torch-neuronx: 2.8.0.2.10.16998+e9bf8a50
torch-xla: 2.8.1
neuronx-cc: 2.21.33363.0+82129205
torchvision: 0.23.0Problem Description
Expected Behavior
When compiling a transformer decoder with discrete attention to Neuron format, the compiled model should produce outputs numerically close to PyTorch (within typical BF16 precision tolerance of ~1e-2).
Actual Behavior
The Neuron-compiled decoder produces completely incorrect outputs:
- Logits max difference: 7.84 (compared to ~1e-2 expected)
- Logits mean difference: 1.25 (compared to ~1e-4 expected)
- Relative difference: 160,087% (!)
- Detection count: 0 detections (PyTorch: 60 detections)
- Top-5 confidence scores: [0.25, 0.24, 0.23, 0.21, 0.21] (PyTorch: [0.93, 0.90, 0.89, 0.84, 0.84])
Components Affected
- Backbone (HGNetv2): ✓ Works correctly (max diff ~0.7)
- Encoder (HybridEncoder): ✓ Works correctly (max diff ~0.18)
- Decoder (DEIMTransformer with discrete attention): ✗ Completely broken
Reproduction
1. Model Architecture
The decoder is a DETR-style transformer decoder with:
- Multi-scale deformable attention using discrete indexing (not F.grid_sample)
- 3 decoder layers
- 300 object queries
- 80 classes (COCO)
- 2 feature levels
Key detail: The model uses cross_attn_method: 'discrete' which implements deformable attention via discrete tensor indexing instead of F.grid_sample() (which is not supported by Neuron).
3. Full Reproduction
Complete reproduction code available at: /home/ubuntu/sh_deimv2/
Key files:
convert_components_neuronx.py: Conversion script (View on GitHub Gist)verify_neuronx_on_image.py: Verification on real image (View on GitHub Gist)compare_all_methods.py: Comprehensive comparison (View on GitHub Gist)
To reproduce:
# Convert decoder to Neuron
python convert_components_neuronx.py \\
--checkpoint models/deimv2_hgnetv2_n_coco.pth \\
--config configs/deimv2/deimv2_hgnetv2_n_coco.yml \\
--component all
# Verify (shows the bug)
python verify_neuronx_on_image.py --image example.jpgInvestigation Results
What We've Tried
- ✗ Different compiler optimizations: Tested
-O0,-O1,-O2 - ✗ FP32 precision:
--fp32-cast=all,--fp32-cast=matmult - ✗ Disabled auto-casting:
--auto-cast=none - ✗ Model type hint:
--model-type=transformer - ✗ Various combinations: See
try_neuronx_compiler_settings.py
None of the above produced correct outputs.
Key Observations
- Backbone and encoder compile correctly - The issue is isolated to the decoder
- PyTorch discrete attention works perfectly - The implementation is correct in PyTorch
- Error is not precision-related - Differences are 1000x larger than BF16 tolerance
- Consistent failure - Fails on both random inputs and real images