From a71c0633eff4791eb98362f72f198a2d1ec3d8f9 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 12 Apr 2021 13:12:19 +0100 Subject: IVGCVSW-5826 Remove cross-wireing in depthwise * The permutation of the tensor info is now completely handled in the armnnUtils::Permuted function. That includes quantization informations too !armnn:5411 Signed-off-by: Jan Eilers Change-Id: I40410141303d950be7888f9e491133251b6f69d8 --- ConversionUtils.cpp | 4 ++-- ConversionUtils.hpp | 7 +------ Utils.cpp | 24 +++++++----------------- Utils.hpp | 2 +- 4 files changed, 11 insertions(+), 26 deletions(-) diff --git a/ConversionUtils.cpp b/ConversionUtils.cpp index 9cc6e286..4cea7276 100644 --- a/ConversionUtils.cpp +++ b/ConversionUtils.cpp @@ -56,7 +56,7 @@ ConstTensorPin::ConstTensorPin(bool optional) : m_Optional(optional) {} -ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo, +ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes, const armnn::PermutationVector& mappings) @@ -73,7 +73,7 @@ ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo, m_SwizzledTensorData.resize(tensorInfo.GetNumBytes()); SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings); - m_ConstTensor = armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, mappings), m_SwizzledTensorData.data()); + m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data()); } else { diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 439d4a4a..473f6d78 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -90,7 +90,7 @@ public: // @param valueStart Start address of tensor data. Belongs to one of the memory pools associated with // the model being converted. // @param numBytes Number of bytes for the tensor data. - ConstTensorPin(const armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes, + ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes, const armnn::PermutationVector& mappings); ConstTensorPin(const ConstTensorPin& other) = delete; @@ -843,11 +843,6 @@ ConstTensorPin ConvertOperandToConstTensorPin(const HalOperand& operand, } armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand); - // Android datalayout might be different than armnn datalayout, e.g. the kernel for the depthwise convolution. - if (tensorInfo.HasPerAxisQuantization()) - { - tensorInfo.SetQuantizationDim(dimensionMappings[tensorInfo.GetQuantizationDim().value()]); - } if (overrideTensorShape != nullptr) { diff --git a/Utils.cpp b/Utils.cpp index 18842814..6fd1a785 100644 --- a/Utils.cpp +++ b/Utils.cpp @@ -31,25 +31,12 @@ namespace armnn_driver { const armnn::PermutationVector g_DontPermute{}; -namespace -{ - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input, - void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings) -{ - assert(inTensorShape.GetNumDimensions() == 4U); - - armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize); -} - -} // anonymous namespace - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output, +void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output, const armnn::PermutationVector& mappings) { - assert(tensor.GetNumDimensions() == 4U); + assert(tensorInfo.GetNumDimensions() == 4U); - armnn::DataType dataType = tensor.GetDataType(); + armnn::DataType dataType = tensorInfo.GetDataType(); switch (dataType) { case armnn::DataType::Float16: @@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::QAsymmU8: case armnn::DataType::QSymmS8: case armnn::DataType::QAsymmS8: - SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); + // First swizzle tensor info + tensorInfo = armnnUtils::Permuted(tensorInfo, mappings); + // Then swizzle tensor data + armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType)); break; default: ALOGW("Unknown armnn::DataType for swizzling"); diff --git a/Utils.hpp b/Utils.hpp index f68747b0..893c4a08 100644 --- a/Utils.hpp +++ b/Utils.hpp @@ -60,7 +60,7 @@ public: }; /// Swizzles tensor data in @a input according to the dimension mappings. -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output, +void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensor, const void* input, void* output, const armnn::PermutationVector& mappings); /// Returns a pointer to a specific location in a pool -- cgit v1.2.1