From 7612bd6cc385dfbf54f831a6349f3a9363c6d0a2 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Tue, 6 Apr 2021 17:29:03 +0100 Subject: IVGCVSW-5842 Remove cross-wiring in depthwise * Reading tensor infos won't allow a permutation vector anymore. The permutation only changed the quantization dimension not the shape and was therefore misleading * The permutation of the full tensor info is now performed in armnnUtils::Permuted * Changed TfLite Parser depthwise parsing function * Added unit tests to TfLite Parser with more random data * Changed TfLite Delegate depthwise parsing function * Added unit test to the delegate with per channel quantization !android-nn-driver:5412 Signed-off-by: Jan Eilers Change-Id: I1f985ee69547bcaf16a72201e00a6b6fe1ef9a97 --- delegate/src/DelegateUtils.hpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'delegate/src/DelegateUtils.hpp') diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index deed61dc5f..76d21f6332 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -398,8 +398,7 @@ armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) } } -armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor, - const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3}) +armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) { armnn::DataType type = GetDataType(tfLiteTensor); armnn::TensorInfo ret; @@ -453,8 +452,7 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor, quantizationScales.push_back(affineQuantization->scale->data[i]); } ret.SetQuantizationScales(quantizationScales); - ret.SetQuantizationDim(dimensionMappings[armnn::numeric_cast( - affineQuantization->quantized_dimension)]); + ret.SetQuantizationDim(armnn::numeric_cast(affineQuantization->quantized_dimension)); } else { @@ -485,13 +483,16 @@ armnn::ConstTensor CreateConstTensor(const TfLiteTensor* tfLiteTensor, if (permutationVector.has_value() && permutationVector.value().GetSize() > 0 && permutationData != nullptr) { - armnnUtils::Permute(armnnUtils::Permuted(tensorInfo.GetShape(), permutationVector.value()), + // Permute tensor info + tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value()); + // then permute data using the shape from permuted tensor info + armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(), tfLiteTensor->data.data, permutationData, armnn::GetDataTypeSize(tensorInfo.GetDataType())); - return armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, permutationVector.value()), permutationData); + return armnn::ConstTensor(tensorInfo, permutationData); } else { -- cgit v1.2.1