aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2020-08-20 15:38:29 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2020-08-21 15:56:13 +0100
commit186c21c9a598cbe2e81ad53e5b5fd96d75f981f5 (patch)
tree45cb5522dcf36256f66bbe25eb8b24460cbdb37a
parentc84e45d933a9b45810a3bb88f6873f4eddca0975 (diff)
downloadarmnn-186c21c9a598cbe2e81ad53e5b5fd96d75f981f5.tar.gz
Bugfix: Allow permutation of QuantizationDim
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com> Change-Id: Ib98ec471e6fdd47600b7c62d0b4d19dd36e20cbd
-rw-r--r--include/armnnUtils/Permute.hpp7
-rw-r--r--src/armnn/test/UtilsTests.cpp21
-rw-r--r--src/armnnUtils/Permute.cpp14
3 files changed, 38 insertions, 4 deletions
diff --git a/include/armnnUtils/Permute.hpp b/include/armnnUtils/Permute.hpp
index 4e9bfc0823..d719f4a623 100644
--- a/include/armnnUtils/Permute.hpp
+++ b/include/armnnUtils/Permute.hpp
@@ -11,9 +11,12 @@
namespace armnnUtils
{
-armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings);
+armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
+ const armnn::PermutationVector& mappings);
-armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings);
+armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
+ const armnn::PermutationVector& mappings,
+ bool perChannelPermute = false);
void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
const void* src, void* dst, size_t dataTypeSize);
diff --git a/src/armnn/test/UtilsTests.cpp b/src/armnn/test/UtilsTests.cpp
index 0bae756a71..fb078de32f 100644
--- a/src/armnn/test/UtilsTests.cpp
+++ b/src/armnn/test/UtilsTests.cpp
@@ -9,6 +9,7 @@
#include <armnn/Types.hpp>
#include <armnn/TypesUtils.hpp>
#include <armnn/Descriptors.hpp>
+#include <armnnUtils/Permute.hpp>
#include <GraphTopologicalSort.hpp>
#include <Graph.hpp>
#include <ResolveType.hpp>
@@ -245,4 +246,24 @@ BOOST_AUTO_TEST_CASE(CyclicalGraphTopologicalSortTest)
BOOST_TEST(!sortCompleted);
}
+BOOST_AUTO_TEST_CASE(PermuteQuantizationDim)
+{
+ std::vector<float> scales;
+
+ // Set QuantizationDim to be index 1
+ const armnn::TensorInfo info({ 1, 2, 3, 4 }, armnn::DataType::Float32, scales, 1U);
+ BOOST_CHECK(info.GetQuantizationDim().value() == 1U);
+
+ // Permute so that index 1 moves to final index i.e. index 3
+ armnn::PermutationVector mappings({ 0, 3, 2, 1 });
+ auto permutedPerChannel = armnnUtils::Permuted(info, mappings, true);
+ auto permuted = armnnUtils::Permuted(info, mappings);
+
+ // Check that QuantizationDim is in index 3
+ BOOST_CHECK(permutedPerChannel.GetQuantizationDim().value() == 3U);
+
+ // Check previous implementation unchanged
+ BOOST_CHECK(permuted.GetQuantizationDim().value() == 1U);
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnUtils/Permute.cpp b/src/armnnUtils/Permute.cpp
index 486aac00c1..377046367c 100644
--- a/src/armnnUtils/Permute.cpp
+++ b/src/armnnUtils/Permute.cpp
@@ -95,7 +95,8 @@ private:
namespace armnnUtils
{
-armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
+armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
+ const armnn::PermutationVector& mappings)
{
assert(srcShape.GetNumDimensions() == mappings.GetSize());
@@ -111,10 +112,19 @@ armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::Per
return permutedShape;
}
-armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
+armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
+ const armnn::PermutationVector& mappings,
+ bool perChannelPermute)
{
armnn::TensorInfo outInfo(info);
outInfo.SetShape(Permuted(info.GetShape(), mappings));
+
+ // If TensorInfo has Per-Axis Quantization then permute QuantizationDim to mapping
+ if (info.HasPerAxisQuantization() && perChannelPermute)
+ {
+ outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
+ }
+
return outInfo;
}