diff --git a/include/data/tensor.hpp b/include/data/tensor.hpp index 7973144..f82e9a8 100644 --- a/include/data/tensor.hpp +++ b/include/data/tensor.hpp @@ -1,4 +1,4 @@ -// + // // Created by fss on 22-12-16. // @@ -25,7 +25,7 @@ class Tensor { public: explicit Tensor() = default; - explicit Tensor(uint32_t channels, uint32_t rows, uint32_t cols); + explicit Tensor(uint32_t rows, uint32_t cols, uint32_t channels); Tensor(const Tensor &tensor); diff --git a/source/data/tensor.cpp b/source/data/tensor.cpp index da29302..9a3c81d 100644 --- a/source/data/tensor.cpp +++ b/source/data/tensor.cpp @@ -7,7 +7,7 @@ namespace kuiper_infer { -Tensor::Tensor(uint32_t channels, uint32_t rows, uint32_t cols) { +Tensor::Tensor(uint32_t rows, uint32_t cols, uint32_t channels) { data_ = arma::fcube(rows, cols, channels); } diff --git a/test/test_tensor.cpp b/test/test_tensor.cpp index 070f77d..70d7626 100644 --- a/test/test_tensor.cpp +++ b/test/test_tensor.cpp @@ -10,6 +10,7 @@ TEST(test_tensor, create) { using namespace kuiper_infer; Tensor tensor(3, 32, 32); + ASSERT_EQ(tensor.empty(), false); ASSERT_EQ(tensor.channels(), 3); ASSERT_EQ(tensor.rows(), 32); ASSERT_EQ(tensor.cols(), 32); @@ -18,6 +19,7 @@ TEST(test_tensor, create) { TEST(test_tensor, fill) { using namespace kuiper_infer; Tensor tensor(3, 3, 3); + ASSERT_EQ(tensor.empty(), false); ASSERT_EQ(tensor.channels(), 3); ASSERT_EQ(tensor.rows(), 3); ASSERT_EQ(tensor.cols(), 3); @@ -44,6 +46,7 @@ TEST(test_tensor, fill) { TEST(test_tensor, padding1) { using namespace kuiper_infer; Tensor tensor(3, 3, 3); + ASSERT_EQ(tensor.empty(), false); ASSERT_EQ(tensor.channels(), 3); ASSERT_EQ(tensor.rows(), 3); ASSERT_EQ(tensor.cols(), 3);