aboutsummaryrefslogtreecommitdiff
path: root/Utils.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2021-04-12 13:12:19 +0100
committerJan Eilers <jan.eilers@arm.com>2021-04-19 09:31:02 +0100
commita71c0633eff4791eb98362f72f198a2d1ec3d8f9 (patch)
tree9976eb70e4c5f3a93b4db0b06e02f8d03057a235 /Utils.cpp
parent32fe97ec627a70b6453375fcfc6665c0e1ad2024 (diff)
downloadandroid-nn-driver-a71c0633eff4791eb98362f72f198a2d1ec3d8f9.tar.gz
IVGCVSW-5826 Remove cross-wireing in depthwise
* The permutation of the tensor info is now completely handled in the armnnUtils::Permuted function. That includes quantization informations too !armnn:5411 Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I40410141303d950be7888f9e491133251b6f69d8
Diffstat (limited to 'Utils.cpp')
-rw-r--r--Utils.cpp24
1 files changed, 7 insertions, 17 deletions
diff --git a/Utils.cpp b/Utils.cpp
index 18842814..6fd1a785 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -31,25 +31,12 @@ namespace armnn_driver
{
const armnn::PermutationVector g_DontPermute{};
-namespace
-{
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input,
- void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings)
-{
- assert(inTensorShape.GetNumDimensions() == 4U);
-
- armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize);
-}
-
-} // anonymous namespace
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output,
+void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output,
const armnn::PermutationVector& mappings)
{
- assert(tensor.GetNumDimensions() == 4U);
+ assert(tensorInfo.GetNumDimensions() == 4U);
- armnn::DataType dataType = tensor.GetDataType();
+ armnn::DataType dataType = tensorInfo.GetDataType();
switch (dataType)
{
case armnn::DataType::Float16:
@@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::QAsymmU8:
case armnn::DataType::QSymmS8:
case armnn::DataType::QAsymmS8:
- SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
+ // First swizzle tensor info
+ tensorInfo = armnnUtils::Permuted(tensorInfo, mappings);
+ // Then swizzle tensor data
+ armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType));
break;
default:
ALOGW("Unknown armnn::DataType for swizzling");