From a71c0633eff4791eb98362f72f198a2d1ec3d8f9 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Mon, 12 Apr 2021 13:12:19 +0100 Subject: IVGCVSW-5826 Remove cross-wireing in depthwise * The permutation of the tensor info is now completely handled in the armnnUtils::Permuted function. That includes quantization informations too !armnn:5411 Signed-off-by: Jan Eilers Change-Id: I40410141303d950be7888f9e491133251b6f69d8 --- Utils.cpp | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) (limited to 'Utils.cpp') diff --git a/Utils.cpp b/Utils.cpp index 18842814..6fd1a785 100644 --- a/Utils.cpp +++ b/Utils.cpp @@ -31,25 +31,12 @@ namespace armnn_driver { const armnn::PermutationVector g_DontPermute{}; -namespace -{ - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input, - void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings) -{ - assert(inTensorShape.GetNumDimensions() == 4U); - - armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize); -} - -} // anonymous namespace - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output, +void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output, const armnn::PermutationVector& mappings) { - assert(tensor.GetNumDimensions() == 4U); + assert(tensorInfo.GetNumDimensions() == 4U); - armnn::DataType dataType = tensor.GetDataType(); + armnn::DataType dataType = tensorInfo.GetDataType(); switch (dataType) { case armnn::DataType::Float16: @@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::QAsymmU8: case armnn::DataType::QSymmS8: case armnn::DataType::QAsymmS8: - SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); + // First swizzle tensor info + tensorInfo = armnnUtils::Permuted(tensorInfo, mappings); + // Then swizzle tensor data + armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType)); break; default: ALOGW("Unknown armnn::DataType for swizzling"); -- cgit v1.2.1