From 3a525c42e5d64f1c09dba601c8a392f9e6d6b6ff Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 12:53:41 -0500 Subject: [PATCH 1/7] Improve TensorLayout dimension querying --- include/core/tensor.hpp | 8 +++ include/core/tensor_layout.hpp | 97 +++++++++++----------------------- include/core/tensor_shape.hpp | 1 + src/core/tensor.cpp | 2 + src/core/tensor_layout.cpp | 49 +++++++++++++++++ src/core/tensor_shape.cpp | 2 + 6 files changed, 94 insertions(+), 65 deletions(-) create mode 100644 src/core/tensor_layout.cpp diff --git a/include/core/tensor.hpp b/include/core/tensor.hpp index f32cb498..8eafc2e6 100644 --- a/include/core/tensor.hpp +++ b/include/core/tensor.hpp @@ -119,6 +119,14 @@ class Tensor { */ int64_t shape(int d) const &; + /** + * @brief Retrieves a specific dimension size from the tensor shape using a character representing the dimension. + * + * @param[in] dimension The dimension to get the size of. This is a character representing the dimension. + * @return The size of the specified dimension. + */ + int64_t shape(std::string_view dimension) const &; + /** * @brief Returns the data type of the tensor * diff --git a/include/core/tensor_layout.hpp b/include/core/tensor_layout.hpp index 8ca3fd45..7df0e349 100644 --- a/include/core/tensor_layout.hpp +++ b/include/core/tensor_layout.hpp @@ -35,21 +35,6 @@ THE SOFTWARE. #define ROCCV_TENSOR_MAX_RANK (15) namespace roccv { -/** - * @brief Descriptors used to specify features of a specific tensor layout type. - * - */ -struct TensorLayoutDesc { - int32_t rank; - int32_t batch_index; - int32_t width_index; - int32_t height_index; - int32_t channel_index; - int32_t max_features_index; - int32_t sift_features_index; - int32_t sift_octave_layer_index; -}; - /** * @brief TensorLayout class. * @@ -62,38 +47,41 @@ class TensorLayout { * @param[in] layout The desired layout of the TensorLayout object. See * eTensorLayout for information on supported layouts. */ - explicit TensorLayout(eTensorLayout layout) { - if (TensorLayout::layoutDescriptorTable.count(layout) == 0) { - throw Exception("Invalid TensorLayout type", eStatusType::INVALID_VALUE); - } - - layout_ = layout; - layout_desc_ = TensorLayout::layoutDescriptorTable.at(layout); - } + explicit TensorLayout(eTensorLayout layout); + + // clang-format off + inline static const std::unordered_map layoutStringTable = { + {TENSOR_LAYOUT_HWC, "HWC"}, + {TENSOR_LAYOUT_NC, "NC"}, + {TENSOR_LAYOUT_NW, "NW"}, + {TENSOR_LAYOUT_NHWC, "NHWC"}, + {TENSOR_LAYOUT_NMC, "NMC"}, + {TENSOR_LAYOUT_NMD, "NMD"}, + {TENSOR_LAYOUT_LNHWC, "LNHWC"}, + {TENSOR_LAYOUT_NCHW, "NCHW"}, + {TENSOR_LAYOUT_N, "N"}, + {TENSOR_LAYOUT_NWC, "NWC"}, + }; + // clang-format on /** - * @brief Provides descriptors for each feature of a specified layout type. + * @brief Returns the index of the given dimension in the layout. + * + * @param[in] dimension The dimension to get the index of. + * @return The index of the dimension, or -1 if the dimension is not found in the layout. */ - inline static const std::unordered_map layoutDescriptorTable = { - {TENSOR_LAYOUT_HWC, {3, -1, 1, 0, 2, -1, -1, -1}}, {TENSOR_LAYOUT_NC, {2, 0, -1, -1, 1, -1, -1, -1}}, - {TENSOR_LAYOUT_NW, {2, 0, 1, -1, -1, -1, -1, -1}}, {TENSOR_LAYOUT_NHWC, {4, 0, 2, 1, 3, -1, -1, -1}}, - {TENSOR_LAYOUT_NMC, {3, 0, -1, -1, -1, 1, 2, -1}}, {TENSOR_LAYOUT_NMD, {3, 0, -1, -1, -1, 1, 2, -1}}, - {TENSOR_LAYOUT_LNHWC, {5, 1, 3, 2, 4, -1, -1, 0}}, {TENSOR_LAYOUT_NCHW, {4, 0, 3, 2, 1, -1, -1, -1}}, - {TENSOR_LAYOUT_N, {1, 0, -1, -1, -1, -1, -1, -1}}, {TENSOR_LAYOUT_NWC, {3, 0, 1, -1, 2, -1, -1, -1}}}; + int32_t indexOf(std::string_view dim) const; /** * @brief Returns the layout enum stored in the TensorLayout object. * * @return eTensorLayout */ - eTensorLayout elayout() const { return layout_; } - - bool operator==(const eTensorLayout &rhs) const { return this->layout_ == rhs; } + eTensorLayout elayout() const { return m_layout; } + bool operator==(const eTensorLayout &rhs) const { return this->m_layout == rhs; } bool operator!=(const eTensorLayout &rhs) const { return !operator==(rhs); } - - bool operator==(const TensorLayout &rhs) const { return this->layout_ == rhs.layout_; } - + bool operator==(const TensorLayout &rhs) const { return this->m_layout == rhs.m_layout; } bool operator!=(const TensorLayout &rhs) const { return !operator==(rhs); } /** @@ -101,60 +89,39 @@ class TensorLayout { * * @return int32_t */ - int32_t rank() const { return layout_desc_.rank; } + int32_t rank() const { return m_rank; } /** * @brief Index of the batch dimension specified by layout. E.g. returns 0 * for TENSOR_LAYOUT_NHWC. * @return Index or -1 if the layout does not have a batch dimension. */ - int32_t batch_index() const { return layout_desc_.batch_index; } + int32_t batch_index() const { return indexOf("N"); } /** * @brief Index of the height dimension specified by layout. E.g. returns 1 * for TENSOR_LAYOUT_NHWC. * @return Index of the height dimension. */ - int32_t height_index() const { return layout_desc_.height_index; } + int32_t height_index() const { return indexOf("H"); } /** * @brief Index of the width dimension specified by layout. E.g. returns 2 * for TENSOR_LAYOUT_NHWC. * @return Index of the width dimension. */ - int32_t width_index() const { return layout_desc_.width_index; } + int32_t width_index() const { return indexOf("W"); } /** * @brief Index of the channels dimension specified by layout. E.g. returns * 3 for TENSOR_LAYOUT_NHWC. * @return Index of the channels dimension. */ - int32_t channels_index() const { return layout_desc_.channel_index; } - - /** - * @brief Index of the max features dimension specified by layout - * - * @return Index of the max features dimension or -1 if the layout does not - * contain it. - */ - int32_t max_features_index() const { return layout_desc_.max_features_index; } - - /** - * @brief Index of the sift features dimension specified by layout - * - * @return int32_t - */ - int32_t sift_features_index() const { return layout_desc_.sift_features_index; } - - /** - * @brief Index of the sift octave layer dimension specified by layout - * - * @return int32_t - */ - int32_t sift_octave_layer_index() const { return layout_desc_.sift_octave_layer_index; } + int32_t channels_index() const { return indexOf("C"); } private: - eTensorLayout layout_; - TensorLayoutDesc layout_desc_; + eTensorLayout m_layout; + std::string m_layoutString; + int m_rank; }; } // namespace roccv \ No newline at end of file diff --git a/include/core/tensor_shape.hpp b/include/core/tensor_shape.hpp index ab94f51f..f6a523b4 100644 --- a/include/core/tensor_shape.hpp +++ b/include/core/tensor_shape.hpp @@ -98,6 +98,7 @@ class TensorShape { // Operators int64_t operator[](int32_t i) const; + int64_t operator[](std::string_view dimension) const; TensorShape &operator=(const TensorShape &other); bool operator==(const TensorShape &rhs) const; bool operator!=(const TensorShape &rhs) const; diff --git a/src/core/tensor.cpp b/src/core/tensor.cpp index 4eeae37b..16d5a21e 100644 --- a/src/core/tensor.cpp +++ b/src/core/tensor.cpp @@ -77,6 +77,8 @@ TensorShape Tensor::shape() const { return TensorShape(m_requirements.shape, m_r int64_t Tensor::shape(int d) const& { return shape()[d]; } +int64_t Tensor::shape(std::string_view dimension) const& { return shape()[dimension]; } + DataType Tensor::dtype() const { return DataType(m_requirements.dtype); } TensorLayout Tensor::layout() const { return TensorLayout(m_requirements.layout); } diff --git a/src/core/tensor_layout.cpp b/src/core/tensor_layout.cpp new file mode 100644 index 00000000..2fdb1fc1 --- /dev/null +++ b/src/core/tensor_layout.cpp @@ -0,0 +1,49 @@ +/** +Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#include "core/tensor_layout.hpp" + +namespace roccv { + +TensorLayout::TensorLayout(eTensorLayout layout) : m_layout(layout) { + if (layoutStringTable.count(m_layout) == 0) { + throw Exception("Invalid TensorLayout type", eStatusType::INVALID_VALUE); + } + + m_layoutString = layoutStringTable.at(m_layout); + m_rank = m_layoutString.size(); +} + +int32_t TensorLayout::indexOf(std::string_view dim) const { + if (dim.size() != 1) { + throw Exception("Dimension must be a single character", eStatusType::INVALID_VALUE); + } + + auto index = m_layoutString.find(dim); + if (index == std::string::npos) { + return -1; + } + + return index; +} + +} // namespace roccv \ No newline at end of file diff --git a/src/core/tensor_shape.cpp b/src/core/tensor_shape.cpp index 29bbfde2..ae81e8d3 100644 --- a/src/core/tensor_shape.cpp +++ b/src/core/tensor_shape.cpp @@ -113,6 +113,8 @@ int64_t TensorShape::operator[](int32_t i) const { return m_shape[i]; } +int64_t TensorShape::operator[](std::string_view dimension) const { return operator[](m_layout.indexOf(dimension)); } + bool TensorShape::operator==(const TensorShape &rhs) const { if (this->m_layout != rhs.m_layout) { return false; From 326d2a64ec4e2a6b4f5e9562d47317d6ed56085c Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 13:05:27 -0500 Subject: [PATCH 2/7] Add TensorShape unit testing for string based indexing --- src/core/tensor_shape.cpp | 13 +++++++++++-- .../cpp/src/tests/core/tensor/test_tensor_shape.cpp | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/core/tensor_shape.cpp b/src/core/tensor_shape.cpp index ae81e8d3..24e0ef43 100644 --- a/src/core/tensor_shape.cpp +++ b/src/core/tensor_shape.cpp @@ -108,12 +108,21 @@ TensorShape &TensorShape::operator=(const TensorShape &other) { int64_t TensorShape::operator[](int32_t i) const { if (i < 0 || i >= this->m_layout.rank()) { - throw Exception("Invalid parameter: Index must be >= 0 and < rank.", eStatusType::OUT_OF_BOUNDS); + throw Exception("TensorShape index out of bounds: " + std::to_string(i) + ". Dimension must be >= 0 and < " + + std::to_string(this->m_layout.rank()), + eStatusType::OUT_OF_BOUNDS); } return m_shape[i]; } -int64_t TensorShape::operator[](std::string_view dimension) const { return operator[](m_layout.indexOf(dimension)); } +int64_t TensorShape::operator[](std::string_view dimension) const { + int32_t index = m_layout.indexOf(dimension); + if (index == -1) { + throw Exception("Invalid dimension: " + std::string(dimension) + ". Dimension must be in the layout.", + eStatusType::OUT_OF_BOUNDS); + } + return operator[](index); +} bool TensorShape::operator==(const TensorShape &rhs) const { if (this->m_layout != rhs.m_layout) { diff --git a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp index 42e9a709..17b0dc1f 100644 --- a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp +++ b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp @@ -87,6 +87,16 @@ void TestTensorShapeCorrectness() { shape2 = shape1; EXPECT_TRUE(shape1 == shape2); } + + // Test TensorShape index operator + { + TensorShape shape({1, 2, 3, 4}, "NHWC"); + EXPECT_EQ(shape["N"], 1); + EXPECT_EQ(shape["H"], 2); + EXPECT_EQ(shape["W"], 3); + EXPECT_EQ(shape["C"], 4); + EXPECT_EXCEPTION(shape["X"], eStatusType::OUT_OF_BOUNDS); + } } } // namespace From 197ddcd92fd407b88fe1e11ba078968d786f9b64 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 13:41:38 -0500 Subject: [PATCH 3/7] Add TensorShape permute method --- include/core/tensor_layout.hpp | 7 +++++++ include/core/tensor_shape.hpp | 8 ++++++++ src/core/tensor_shape.cpp | 10 ++++++++++ 3 files changed, 25 insertions(+) diff --git a/include/core/tensor_layout.hpp b/include/core/tensor_layout.hpp index 7df0e349..583052f0 100644 --- a/include/core/tensor_layout.hpp +++ b/include/core/tensor_layout.hpp @@ -72,6 +72,13 @@ class TensorLayout { */ int32_t indexOf(std::string_view dim) const; + /** + * @brief Returns the layout string representing the layout. + * + * @return The layout string. + */ + inline const std::string &string() const { return m_layoutString; } + /** * @brief Returns the layout enum stored in the TensorLayout object. * diff --git a/include/core/tensor_shape.hpp b/include/core/tensor_shape.hpp index f6a523b4..270db186 100644 --- a/include/core/tensor_shape.hpp +++ b/include/core/tensor_shape.hpp @@ -96,6 +96,14 @@ class TensorShape { */ const std::array &shape() const; + /** + * @brief Permutes the tensor shape to the given layout. + * + * @param[in] layout The layout to permute the tensor shape to. + * @return The permuted tensor shape. + */ + TensorShape permute(const TensorLayout &layout) const; + // Operators int64_t operator[](int32_t i) const; int64_t operator[](std::string_view dimension) const; diff --git a/src/core/tensor_shape.cpp b/src/core/tensor_shape.cpp index 24e0ef43..bda054c9 100644 --- a/src/core/tensor_shape.cpp +++ b/src/core/tensor_shape.cpp @@ -23,6 +23,7 @@ THE SOFTWARE. #include "core/tensor_shape.hpp" #include +#include #include "core/exception.hpp" #include "core/status_type.h" @@ -150,4 +151,13 @@ const TensorLayout &TensorShape::layout() const { return m_layout; } const std::array &TensorShape::shape() const { return m_shape; } +TensorShape TensorShape::permute(const TensorLayout &layout) const { + const std::string &layoutString = layout.string(); + std::vector permutedShape(layout.rank()); + for (int32_t i = 0; i < layout.rank(); i++) { + permutedShape[i] = operator[](std::to_string(layoutString[i])); + } + return TensorShape(layout, permutedShape); +} + } // namespace roccv \ No newline at end of file From dfaf4c372d729bdaab0d25d6b8f12ec30a81d298 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 14:03:00 -0500 Subject: [PATCH 4/7] Add unit test for shape permutation and fix method errors --- include/core/tensor_layout.hpp | 8 ++++++++ src/core/tensor_layout.cpp | 2 ++ src/core/tensor_shape.cpp | 3 +-- .../cpp/src/tests/core/tensor/test_tensor_shape.cpp | 11 +++++++++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/include/core/tensor_layout.hpp b/include/core/tensor_layout.hpp index 583052f0..c4854f3a 100644 --- a/include/core/tensor_layout.hpp +++ b/include/core/tensor_layout.hpp @@ -72,6 +72,14 @@ class TensorLayout { */ int32_t indexOf(std::string_view dim) const; + /** + * @brief Returns the dimension at the given index in the layout. + * + * @param[in] index The index of the dimension to get. + * @return The dimension at the given index. + */ + std::string_view dimAt(int32_t index) const; + /** * @brief Returns the layout string representing the layout. * diff --git a/src/core/tensor_layout.cpp b/src/core/tensor_layout.cpp index 2fdb1fc1..35ff0401 100644 --- a/src/core/tensor_layout.cpp +++ b/src/core/tensor_layout.cpp @@ -46,4 +46,6 @@ int32_t TensorLayout::indexOf(std::string_view dim) const { return index; } +std::string_view TensorLayout::dimAt(int32_t index) const { return std::string_view(&m_layoutString[index], 1); } + } // namespace roccv \ No newline at end of file diff --git a/src/core/tensor_shape.cpp b/src/core/tensor_shape.cpp index bda054c9..fa22b340 100644 --- a/src/core/tensor_shape.cpp +++ b/src/core/tensor_shape.cpp @@ -152,10 +152,9 @@ const TensorLayout &TensorShape::layout() const { return m_layout; } const std::array &TensorShape::shape() const { return m_shape; } TensorShape TensorShape::permute(const TensorLayout &layout) const { - const std::string &layoutString = layout.string(); std::vector permutedShape(layout.rank()); for (int32_t i = 0; i < layout.rank(); i++) { - permutedShape[i] = operator[](std::to_string(layoutString[i])); + permutedShape[i] = operator[](layout.dimAt(i)); } return TensorShape(layout, permutedShape); } diff --git a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp index 17b0dc1f..c3340172 100644 --- a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp +++ b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp @@ -97,6 +97,17 @@ void TestTensorShapeCorrectness() { EXPECT_EQ(shape["C"], 4); EXPECT_EXCEPTION(shape["X"], eStatusType::OUT_OF_BOUNDS); } + + // Test TensorShape permute operator + { + TensorShape shape({1, 2, 3, 4}, "NHWC"); + TensorShape permutedShape = shape.permute(TensorLayout(TENSOR_LAYOUT_NCHW)); + EXPECT_TRUE(permutedShape.layout() == eTensorLayout::TENSOR_LAYOUT_NCHW); + EXPECT_EQ(permutedShape["N"], 1); + EXPECT_EQ(permutedShape["C"], 4); + EXPECT_EQ(permutedShape["H"], 2); + EXPECT_EQ(permutedShape["W"], 3); + } } } // namespace From 4e25005d3955019bdbb1a77e0b528f6476988eec Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 14:17:31 -0500 Subject: [PATCH 5/7] Add note for TensorShape::permute() --- include/core/tensor_shape.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/core/tensor_shape.hpp b/include/core/tensor_shape.hpp index 270db186..1f3ca0e0 100644 --- a/include/core/tensor_shape.hpp +++ b/include/core/tensor_shape.hpp @@ -99,6 +99,10 @@ class TensorShape { /** * @brief Permutes the tensor shape to the given layout. * + * @note This operation requires that the set of dimensions in the new layout are a subset of the dimensions in the + * current layout. For example, a tensor shape with layout HWC cannot be permuted to layout NCHW because NCHW has + * dimension N that is not present in HWC. + * @param[in] layout The layout to permute the tensor shape to. * @return The permuted tensor shape. */ From 4a75cdbc47c677516a2bbfbedcc90f015172dd78 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 14:35:05 -0500 Subject: [PATCH 6/7] Address comments --- include/core/tensor_layout.hpp | 2 +- src/core/tensor_layout.cpp | 10 +++++++++- .../cpp/src/tests/core/tensor/test_tensor_shape.cpp | 6 ++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/include/core/tensor_layout.hpp b/include/core/tensor_layout.hpp index c4854f3a..63e19be3 100644 --- a/include/core/tensor_layout.hpp +++ b/include/core/tensor_layout.hpp @@ -137,6 +137,6 @@ class TensorLayout { private: eTensorLayout m_layout; std::string m_layoutString; - int m_rank; + int32_t m_rank; }; } // namespace roccv \ No newline at end of file diff --git a/src/core/tensor_layout.cpp b/src/core/tensor_layout.cpp index 35ff0401..03476b33 100644 --- a/src/core/tensor_layout.cpp +++ b/src/core/tensor_layout.cpp @@ -46,6 +46,14 @@ int32_t TensorLayout::indexOf(std::string_view dim) const { return index; } -std::string_view TensorLayout::dimAt(int32_t index) const { return std::string_view(&m_layoutString[index], 1); } +std::string_view TensorLayout::dimAt(int32_t index) const { + if (index < 0 || index >= m_rank) { + throw Exception( + "Invalid index: " + std::to_string(index) + ". Index must be >= 0 and < " + std::to_string(m_rank), + eStatusType::OUT_OF_BOUNDS); + } + + return std::string_view(&m_layoutString[index], 1); +} } // namespace roccv \ No newline at end of file diff --git a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp index c3340172..5edd93fd 100644 --- a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp +++ b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp @@ -55,6 +55,12 @@ void TestNegativeTensorShape() { { EXPECT_EXCEPTION(TensorShape shape({1, 2, 3}, "NWCX"), eStatusType::INVALID_VALUE); } + + // Test permute operation with invalid layout + { + TensorShape shape({1, 2, 3}, "HWC"); + EXPECT_EXCEPTION(shape.permute(TensorLayout(TENSOR_LAYOUT_NCHW)), eStatusType::OUT_OF_BOUNDS); + } } /** From d51c916b39a3b2382e80f95dd0201ee2152b8434 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 16 Jan 2026 14:46:57 -0500 Subject: [PATCH 7/7] Add containsDim method for TensorShape/TensorLayout --- include/core/tensor_layout.hpp | 8 ++++++++ include/core/tensor_shape.hpp | 8 ++++++++ .../cpp/src/tests/core/tensor/test_tensor_shape.cpp | 9 +++++++++ 3 files changed, 25 insertions(+) diff --git a/include/core/tensor_layout.hpp b/include/core/tensor_layout.hpp index 63e19be3..7af24fb2 100644 --- a/include/core/tensor_layout.hpp +++ b/include/core/tensor_layout.hpp @@ -87,6 +87,14 @@ class TensorLayout { */ inline const std::string &string() const { return m_layoutString; } + /** + * @brief Returns true if the layout contains the given dimension, false otherwise. + * + * @param[in] dim The dimension to check for. + * @return True if the layout contains the dimension, false otherwise. + */ + inline bool containsDim(std::string_view dim) const { return indexOf(dim) != -1; } + /** * @brief Returns the layout enum stored in the TensorLayout object. * diff --git a/include/core/tensor_shape.hpp b/include/core/tensor_shape.hpp index 1f3ca0e0..f172530a 100644 --- a/include/core/tensor_shape.hpp +++ b/include/core/tensor_shape.hpp @@ -108,6 +108,14 @@ class TensorShape { */ TensorShape permute(const TensorLayout &layout) const; + /** + * @brief Returns true if the tensor shape contains the given dimension, false otherwise. + * + * @param[in] dim The dimension to check for. + * @return True if the tensor shape contains the dimension, false otherwise. + */ + inline bool containsDim(std::string_view dim) const { return m_layout.containsDim(dim); } + // Operators int64_t operator[](int32_t i) const; int64_t operator[](std::string_view dimension) const; diff --git a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp index 5edd93fd..6a4d4b21 100644 --- a/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp +++ b/tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp @@ -114,6 +114,15 @@ void TestTensorShapeCorrectness() { EXPECT_EQ(permutedShape["H"], 2); EXPECT_EQ(permutedShape["W"], 3); } + + // Test TensorShape containsDim operator + { + TensorShape shape({1, 2, 3}, "HWC"); + EXPECT_TRUE(shape.containsDim("H")); + EXPECT_TRUE(shape.containsDim("W")); + EXPECT_TRUE(shape.containsDim("C")); + EXPECT_FALSE(shape.containsDim("N")); + } } } // namespace