diff options
Diffstat (limited to 'Utils.cpp')
-rw-r--r-- | Utils.cpp | 24 |
1 files changed, 7 insertions, 17 deletions
@@ -31,25 +31,12 @@ namespace armnn_driver { const armnn::PermutationVector g_DontPermute{}; -namespace -{ - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input, - void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings) -{ - assert(inTensorShape.GetNumDimensions() == 4U); - - armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize); -} - -} // anonymous namespace - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output, +void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output, const armnn::PermutationVector& mappings) { - assert(tensor.GetNumDimensions() == 4U); + assert(tensorInfo.GetNumDimensions() == 4U); - armnn::DataType dataType = tensor.GetDataType(); + armnn::DataType dataType = tensorInfo.GetDataType(); switch (dataType) { case armnn::DataType::Float16: @@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::QAsymmU8: case armnn::DataType::QSymmS8: case armnn::DataType::QAsymmS8: - SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); + // First swizzle tensor info + tensorInfo = armnnUtils::Permuted(tensorInfo, mappings); + // Then swizzle tensor data + armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType)); break; default: ALOGW("Unknown armnn::DataType for swizzling"); |