High-performance tropical matrix multiplication in Rust with SIMD and CUDA backends. Inspired by CuTropicalGEMM.jl.
- Multiple Semirings: MaxPlus, MinPlus, MaxMul
- SIMD Acceleration: AVX-512, AVX2, SSE4.1, NEON auto-detection
- CUDA Backend: GPU-accelerated kernels via NVRTC
- Argmax Tracking: For backpropagation in tropical neural networks
- Python Bindings: NumPy and PyTorch integration
[dependencies]
tropical-gemm = "0.2"
tropical-gemm-cuda = "0.2" # Optional GPU supportuse tropical_gemm::{Mat, MaxPlus};
let a = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
let b = Mat::<MaxPlus<f32>>::from_row_major(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 2);
// C[i,j] = max_k(A[i,k] + B[k,j])
let c = a.matmul(&b);
assert_eq!(c.get_value(0, 0), 8.0); // max(1+1, 2+3, 3+5) = 8pip install tropical-gemm[torch]import torch
from tropical_gemm.pytorch import (
tropical_maxplus_matmul,
tropical_minplus_matmul,
tropical_maxmul_matmul,
tropical_maxplus_matmul_gpu, # GPU (requires CUDA)
GPU_AVAILABLE,
)
# Create tensors with gradients
a = torch.randn(100, 50, requires_grad=True)
b = torch.randn(50, 80, requires_grad=True)
# Forward pass
c = tropical_maxplus_matmul(a, b)
# Backward pass - gradients flow automatically
loss = c.sum()
loss.backward()
print(a.grad.shape) # (100, 50)
print(b.grad.shape) # (50, 80)📖 User Guide - Installation, tutorials, examples
📚 API Reference - Rust API documentation
| Type | ⊕ | ⊗ | Use Case |
|---|---|---|---|
MaxPlus<T> |
max | + | Longest path, Viterbi |
MinPlus<T> |
min | + | Shortest path |
MaxMul<T> |
max | × | Max probability |
| Size | CPU (ms) | GPU (ms) | Speedup |
|---|---|---|---|
| 256 | 4.1 | 0.03 | 128x |
| 1024 | 262 | 0.36 | 728x |
| 2048 | 2092 | 2.5 | 837x |
MIT