diff options
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index deed61dc5f..76d21f6332 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -398,8 +398,7 @@ armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) } } -armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor, - const armnn::PermutationVector& dimensionMappings = {0, 1, 2, 3}) +armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) { armnn::DataType type = GetDataType(tfLiteTensor); armnn::TensorInfo ret; @@ -453,8 +452,7 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor, quantizationScales.push_back(affineQuantization->scale->data[i]); } ret.SetQuantizationScales(quantizationScales); - ret.SetQuantizationDim(dimensionMappings[armnn::numeric_cast<unsigned int>( - affineQuantization->quantized_dimension)]); + ret.SetQuantizationDim(armnn::numeric_cast<unsigned int>(affineQuantization->quantized_dimension)); } else { @@ -485,13 +483,16 @@ armnn::ConstTensor CreateConstTensor(const TfLiteTensor* tfLiteTensor, if (permutationVector.has_value() && permutationVector.value().GetSize() > 0 && permutationData != nullptr) { - armnnUtils::Permute(armnnUtils::Permuted(tensorInfo.GetShape(), permutationVector.value()), + // Permute tensor info + tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value()); + // then permute data using the shape from permuted tensor info + armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(), tfLiteTensor->data.data, permutationData, armnn::GetDataTypeSize(tensorInfo.GetDataType())); - return armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, permutationVector.value()), permutationData); + return armnn::ConstTensor(tensorInfo, permutationData); } else { |