Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ src/torchada/
├── _patch.py # All patching logic (~1100 lines)
├── _platform.py # Platform detection utilities
├── _mapping.py # CUDA→MUSA symbol mappings for C++ extensions
├── _cpp_ops.py # C++ operator overrides infrastructure
├── csrc/ # C++ source files for operator overrides
│ ├── ops.h # Header with utilities and examples
│ └── ops.cpp # Main C++ source with Python bindings
├── cuda/ # CUDA module compatibility
└── utils/cpp_extension.py # CUDAExtension wrapper
tests/
Expand Down Expand Up @@ -68,7 +72,7 @@ pytest tests/ --tb=short

**Test Markers** (defined in `conftest.py`):
- `@pytest.mark.musa` - Requires MUSA platform
- `@pytest.mark.cuda` - Requires CUDA platform
- `@pytest.mark.cuda` - Requires CUDA platform
- `@pytest.mark.gpu` - Requires any GPU
- `@pytest.mark.slow` - Slow tests

Expand Down Expand Up @@ -138,6 +142,32 @@ import uuid
lib_name = f"test_lib_{uuid.uuid4().hex[:8]}"
```

## C++ Operator Overrides

torchada supports overriding ATen operators at the C++ level for better performance.

**See [docs/custom_musa_ops.md](docs/custom_musa_ops.md) for detailed documentation.**

**Quick start**:
```bash
export TORCHADA_ENABLE_CPP_OPS=1
```

**Adding a new operator override**:

1. Edit `src/torchada/csrc/musa_ops.mu` for MUSA kernels (or `ops.cpp` for pure C++)

2. Register using `TORCH_LIBRARY_IMPL(aten, PrivateUse1, m)`

3. The extension is JIT-compiled on first use

**Environment variables**:
- `TORCHADA_ENABLE_CPP_OPS=1` - Enable C++ operator overrides
- `TORCHADA_CPP_OPS_VERBOSE=1` - Show compilation output
- `TORCHADA_DEBUG_CPP_OPS=1` - Log operator calls
- `TORCHADA_DISABLE_OP_OVERRIDE_<OP_NAME>=1` - Disable specific operator override
- `MTGPU_TARGET=mp_XX` - Override GPU architecture (auto-detected via `musaInfo`)

## Security Considerations

- All patches are applied at import time via `apply_patches()`
Expand Down
236 changes: 236 additions & 0 deletions docs/custom_musa_ops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Writing Custom MUSA Operators in torchada

This guide explains how to write custom MUSA C++ operators that override torch_musa's default ATen implementations.

## Overview

torchada allows you to override ATen operators at the C++ level for the `PrivateUse1` (MUSA) dispatch key. This is useful when you need:

- Better performance than the default torch_musa implementation
- Custom behavior for specific operators
- Workarounds for torch_musa bugs

## Quick Start

### 1. Enable C++ Ops

```bash
export TORCHADA_ENABLE_CPP_OPS=1
```

### 2. Write Your Kernel

Edit `src/torchada/csrc/musa_ops.mu`:

```cpp
#include "ops.h"
#include <ATen/musa/MUSAContext.h>

namespace torchada {

template <typename scalar_t>
__global__ void my_kernel(scalar_t* output, const scalar_t* input, int64_t n) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = /* your computation */;
}
}

at::Tensor my_op_impl(const at::Tensor& self) {
log_op_call("my_op");

auto input = self.contiguous();
auto output = at::empty_like(input);
if (input.numel() == 0) return output;

musaStream_t stream = at::musa::getCurrentMUSAStream();
const int64_t n = input.numel();
const int threads = 256;
const int blocks = (n + threads - 1) / threads;

AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "my_op", [&] {
my_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
n);
});

return output;
}

} // namespace torchada

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
// Check env var at registration time - allows disabling via
// TORCHADA_DISABLE_OP_OVERRIDE_my_op=1
if (torchada::is_override_enabled("my_op")) {
m.impl("my_op", torchada::my_op_impl);
}
}
```

### 3. Test Your Kernel

```bash
TORCHADA_ENABLE_CPP_OPS=1 TORCHADA_DEBUG_CPP_OPS=1 python -c "
import torch
import torchada

x = torch.randn(1000, device='cuda')
y = torch.neg(x) # Should print '[torchada] neg called'
print('Result:', y.cpu()[:5])
"
```

## File Structure

| File | Purpose |
|------|---------|
| `src/torchada/csrc/ops.h` | Header with utilities (`log_op_call`, `is_override_enabled`) |
| `src/torchada/csrc/ops.cpp` | Python bindings and C++-only operator overrides |
| `src/torchada/csrc/musa_ops.mu` | MUSA kernel implementations |

## Environment Variables

| Variable | Description |
|----------|-------------|
| `TORCHADA_ENABLE_CPP_OPS=1` | Enable C++ operator overrides |
| `TORCHADA_CPP_OPS_VERBOSE=1` | Show compilation output |
| `TORCHADA_DEBUG_CPP_OPS=1` | Log operator calls to stdout |
| `TORCHADA_DISABLE_OP_OVERRIDE_<NAME>=1` | Disable specific operator override |
| `MTGPU_TARGET=mp_XX` | Override GPU architecture detection |

### Disabling Specific Operators

To disable a specific operator override at runtime, set the environment variable before importing torchada:

```bash
# Disable the 'neg' operator override, use torch_musa's default instead
TORCHADA_ENABLE_CPP_OPS=1 TORCHADA_DISABLE_OP_OVERRIDE_neg=1 python my_script.py
```

**Important**: The operator name in the environment variable should match the name passed to `is_override_enabled()` in the C++ code. For example, if the code uses `is_override_enabled("neg")`, set `TORCHADA_DISABLE_OP_OVERRIDE_neg=1`.

This check happens at **registration time** (when the extension is loaded), not at runtime. Once the extension is loaded, the operator registrations are fixed.

## GPU Architecture

torchada auto-detects the GPU architecture using `musaInfo`:

| GPU | Compute Capability | Architecture |
|-----|-------------------|--------------|
| MTT S80 | 2.1 | mp_21 |
| MTT S4000 | 2.2 | mp_22 |
| MTT S5000 | 3.1 | mp_31 |

Override with: `export MTGPU_TARGET=mp_22`

## Best Practices

### Avoid Infinite Recursion

When overriding an operator, don't call the same operator:

```cpp
// BAD - causes infinite recursion
at::Tensor bad_neg_impl(const at::Tensor& self) {
return -self; // Calls aten::neg again!
}

