aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TensorTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TensorTest.cpp')
-rw-r--r--src/armnn/test/TensorTest.cpp34
1 files changed, 34 insertions, 0 deletions
diff --git a/src/armnn/test/TensorTest.cpp b/src/armnn/test/TensorTest.cpp
index a0a6c7e91f..154a0bca04 100644
--- a/src/armnn/test/TensorTest.cpp
+++ b/src/armnn/test/TensorTest.cpp
@@ -143,4 +143,38 @@ BOOST_AUTO_TEST_CASE(TensorShapeOperatorBrackets)
BOOST_TEST(shape[2] == 20);
}
+BOOST_AUTO_TEST_CASE(TensorInfoPerAxisQuantization)
+{
+ // Old constructor
+ TensorInfo tensorInfo0({ 1, 1 }, DataType::Float32, 2.0f, 1);
+ BOOST_CHECK(!tensorInfo0.HasMultipleQuantizationScales());
+ BOOST_CHECK(tensorInfo0.GetQuantizationScale() == 2.0f);
+ BOOST_CHECK(tensorInfo0.GetQuantizationOffset() == 1);
+ BOOST_CHECK(tensorInfo0.GetQuantizationScales()[0] == 2.0f);
+ BOOST_CHECK(!tensorInfo0.GetQuantizationDim().has_value());
+
+ // Set per-axis quantization scales
+ std::vector<float> perAxisScales{ 3.0f, 4.0f };
+ tensorInfo0.SetQuantizationScales(perAxisScales);
+ BOOST_CHECK(tensorInfo0.HasMultipleQuantizationScales());
+ BOOST_CHECK(tensorInfo0.GetQuantizationScales() == perAxisScales);
+
+ // Set per-tensor quantization scale
+ tensorInfo0.SetQuantizationScale(5.0f);
+ BOOST_CHECK(!tensorInfo0.HasMultipleQuantizationScales());
+ BOOST_CHECK(tensorInfo0.GetQuantizationScales()[0] == 5.0f);
+
+ // Set quantization offset
+ tensorInfo0.SetQuantizationDim(Optional<unsigned int>(1));
+ BOOST_CHECK(tensorInfo0.GetQuantizationDim().value() == 1);
+
+ // New constructor
+ perAxisScales = { 6.0f, 7.0f };
+ TensorInfo tensorInfo1({ 1, 1 }, DataType::Float32, perAxisScales, 1);
+ BOOST_CHECK(tensorInfo1.HasMultipleQuantizationScales());
+ BOOST_CHECK(tensorInfo1.GetQuantizationOffset() == 0);
+ BOOST_CHECK(tensorInfo1.GetQuantizationScales() == perAxisScales);
+ BOOST_CHECK(tensorInfo1.GetQuantizationDim().value() == 1);
+}
+
BOOST_AUTO_TEST_SUITE_END()