diff options
Diffstat (limited to 'src/armnnUtils/Permute.cpp')
-rw-r--r-- | src/armnnUtils/Permute.cpp | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/src/armnnUtils/Permute.cpp b/src/armnnUtils/Permute.cpp index 486aac00c1..377046367c 100644 --- a/src/armnnUtils/Permute.cpp +++ b/src/armnnUtils/Permute.cpp @@ -95,7 +95,8 @@ private: namespace armnnUtils { -armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings) +armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, + const armnn::PermutationVector& mappings) { assert(srcShape.GetNumDimensions() == mappings.GetSize()); @@ -111,10 +112,19 @@ armnn::TensorShape Permuted(const armnn::TensorShape& srcShape, const armnn::Per return permutedShape; } -armnn::TensorInfo Permuted(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings) +armnn::TensorInfo Permuted(const armnn::TensorInfo& info, + const armnn::PermutationVector& mappings, + bool perChannelPermute) { armnn::TensorInfo outInfo(info); outInfo.SetShape(Permuted(info.GetShape(), mappings)); + + // If TensorInfo has Per-Axis Quantization then permute QuantizationDim to mapping + if (info.HasPerAxisQuantization() && perChannelPermute) + { + outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]); + } + return outInfo; } |