diff options
-rw-r--r-- | include/armnn/Tensor.hpp | 11 | ||||
-rw-r--r-- | src/armnn/test/TensorTest.cpp | 63 |
2 files changed, 74 insertions, 0 deletions
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index 69ffbd9fcc..8814d89174 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -235,6 +235,17 @@ private: (m_QuantizationDim == other.m_QuantizationDim)); } + Quantization& operator=(const Quantization& other) + { + if(!(*this == other)) + { + m_Scales = other.m_Scales; + m_Offset = other.m_Offset; + m_QuantizationDim = other.m_QuantizationDim; + } + return *this; + } + std::vector<float> m_Scales; Optional<int32_t> m_Offset; Optional<unsigned int> m_QuantizationDim; diff --git a/src/armnn/test/TensorTest.cpp b/src/armnn/test/TensorTest.cpp index ed3925539b..a0b68acdd2 100644 --- a/src/armnn/test/TensorTest.cpp +++ b/src/armnn/test/TensorTest.cpp @@ -99,6 +99,69 @@ BOOST_FIXTURE_TEST_CASE(TensorInfoAssignmentOperator, TensorInfoFixture) BOOST_TEST(copy == m_TensorInfo); } +BOOST_AUTO_TEST_CASE(CopyNoQuantizationTensorInfo) +{ + TensorInfo infoA; + infoA.SetShape({ 5, 6, 7, 8 }); + infoA.SetDataType(DataType::QAsymmU8); + + TensorInfo infoB; + infoB.SetShape({ 5, 6, 7, 8 }); + infoB.SetDataType(DataType::QAsymmU8); + infoB.SetQuantizationScale(10.0f); + infoB.SetQuantizationOffset(5); + infoB.SetQuantizationDim(Optional<unsigned int>(1)); + + BOOST_TEST((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 }))); + BOOST_TEST((infoA.GetDataType() == DataType::QAsymmU8)); + BOOST_TEST(infoA.GetQuantizationScale() == 1); + BOOST_TEST(infoA.GetQuantizationOffset() == 0); + BOOST_CHECK(!infoA.GetQuantizationDim().has_value()); + + BOOST_TEST(infoA != infoB); + infoA = infoB; + BOOST_TEST(infoA == infoB); + + BOOST_TEST((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 }))); + BOOST_TEST((infoA.GetDataType() == DataType::QAsymmU8)); + BOOST_TEST(infoA.GetQuantizationScale() == 10.0f); + BOOST_TEST(infoA.GetQuantizationOffset() == 5); + BOOST_CHECK(infoA.GetQuantizationDim().value() == 1); +} + +BOOST_AUTO_TEST_CASE(CopyDifferentQuantizationTensorInfo) +{ + TensorInfo infoA; + infoA.SetShape({ 5, 6, 7, 8 }); + infoA.SetDataType(DataType::QAsymmU8); + infoA.SetQuantizationScale(10.0f); + infoA.SetQuantizationOffset(5); + infoA.SetQuantizationDim(Optional<unsigned int>(1)); + + TensorInfo infoB; + infoB.SetShape({ 5, 6, 7, 8 }); + infoB.SetDataType(DataType::QAsymmU8); + infoB.SetQuantizationScale(11.0f); + infoB.SetQuantizationOffset(6); + infoB.SetQuantizationDim(Optional<unsigned int>(2)); + + BOOST_TEST((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 }))); + BOOST_TEST((infoA.GetDataType() == DataType::QAsymmU8)); + BOOST_TEST(infoA.GetQuantizationScale() == 10.0f); + BOOST_TEST(infoA.GetQuantizationOffset() == 5); + BOOST_CHECK(infoA.GetQuantizationDim().value() == 1); + + BOOST_TEST(infoA != infoB); + infoA = infoB; + BOOST_TEST(infoA == infoB); + + BOOST_TEST((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 }))); + BOOST_TEST((infoA.GetDataType() == DataType::QAsymmU8)); + BOOST_TEST(infoA.GetQuantizationScale() == 11.0f); + BOOST_TEST(infoA.GetQuantizationOffset() == 6); + BOOST_CHECK(infoA.GetQuantizationDim().value() == 2); +} + void CheckTensor(const ConstTensor& t) { t.GetInfo(); |