aboutsummaryrefslogtreecommitdiff
path: root/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Utils.cpp')
-rw-r--r--Utils.cpp24
1 files changed, 7 insertions, 17 deletions
diff --git a/Utils.cpp b/Utils.cpp
index 18842814..6fd1a785 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -31,25 +31,12 @@ namespace armnn_driver
{
const armnn::PermutationVector g_DontPermute{};
-namespace
-{
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input,
- void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings)
-{
- assert(inTensorShape.GetNumDimensions() == 4U);
-
- armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize);
-}
-
-} // anonymous namespace
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output,
+void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output,
const armnn::PermutationVector& mappings)
{
- assert(tensor.GetNumDimensions() == 4U);
+ assert(tensorInfo.GetNumDimensions() == 4U);
- armnn::DataType dataType = tensor.GetDataType();
+ armnn::DataType dataType = tensorInfo.GetDataType();
switch (dataType)
{
case armnn::DataType::Float16:
@@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::QAsymmU8:
case armnn::DataType::QSymmS8:
case armnn::DataType::QAsymmS8:
- SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
+ // First swizzle tensor info
+ tensorInfo = armnnUtils::Permuted(tensorInfo, mappings);
+ // Then swizzle tensor data
+ armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType));
break;
default:
ALOGW("Unknown armnn::DataType for swizzling");