Skip to content
8 changes: 8 additions & 0 deletions include/core/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ class Tensor {
*/
int64_t shape(int d) const &;

/**
* @brief Retrieves a specific dimension size from the tensor shape using a character representing the dimension.
*
* @param[in] dimension The dimension to get the size of. This is a character representing the dimension.
* @return The size of the specified dimension.
*/
int64_t shape(std::string_view dimension) const &;

/**
* @brief Returns the data type of the tensor
*
Expand Down
118 changes: 54 additions & 64 deletions include/core/tensor_layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,6 @@ THE SOFTWARE.
#define ROCCV_TENSOR_MAX_RANK (15)

namespace roccv {
/**
* @brief Descriptors used to specify features of a specific tensor layout type.
*
*/
struct TensorLayoutDesc {
int32_t rank;
int32_t batch_index;
int32_t width_index;
int32_t height_index;
int32_t channel_index;
int32_t max_features_index;
int32_t sift_features_index;
int32_t sift_octave_layer_index;
};

/**
* @brief TensorLayout class.
*
Expand All @@ -62,99 +47,104 @@ class TensorLayout {
* @param[in] layout The desired layout of the TensorLayout object. See
* eTensorLayout for information on supported layouts.
*/
explicit TensorLayout(eTensorLayout layout) {
if (TensorLayout::layoutDescriptorTable.count(layout) == 0) {
throw Exception("Invalid TensorLayout type", eStatusType::INVALID_VALUE);
}
explicit TensorLayout(eTensorLayout layout);

// clang-format off
inline static const std::unordered_map<eTensorLayout, std::string> layoutStringTable = {
{TENSOR_LAYOUT_HWC, "HWC"},
{TENSOR_LAYOUT_NC, "NC"},
{TENSOR_LAYOUT_NW, "NW"},
{TENSOR_LAYOUT_NHWC, "NHWC"},
{TENSOR_LAYOUT_NMC, "NMC"},
{TENSOR_LAYOUT_NMD, "NMD"},
{TENSOR_LAYOUT_LNHWC, "LNHWC"},
{TENSOR_LAYOUT_NCHW, "NCHW"},
{TENSOR_LAYOUT_N, "N"},
{TENSOR_LAYOUT_NWC, "NWC"},
};
// clang-format on

layout_ = layout;
layout_desc_ = TensorLayout::layoutDescriptorTable.at(layout);
}
/**
* @brief Returns the index of the given dimension in the layout.
*
* @param[in] dimension The dimension to get the index of.
* @return The index of the dimension, or -1 if the dimension is not found in the layout.
*/
int32_t indexOf(std::string_view dim) const;

/**
* @brief Provides descriptors for each feature of a specified layout type.
* @brief Returns the dimension at the given index in the layout.
*
* @param[in] index The index of the dimension to get.
* @return The dimension at the given index.
*/
std::string_view dimAt(int32_t index) const;

/**
* @brief Returns the layout string representing the layout.
*
* @return The layout string.
*/
inline const std::string &string() const { return m_layoutString; }

/**
* @brief Returns true if the layout contains the given dimension, false otherwise.
*
* @param[in] dim The dimension to check for.
* @return True if the layout contains the dimension, false otherwise.
*/
inline static const std::unordered_map<eTensorLayout, TensorLayoutDesc> layoutDescriptorTable = {
{TENSOR_LAYOUT_HWC, {3, -1, 1, 0, 2, -1, -1, -1}}, {TENSOR_LAYOUT_NC, {2, 0, -1, -1, 1, -1, -1, -1}},
{TENSOR_LAYOUT_NW, {2, 0, 1, -1, -1, -1, -1, -1}}, {TENSOR_LAYOUT_NHWC, {4, 0, 2, 1, 3, -1, -1, -1}},
{TENSOR_LAYOUT_NMC, {3, 0, -1, -1, -1, 1, 2, -1}}, {TENSOR_LAYOUT_NMD, {3, 0, -1, -1, -1, 1, 2, -1}},
{TENSOR_LAYOUT_LNHWC, {5, 1, 3, 2, 4, -1, -1, 0}}, {TENSOR_LAYOUT_NCHW, {4, 0, 3, 2, 1, -1, -1, -1}},
{TENSOR_LAYOUT_N, {1, 0, -1, -1, -1, -1, -1, -1}}, {TENSOR_LAYOUT_NWC, {3, 0, 1, -1, 2, -1, -1, -1}}};
inline bool containsDim(std::string_view dim) const { return indexOf(dim) != -1; }

/**
* @brief Returns the layout enum stored in the TensorLayout object.
*
* @return eTensorLayout
*/
eTensorLayout elayout() const { return layout_; }

bool operator==(const eTensorLayout &rhs) const { return this->layout_ == rhs; }
eTensorLayout elayout() const { return m_layout; }

bool operator==(const eTensorLayout &rhs) const { return this->m_layout == rhs; }
bool operator!=(const eTensorLayout &rhs) const { return !operator==(rhs); }

bool operator==(const TensorLayout &rhs) const { return this->layout_ == rhs.layout_; }

bool operator==(const TensorLayout &rhs) const { return this->m_layout == rhs.m_layout; }
bool operator!=(const TensorLayout &rhs) const { return !operator==(rhs); }

/**
* @brief Returns the rank of the Tensor Layout object.
*
* @return int32_t
*/
int32_t rank() const { return layout_desc_.rank; }
int32_t rank() const { return m_rank; }

/**
* @brief Index of the batch dimension specified by layout. E.g. returns 0
* for TENSOR_LAYOUT_NHWC.
* @return Index or -1 if the layout does not have a batch dimension.
*/
int32_t batch_index() const { return layout_desc_.batch_index; }
int32_t batch_index() const { return indexOf("N"); }

/**
* @brief Index of the height dimension specified by layout. E.g. returns 1
* for TENSOR_LAYOUT_NHWC.
* @return Index of the height dimension.
*/
int32_t height_index() const { return layout_desc_.height_index; }
int32_t height_index() const { return indexOf("H"); }

/**
* @brief Index of the width dimension specified by layout. E.g. returns 2
* for TENSOR_LAYOUT_NHWC.
* @return Index of the width dimension.
*/
int32_t width_index() const { return layout_desc_.width_index; }
int32_t width_index() const { return indexOf("W"); }

/**
* @brief Index of the channels dimension specified by layout. E.g. returns
* 3 for TENSOR_LAYOUT_NHWC.
* @return Index of the channels dimension.
*/
int32_t channels_index() const { return layout_desc_.channel_index; }

/**
* @brief Index of the max features dimension specified by layout
*
* @return Index of the max features dimension or -1 if the layout does not
* contain it.
*/
int32_t max_features_index() const { return layout_desc_.max_features_index; }

/**
* @brief Index of the sift features dimension specified by layout
*
* @return int32_t
*/
int32_t sift_features_index() const { return layout_desc_.sift_features_index; }

/**
* @brief Index of the sift octave layer dimension specified by layout
*
* @return int32_t
*/
int32_t sift_octave_layer_index() const { return layout_desc_.sift_octave_layer_index; }
int32_t channels_index() const { return indexOf("C"); }

private:
eTensorLayout layout_;
TensorLayoutDesc layout_desc_;
eTensorLayout m_layout;
std::string m_layoutString;
int32_t m_rank;
};
} // namespace roccv
21 changes: 21 additions & 0 deletions include/core/tensor_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,29 @@ class TensorShape {
*/
const std::array<int64_t, ROCCV_TENSOR_MAX_RANK> &shape() const;

/**
* @brief Permutes the tensor shape to the given layout.
*
* @note This operation requires that the set of dimensions in the new layout are a subset of the dimensions in the
* current layout. For example, a tensor shape with layout HWC cannot be permuted to layout NCHW because NCHW has
* dimension N that is not present in HWC.

* @param[in] layout The layout to permute the tensor shape to.
* @return The permuted tensor shape.
*/
TensorShape permute(const TensorLayout &layout) const;

/**
* @brief Returns true if the tensor shape contains the given dimension, false otherwise.
*
* @param[in] dim The dimension to check for.
* @return True if the tensor shape contains the dimension, false otherwise.
*/
inline bool containsDim(std::string_view dim) const { return m_layout.containsDim(dim); }

// Operators
int64_t operator[](int32_t i) const;
int64_t operator[](std::string_view dimension) const;
TensorShape &operator=(const TensorShape &other);
bool operator==(const TensorShape &rhs) const;
bool operator!=(const TensorShape &rhs) const;
Expand Down
2 changes: 2 additions & 0 deletions src/core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ TensorShape Tensor::shape() const { return TensorShape(m_requirements.shape, m_r

int64_t Tensor::shape(int d) const& { return shape()[d]; }

int64_t Tensor::shape(std::string_view dimension) const& { return shape()[dimension]; }

DataType Tensor::dtype() const { return DataType(m_requirements.dtype); }

TensorLayout Tensor::layout() const { return TensorLayout(m_requirements.layout); }
Expand Down
59 changes: 59 additions & 0 deletions src/core/tensor_layout.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/**
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/

#include "core/tensor_layout.hpp"

namespace roccv {

TensorLayout::TensorLayout(eTensorLayout layout) : m_layout(layout) {
if (layoutStringTable.count(m_layout) == 0) {
throw Exception("Invalid TensorLayout type", eStatusType::INVALID_VALUE);
}

m_layoutString = layoutStringTable.at(m_layout);
m_rank = m_layoutString.size();
}

int32_t TensorLayout::indexOf(std::string_view dim) const {
if (dim.size() != 1) {
throw Exception("Dimension must be a single character", eStatusType::INVALID_VALUE);
}

auto index = m_layoutString.find(dim);
if (index == std::string::npos) {
return -1;
}

return index;
}

std::string_view TensorLayout::dimAt(int32_t index) const {
if (index < 0 || index >= m_rank) {
throw Exception(
"Invalid index: " + std::to_string(index) + ". Index must be >= 0 and < " + std::to_string(m_rank),
eStatusType::OUT_OF_BOUNDS);
}

return std::string_view(&m_layoutString[index], 1);
}

} // namespace roccv
22 changes: 21 additions & 1 deletion src/core/tensor_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ THE SOFTWARE.
#include "core/tensor_shape.hpp"

#include <algorithm>
#include <vector>

#include "core/exception.hpp"
#include "core/status_type.h"
Expand Down Expand Up @@ -108,11 +109,22 @@ TensorShape &TensorShape::operator=(const TensorShape &other) {

int64_t TensorShape::operator[](int32_t i) const {
if (i < 0 || i >= this->m_layout.rank()) {
throw Exception("Invalid parameter: Index must be >= 0 and < rank.", eStatusType::OUT_OF_BOUNDS);
throw Exception("TensorShape index out of bounds: " + std::to_string(i) + ". Dimension must be >= 0 and < " +
std::to_string(this->m_layout.rank()),
eStatusType::OUT_OF_BOUNDS);
}
return m_shape[i];
}

int64_t TensorShape::operator[](std::string_view dimension) const {
int32_t index = m_layout.indexOf(dimension);
if (index == -1) {
throw Exception("Invalid dimension: " + std::string(dimension) + ". Dimension must be in the layout.",
eStatusType::OUT_OF_BOUNDS);
}
return operator[](index);
}

bool TensorShape::operator==(const TensorShape &rhs) const {
if (this->m_layout != rhs.m_layout) {
return false;
Expand All @@ -139,4 +151,12 @@ const TensorLayout &TensorShape::layout() const { return m_layout; }

const std::array<int64_t, ROCCV_TENSOR_MAX_RANK> &TensorShape::shape() const { return m_shape; }

TensorShape TensorShape::permute(const TensorLayout &layout) const {
std::vector<int64_t> permutedShape(layout.rank());
for (int32_t i = 0; i < layout.rank(); i++) {
permutedShape[i] = operator[](layout.dimAt(i));
}
return TensorShape(layout, permutedShape);
}

} // namespace roccv
36 changes: 36 additions & 0 deletions tests/roccv/cpp/src/tests/core/tensor/test_tensor_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ void TestNegativeTensorShape() {
{
EXPECT_EXCEPTION(TensorShape shape({1, 2, 3}, "NWCX"), eStatusType::INVALID_VALUE);
}

// Test permute operation with invalid layout
{
TensorShape shape({1, 2, 3}, "HWC");
EXPECT_EXCEPTION(shape.permute(TensorLayout(TENSOR_LAYOUT_NCHW)), eStatusType::OUT_OF_BOUNDS);
}
}

/**
Expand Down Expand Up @@ -87,6 +93,36 @@ void TestTensorShapeCorrectness() {
shape2 = shape1;
EXPECT_TRUE(shape1 == shape2);
}

// Test TensorShape index operator
{
TensorShape shape({1, 2, 3, 4}, "NHWC");
EXPECT_EQ(shape["N"], 1);
EXPECT_EQ(shape["H"], 2);
EXPECT_EQ(shape["W"], 3);
EXPECT_EQ(shape["C"], 4);
EXPECT_EXCEPTION(shape["X"], eStatusType::OUT_OF_BOUNDS);
}

// Test TensorShape permute operator
{
TensorShape shape({1, 2, 3, 4}, "NHWC");
TensorShape permutedShape = shape.permute(TensorLayout(TENSOR_LAYOUT_NCHW));
EXPECT_TRUE(permutedShape.layout() == eTensorLayout::TENSOR_LAYOUT_NCHW);
EXPECT_EQ(permutedShape["N"], 1);
EXPECT_EQ(permutedShape["C"], 4);
EXPECT_EQ(permutedShape["H"], 2);
EXPECT_EQ(permutedShape["W"], 3);
}

// Test TensorShape containsDim operator
{
TensorShape shape({1, 2, 3}, "HWC");
EXPECT_TRUE(shape.containsDim("H"));
EXPECT_TRUE(shape.containsDim("W"));
EXPECT_TRUE(shape.containsDim("C"));
EXPECT_FALSE(shape.containsDim("N"));
}
}
} // namespace

Expand Down