diff --git a/include/core/tensor.hpp b/include/core/tensor.hpp index f32cb49..3d06c2b 100644 --- a/include/core/tensor.hpp +++ b/include/core/tensor.hpp @@ -1,5 +1,5 @@ /** -Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -119,6 +119,14 @@ class Tensor { */ int64_t shape(int d) const &; + /** + * @brief Retrieves a specific dimension size from the tensor shape using a character representing the dimension. + * + * @param[in] dimension The dimension to get the size of. This is a character representing the dimension. + * @return The size of the specified dimension. + */ + int64_t shape(std::string_view dimension) const &; + /** * @brief Returns the data type of the tensor * diff --git a/include/core/tensor_layout.hpp b/include/core/tensor_layout.hpp index 8ca3fd4..5053073 100644 --- a/include/core/tensor_layout.hpp +++ b/include/core/tensor_layout.hpp @@ -1,5 +1,5 @@ /** -Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -35,21 +35,6 @@ THE SOFTWARE. #define ROCCV_TENSOR_MAX_RANK (15) namespace roccv { -/** - * @brief Descriptors used to specify features of a specific tensor layout type. - * - */ -struct TensorLayoutDesc { - int32_t rank; - int32_t batch_index; - int32_t width_index; - int32_t height_index; - int32_t channel_index; - int32_t max_features_index; - int32_t sift_features_index; - int32_t sift_octave_layer_index; -}; - /** * @brief TensorLayout class. * @@ -62,38 +47,64 @@ class TensorLayout { * @param[in] layout The desired layout of the TensorLayout object. See * eTensorLayout for information on supported layouts. */ - explicit TensorLayout(eTensorLayout layout) { - if (TensorLayout::layoutDescriptorTable.count(layout) == 0) { - throw Exception("Invalid TensorLayout type", eStatusType::INVALID_VALUE); - } + explicit TensorLayout(eTensorLayout layout); + + // clang-format off + inline static const std::unordered_map layoutStringTable = { + {TENSOR_LAYOUT_HWC, "HWC"}, + {TENSOR_LAYOUT_NC, "NC"}, + {TENSOR_LAYOUT_NW, "NW"}, + {TENSOR_LAYOUT_NHWC, "NHWC"}, + {TENSOR_LAYOUT_NMC, "NMC"}, + {TENSOR_LAYOUT_NMD, "NMD"}, + {TENSOR_LAYOUT_LNHWC, "LNHWC"}, + {TENSOR_LAYOUT_NCHW, "NCHW"}, + {TENSOR_LAYOUT_N, "N"}, + {TENSOR_LAYOUT_NWC, "NWC"}, + }; + // clang-format on - layout_ = layout; - layout_desc_ = TensorLayout::layoutDescriptorTable.at(layout); - } + /** + * @brief Returns the index of the given dimension in the layout. + * + * @param[in] dimension The dimension to get the index of. + * @return The index of the dimension, or -1 if the dimension is not found in the layout. + */ + int32_t indexOf(std::string_view dim) const; /** - * @brief Provides descriptors for each feature of a specified layout type. + * @brief Returns the dimension at the given index in the layout. + * + * @param[in] index The index of the dimension to get. + * @return The dimension at the given index. + */ + std::string_view dimAt(int32_t index) const; + + /** + * @brief Returns the layout string representing the layout. + * + * @return The layout string. + */ + inline const std::string &string() const { return m_layoutString; } + + /** + * @brief Returns true if the layout contains the given dimension, false otherwise. + * + * @param[in] dim The dimension to check for. + * @return True if the layout contains the dimension, false otherwise. */ - inline static const std::unordered_map layoutDescriptorTable = { - {TENSOR_LAYOUT_HWC, {3, -1, 1, 0, 2, -1, -1, -1}}, {TENSOR_LAYOUT_NC, {2, 0, -1, -1, 1, -1, -1, -1}}, - {TENSOR_LAYOUT_NW, {2, 0, 1, -1, -1, -1, -1, -1}}, {TENSOR_LAYOUT_NHWC, {4, 0, 2, 1, 3, -1, -1, -1}}, - {TENSOR_LAYOUT_NMC, {3, 0, -1, -1, -1, 1, 2, -1}}, {TENSOR_LAYOUT_NMD, {3, 0, -1, -1, -1, 1, 2, -1}}, - {TENSOR_LAYOUT_LNHWC, {5, 1, 3, 2, 4, -1, -1, 0}}, {TENSOR_LAYOUT_NCHW, {4, 0, 3, 2, 1, -1, -1, -1}}, - {TENSOR_LAYOUT_N, {1, 0, -1, -1, -1, -1, -1, -1}}, {TENSOR_LAYOUT_NWC, {3, 0, 1, -1, 2, -1, -1, -1}}}; + inline bool containsDim(std::string_view dim) const { return indexOf(dim) != -1; } /** * @brief Returns the layout enum stored in the TensorLayout object. * * @return eTensorLayout */ - eTensorLayout elayout() const { return layout_; } - - bool operator==(const eTensorLayout &rhs) const { return this->layout_ == rhs; } + eTensorLayout elayout() const { return m_layout; } + bool operator==(const eTensorLayout &rhs) const { return this->m_layout == rhs; } bool operator!=(const eTensorLayout &rhs) const { return !operator==(rhs); } - - bool operator==(const TensorLayout &rhs) const { return this->layout_ == rhs.layout_; } - + bool operator==(const TensorLayout &rhs) const { return this->m_layout == rhs.m_layout; } bool operator!=(const TensorLayout &rhs) const { return !operator==(rhs); } /** @@ -101,60 +112,39 @@ class TensorLayout { * * @return int32_t */ - int32_t rank() const { return layout_desc_.rank; } + int32_t rank() const { return m_rank; } /** * @brief Index of the batch dimension specified by layout. E.g. returns 0 * for TENSOR_LAYOUT_NHWC. * @return Index or -1 if the layout does not have a batch dimension. */ - int32_t batch_index() const { return layout_desc_.batch_index; } + int32_t batch_index() const { return indexOf("N"); } /** * @brief Index of the height dimension specified by layout. E.g. returns 1 * for TENSOR_LAYOUT_NHWC. * @return Index of the height dimension. */ - int32_t height_index() const { return layout_desc_.height_index; } + int32_t height_index() const { return indexOf("H"); } /** * @brief Index of the width dimension specified by layout. E.g. returns 2 * for TENSOR_LAYOUT_NHWC. * @return Index of the width dimension. */ - int32_t width_index() const { return layout_desc_.width_index; } + int32_t width_index() const { return indexOf("W"); } /** * @brief Index of the channels dimension specified by layout. E.g. returns * 3 for TENSOR_LAYOUT_NHWC. * @return Index of the channels dimension. */ - int32_t channels_index() const { return layout_desc_.channel_index; } - - /** - * @brief Index of the max features dimension specified by layout - * - * @return Index of the max features dimension or -1 if the layout does not - * contain it. - */ - int32_t max_features_index() const { return layout_desc_.max_features_index; } - - /** - * @brief Index of the sift features dimension specified by layout - * - * @return int32_t - */ - int32_t sift_features_index() const { return layout_desc_.sift_features_index; } - - /** - * @brief Index of the sift octave layer dimension specified by layout - * - * @return int32_t - */ - int32_t sift_octave_layer_index() const { return layout_desc_.sift_octave_layer_index; } + int32_t channels_index() const { return indexOf("C"); } private: - eTensorLayout layout_; - TensorLayoutDesc layout_desc_; + eTensorLayout m_layout; + std::string m_layoutString; + int32_t m_rank; }; } // namespace roccv \ No newline at end of file diff --git a/include/core/tensor_shape.hpp b/include/core/tensor_shape.hpp index ab94f51..d17ce60 100644 --- a/include/core/tensor_shape.hpp +++ b/include/core/tensor_shape.hpp @@ -1,5 +1,5 @@ /** -Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -96,8 +96,29 @@ class TensorShape { */ const std::array &shape() const; + /** + * @brief Permutes the tensor shape to the given layout. + * + * @note This operation requires that the set of dimensions in the new layout are a subset of the dimensions in the + * current layout. For example, a tensor shape with layout HWC cannot be permuted to layout NCHW because NCHW has + * dimension N that is not present in HWC. + + * @param[in] layout The layout to permute the tensor shape to. + * @return The permuted tensor shape. + */ + TensorShape permute(const TensorLayout &layout) const; + + /** + * @brief Returns true if the tensor shape contains the given dimension, false otherwise. + * + * @param[in] dim The dimension to check for. + * @return True if the tensor shape contains the dimension, false otherwise. + */ + inline bool containsDim(std::string_view dim) const { return m_layout.containsDim(dim); } + // Operators int64_t operator[](int32_t i) const; + int64_t operator[](std::string_view dimension) const; TensorShape &operator=(const TensorShape &other); bool operator==(const TensorShape &rhs) const; bool operator!=(const TensorShape &rhs) const; diff --git a/src/core/tensor.cpp b/src/core/tensor.cpp index 4eeae37..2e35b05 100644 --- a/src/core/tensor.cpp +++ b/src/core/tensor.cpp @@ -1,5 +1,5 @@ /** -Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -77,6 +77,8 @@ TensorShape Tensor::shape() const { return TensorShape(m_requirements.shape, m_r int64_t Tensor::shape(int d) const& { return shape()[d]; } +int64_t Tensor::shape(std::string_view dimension) const& { return shape()[dimension]; } + DataType Tensor::dtype() const { return DataType(m_requirements.dtype); } TensorLayout Tensor::layout() const { return TensorLayout(m_requirements.layout); } diff --git a/src/core/tensor_layout.cpp b/src/core/tensor_layout.cpp new file mode 100644 index 0000000..20fbc53 --- /dev/null +++ b/src/core/tensor_layout.cpp @@ -0,0 +1,59 @@ +/** +Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#include "core/tensor_layout.hpp" + +namespace roccv { + +TensorLayout::TensorLayout(eTensorLayout layout) : m_layout(layout) { + if (layoutStringTable.count(m_layout) == 0) { + throw Exception("Invalid TensorLayout type", eStatusType::INVALID_VALUE); + } + + m_layoutString = layoutStringTable.at(m_layout); + m_rank = m_layoutString.size(); +} + +int32_t TensorLayout::indexOf(std::string_view dim) const { + if (dim.size() != 1) { + throw Exception("Dimension must be a single character", eStatusType::INVALID_VALUE); + } + + auto index = m_layoutString.find(dim); + if (index == std::string::npos) { + return -1; + } + + return index; +} + +std::string_view TensorLayout::dimAt(int32_t index) const { + if (index < 0 || index >= m_rank) { + throw Exception( + "Invalid index: " + std::to_string(index) + ". Index must be >= 0 and < " + std::to_string(m_rank), + eStatusType::OUT_OF_BOUNDS); + } + + return std::string_view(&m_layoutString[index], 1); +} + +} // namespace roccv \ No newline at end of file diff --git a/src/core/tensor_shape.cpp b/src/core/tensor_shape.cpp index 29bbfde..870ad1b 100644 --- a/src/core/tensor_shape.cpp +++ b/src/core/tensor_shape.cpp @@ -1,5 +1,5 @@ /** -Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,7 @@ THE SOFTWARE. #include "core/tensor_shape.hpp" #include +#include #include "core/exception.hpp" #include "core/status_type.h" @@ -108,11 +109,22 @@ TensorShape &TensorShape::operator=(const TensorShape &other) { int64_t TensorShape::operator[](int32_t i) const { if (i < 0 || i >= this->m_layout.rank()) { - throw Exception("Invalid parameter: Index must be >= 0 and < rank.", eStatusType::OUT_OF_BOUNDS); + throw Exception("TensorShape index out of bounds: " + std::to_string(i) + ". Dimension must be >= 0 and < " + + std::to_string(this->m_layout.rank()), + eStatusType::OUT_OF_BOUNDS); } return m_shape[i]; } +int64_t TensorShape::operator[](std::string_view dimension) const { + int32_t index = m_layout.indexOf(dimension); + if (index == -1) { + throw Exception("Invalid dimension: " + std::string(dimension) + ". Dimension must be in the layout.", + eStatusType::OUT_OF_BOUNDS); + } + return operator[](index); +} + bool TensorShape::operator==(const TensorShape &rhs) const { if (this->m_layout != rhs.m_layout) { return false; @@ -139,4 +151,12 @@ const TensorLayout &TensorShape::layout() const { return m_layout; } const std::array &TensorShape::shape() const { return m_shape; } +TensorShape TensorShape::permute(const TensorLayout &layout) const { + std::vector permutedShape(layout.rank()); + for (int32_t i = 0; i < layout.rank(); i++) { + permutedShape[i] = operator[](layout.dimAt(i)); + } + return TensorShape(layout, permutedShape); +} + } // namespace roccv \ No newline at end of file diff --git a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp index 42e9a70..4dc423c 100644 --- a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp +++ b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights @@ -55,6 +55,12 @@ void TestNegativeTensorShape() { { EXPECT_EXCEPTION(TensorShape shape({1, 2, 3}, "NWCX"), eStatusType::INVALID_VALUE); } + + // Test permute operation with invalid layout + { + TensorShape shape({1, 2, 3}, "HWC"); + EXPECT_EXCEPTION(shape.permute(TensorLayout(TENSOR_LAYOUT_NCHW)), eStatusType::OUT_OF_BOUNDS); + } } /** @@ -87,6 +93,36 @@ void TestTensorShapeCorrectness() { shape2 = shape1; EXPECT_TRUE(shape1 == shape2); } + + // Test TensorShape index operator + { + TensorShape shape({1, 2, 3, 4}, "NHWC"); + EXPECT_EQ(shape["N"], 1); + EXPECT_EQ(shape["H"], 2); + EXPECT_EQ(shape["W"], 3); + EXPECT_EQ(shape["C"], 4); + EXPECT_EXCEPTION(shape["X"], eStatusType::OUT_OF_BOUNDS); + } + + // Test TensorShape permute operator + { + TensorShape shape({1, 2, 3, 4}, "NHWC"); + TensorShape permutedShape = shape.permute(TensorLayout(TENSOR_LAYOUT_NCHW)); + EXPECT_TRUE(permutedShape.layout() == eTensorLayout::TENSOR_LAYOUT_NCHW); + EXPECT_EQ(permutedShape["N"], 1); + EXPECT_EQ(permutedShape["C"], 4); + EXPECT_EQ(permutedShape["H"], 2); + EXPECT_EQ(permutedShape["W"], 3); + } + + // Test TensorShape containsDim operator + { + TensorShape shape({1, 2, 3}, "HWC"); + EXPECT_TRUE(shape.containsDim("H")); + EXPECT_TRUE(shape.containsDim("W")); + EXPECT_TRUE(shape.containsDim("C")); + EXPECT_FALSE(shape.containsDim("N")); + } } } // namespace