From 186c21c9a598cbe2e81ad53e5b5fd96d75f981f5 Mon Sep 17 00:00:00 2001 From: Francis Murtagh Date: Thu, 20 Aug 2020 15:38:29 +0100 Subject: Bugfix: Allow permutation of QuantizationDim Signed-off-by: Francis Murtagh Change-Id: Ib98ec471e6fdd47600b7c62d0b4d19dd36e20cbd --- include/armnnUtils/Permute.hpp | 7 +++++-- src/armnn/test/UtilsTests.cpp | 21 +++++++++++++++++++++ src/armnnUtils/Permute.cpp | 14 ++++++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/include/armnnUtils/Permute.hpp b/include/armnnUtils/Permute.hpp index 4e9bfc0823..d719f4a623 100644 --- a/include/armnnUtils/Permute.hpp +++ b/include/armnnUtils/Permute.hpp @@ -11,9 +11,12 @@ namespace armnnUtils { -armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings); +armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, + const armnn::PermutationVector& mappings); -armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings); +armnn::TensorInfo Permuted(const armnn::TensorInfo& info, + const armnn::PermutationVector& mappings, + bool perChannelPermute = false); void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings, const void* src, void* dst, size_t dataTypeSize); diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp index 0bae756a71..fb078de32f 100644 --- a/src/armnn/test/UtilsTests.cpp +++ b/src/armnn/test/UtilsTests.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -245,4 +246,24 @@ BOOST_AUTO_TEST_CASE(CyclicalGraphTopologicalSortTest) BOOST_TEST(!sortCompleted); } +BOOST_AUTO_TEST_CASE(PermuteQuantizationDim) +{ + std::vector scales; + + // Set QuantizationDim to be index 1 + const armnn::TensorInfo info({ 1, 2, 3, 4 }, armnn::DataType::Float32, scales, 1U); + BOOST_CHECK(info.GetQuantizationDim().value() == 1U); + + // Permute so that index 1 moves to final index i.e. index 3 + armnn::PermutationVector mappings({ 0, 3, 2, 1 }); + auto permutedPerChannel = armnnUtils::Permuted(info, mappings, true); + auto permuted = armnnUtils::Permuted(info, mappings); + + // Check that QuantizationDim is in index 3 + BOOST_CHECK(permutedPerChannel.GetQuantizationDim().value() == 3U); + + // Check previous implementation unchanged + BOOST_CHECK(permuted.GetQuantizationDim().value() == 1U); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnUtils/Permute.cpp b/src/armnnUtils/Permute.cpp index 486aac00c1..377046367c 100644 --- a/src/armnnUtils/Permute.cpp +++ b/src/armnnUtils/Permute.cpp @@ -95,7 +95,8 @@ private: namespace armnnUtils { -armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings) +armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, + const armnn::PermutationVector& mappings) { assert(srcShape.GetNumDimensions() == mappings.GetSize()); @@ -111,10 +112,19 @@ armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::Per return permutedShape; } -armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings) +armnn::TensorInfo Permuted(const armnn::TensorInfo& info, + const armnn::PermutationVector& mappings, + bool perChannelPermute) { armnn::TensorInfo outInfo(info); outInfo.SetShape(Permuted(info.GetShape(), mappings)); + + // If TensorInfo has Per-Axis Quantization then permute QuantizationDim to mapping + if (info.HasPerAxisQuantization() && perChannelPermute) + { + outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]); + } + return outInfo; } -- cgit v1.2.1