aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UtilsTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/UtilsTests.cpp')
-rw-r--r--src/armnn/test/UtilsTests.cpp16
1 files changed, 9 insertions, 7 deletions
diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp
index f0198cb9d4..a813feaf7f 100644
--- a/src/armnn/test/UtilsTests.cpp
+++ b/src/armnn/test/UtilsTests.cpp
@@ -249,22 +249,24 @@ BOOST_AUTO_TEST_CASE(CyclicalGraphTopologicalSortTest)
BOOST_AUTO_TEST_CASE(PermuteQuantizationDim)
{
- std::vector<float> scales;
+ std::vector<float> scales {1.0f, 1.0f};
// Set QuantizationDim to be index 1
- const armnn::TensorInfo info({ 1, 2, 3, 4 }, armnn::DataType::Float32, scales, 1U);
- BOOST_CHECK(info.GetQuantizationDim().value() == 1U);
+ const armnn::TensorInfo perChannelInfo({ 1, 2, 3, 4 }, armnn::DataType::Float32, scales, 1U);
+ BOOST_CHECK(perChannelInfo.GetQuantizationDim().value() == 1U);
// Permute so that index 1 moves to final index i.e. index 3
armnn::PermutationVector mappings({ 0, 3, 2, 1 });
- auto permutedPerChannel = armnnUtils::Permuted(info, mappings, true);
- auto permuted = armnnUtils::Permuted(info, mappings);
+ auto permutedPerChannel = armnnUtils::Permuted(perChannelInfo, mappings);
// Check that QuantizationDim is in index 3
BOOST_CHECK(permutedPerChannel.GetQuantizationDim().value() == 3U);
- // Check previous implementation unchanged
- BOOST_CHECK(permuted.GetQuantizationDim().value() == 1U);
+ // Even if there is only a single scale the quantization dim still exists and needs to be permuted
+ std::vector<float> scale {1.0f};
+ const armnn::TensorInfo perChannelInfo1({ 1, 2, 3, 4 }, armnn::DataType::Float32, scale, 1U);
+ auto permuted = armnnUtils::Permuted(perChannelInfo1, mappings);
+ BOOST_CHECK(permuted.GetQuantizationDim().value() == 3U);
}
#if defined(ARMNNREF_ENABLED)