From fd9e65d3d3a9e59bbe3329c8107d677de15625ef Mon Sep 17 00:00:00 2001 From: Christian Schmidt Date: Sat, 24 Jan 2026 18:08:19 +0100 Subject: [PATCH] Fix transform_points slicing issue --- utils3d/torch/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils3d/torch/transforms.py b/utils3d/torch/transforms.py index 13d9843..be20419 100644 --- a/utils3d/torch/transforms.py +++ b/utils3d/torch/transforms.py @@ -1374,7 +1374,7 @@ def transform_points(x: Tensor, *Ts: Tensor) -> Tensor: total_numel = sum(t.numel() for t in Ts) + x.numel() if total_numel > 1000: # Only use einsum when the total number of elements is large enough to benefit from optimized contraction path - operands = [*reversed(Ts), x[:, None]] + operands = [*reversed(Ts), x[..., None]] offset = len(operands) + 1 batch_shape = torch.broadcast_shapes(*(m.shape[:-2] for m in operands)) batch_subscripts = tuple(range(offset, offset + len(batch_shape))) @@ -1394,7 +1394,7 @@ def transform_points(x: Tensor, *Ts: Tensor) -> Tensor: ) y = y.squeeze(-1) else: - y = x[:, None] + y = x[..., None] for T in Ts: y = T @ y y = y.squeeze(-1)