// GOOD - use lower-level primitives
at::Tensor good_neg_impl(const at::Tensor& self) {
auto output = at::empty_like(self);
// Launch custom kernel or use in-place ops
return output;
}
```

### Handle Edge Cases

```cpp
at::Tensor my_impl(const at::Tensor& self) {
auto input = self.contiguous(); // Ensure contiguous
if (input.numel() == 0) {
return at::empty_like(input); // Handle empty tensors
}
// ... kernel launch
}
```

### Check for Errors

```cpp
musaError_t err = musaGetLastError();
if (err != musaSuccess) {
TORCH_CHECK(false, "MUSA kernel failed: ", musaGetErrorString(err));
}
```

### Use Type Dispatching

```cpp
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
input.scalar_type(), "my_kernel", [&] {
my_kernel<scalar_t><<<blocks, threads, 0, stream>>>(...);
});
```

## Overridable ATen Operators

Any ATen operator with a `PrivateUse1` dispatch can be overridden. Common categories:

### Unary Operations
`abs`, `neg`, `exp`, `exp2`, `log`, `log2`, `log10`, `sqrt`, `rsqrt`, `ceil`, `floor`, `round`, `trunc`, `sign`, `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `sinh`, `cosh`, `tanh`, `sigmoid`, `erf`, `erfc`, `reciprocal`, `bitwise_not`

### Binary Operations
`add`, `sub`, `mul`, `div`, `pow`, `fmod`, `remainder`, `maximum`, `minimum`, `atan2`, `bitwise_and`, `bitwise_or`, `bitwise_xor`, `logical_and`, `logical_or`, `logical_xor`

### Reduction Operations
`sum`, `prod`, `mean`, `std`, `var`, `max`, `min`, `argmax`, `argmin`, `all`, `any`, `norm`, `logsumexp`

### Matrix Operations
`mm`, `bmm`, `addmm`, `addmv`, `addr`, `matmul`, `dot`, `mv`, `ger`, `linear`

### Activation Functions
`relu`, `relu_`, `leaky_relu`, `gelu`, `silu`, `mish`, `hardswish`, `hardsigmoid`, `softplus`, `softshrink`, `threshold`

### Normalization
`batch_norm`, `layer_norm`, `group_norm`, `instance_norm`, `local_response_norm`

### Pooling
`max_pool1d`, `max_pool2d`, `max_pool3d`, `avg_pool1d`, `avg_pool2d`, `avg_pool3d`, `adaptive_max_pool2d`, `adaptive_avg_pool2d`

### Convolution
`conv1d`, `conv2d`, `conv3d`, `conv_transpose1d`, `conv_transpose2d`, `conv_transpose3d`

### Memory Operations
`copy_`, `clone`, `contiguous`, `fill_`, `zero_`, `ones_like`, `zeros_like`, `empty_like`

### Indexing
`index`, `index_put_`, `gather`, `scatter`, `scatter_add`, `masked_fill`, `masked_select`, `where`

### Shape Operations
`view`, `reshape`, `transpose`, `permute`, `squeeze`, `unsqueeze`, `expand`, `repeat`, `cat`, `stack`, `split`, `chunk`

To find the exact operator signature, use:

```python
import torch
# Search for specific operator:
for s in torch._C._jit_get_all_schemas():
if 'neg' in str(s):
print(s)
```

## Complete Example

See `src/torchada/csrc/musa_ops.mu` for a complete working example that overrides `aten::neg`.

## Debugging

1. **Verify kernel is called**: Set `TORCHADA_DEBUG_CPP_OPS=1`
2. **Check compilation**: Set `TORCHADA_CPP_OPS_VERBOSE=1`
3. **Clear cache**: `rm -rf ~/.cache/torch_extensions/*/torchada_cpp_ops`
4. **Check architecture**: Run `musaInfo | grep "compute capability"`

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Repository = "https://github.com/MooreThreads/torchada"
[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
torchada = ["csrc/*.h", "csrc/*.cpp", "csrc/*.cu", "csrc/*.mu", "csrc/*.cuh", "csrc/*.muh"]

[tool.black]
line-length = 100
target-version = ["py38", "py39", "py310", "py311", "py312"]
Expand Down
5 changes: 5 additions & 0 deletions src/torchada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
# Automatically apply patches on import
apply_patches()

# Load C++ operator overrides if enabled via TORCHADA_ENABLE_CPP_OPS=1
from ._cpp_ops import load_cpp_ops

load_cpp_ops()


def get_version() -> str:
"""Return the version of torchada."""
Expand Down
Loading