aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UtilsTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/UtilsTests.cpp')
-rw-r--r--src/armnn/test/UtilsTests.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp
index 0bae756a71..fb078de32f 100644
--- a/src/armnn/test/UtilsTests.cpp
+++ b/src/armnn/test/UtilsTests.cpp
@@ -9,6 +9,7 @@
#include <armnn/Types.hpp>
#include <armnn/TypesUtils.hpp>
#include <armnn/Descriptors.hpp>
+#include <armnnUtils/Permute.hpp>
#include <GraphTopologicalSort.hpp>
#include <Graph.hpp>
#include <ResolveType.hpp>
@@ -245,4 +246,24 @@ BOOST_AUTO_TEST_CASE(CyclicalGraphTopologicalSortTest)
BOOST_TEST(!sortCompleted);
}
+BOOST_AUTO_TEST_CASE(PermuteQuantizationDim)
+{
+ std::vector<float> scales;
+
+ // Set QuantizationDim to be index 1
+ const armnn::TensorInfo info({ 1, 2, 3, 4 }, armnn::DataType::Float32, scales, 1U);
+ BOOST_CHECK(info.GetQuantizationDim().value() == 1U);
+
+ // Permute so that index 1 moves to final index i.e. index 3
+ armnn::PermutationVector mappings({ 0, 3, 2, 1 });
+ auto permutedPerChannel = armnnUtils::Permuted(info, mappings, true);
+ auto permuted = armnnUtils::Permuted(info, mappings);
+
+ // Check that QuantizationDim is in index 3
+ BOOST_CHECK(permutedPerChannel.GetQuantizationDim().value() == 3U);
+
+ // Check previous implementation unchanged
+ BOOST_CHECK(permuted.GetQuantizationDim().value() == 1U);
+}
+
BOOST_AUTO_TEST_SUITE_END()