aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon/ArmComputeTensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp69
1 files changed, 68 insertions, 1 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index 38c7f70da5..e6c5a9b41c 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <aclCommon/ArmComputeTensorUtils.hpp>
@@ -146,6 +146,51 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
return shape;
}
+std::vector<unsigned int> ReduceDimsForACL(const armnn::TensorShape tensorShape, unsigned int dimensions)
+{
+ std::vector<unsigned int> newShape;
+
+ unsigned int dimsToSkip = 0;
+
+ if (tensorShape.GetNumDimensions() > dimensions)
+ {
+ dimsToSkip = tensorShape.GetNumDimensions() - dimensions;
+ }
+ unsigned int dimsSkipped = 0;
+ bool insertRemainder = false;
+
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
+ {
+ if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder)
+ {
+ ++dimsSkipped;
+ continue;
+ }
+ newShape.insert(newShape.begin(), tensorShape[i]);
+ // Once we insert the first dimension we can't skip any more
+ insertRemainder = true;
+ }
+ return newShape;
+}
+
+arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions)
+{
+ arm_compute::TensorShape shape;
+ std::vector<unsigned int> strippedShape = ReduceDimsForACL(tensorShape, dimensions);
+
+ for (unsigned int i = 0; i < strippedShape.size(); i++)
+ {
+ shape.set(i, strippedShape[i], false);
+ }
+
+ // prevent arm_compute issue where tensor is flattened to nothing
+ if (shape.num_dimensions() == 0)
+ {
+ shape.set_num_dimensions(1);
+ }
+ return shape;
+}
+
// Utility function used to build a TensorInfo object, that can be used to initialise
// ARM Compute Tensor and CLTensor allocators.
// Note: this utility ignores the value of armnn::TensorInfo.IsConstant(). ACL tensors
@@ -174,6 +219,28 @@ arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tenso
return aclTensorInfo;
}
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions)
+{
+ bool multiScales = tensorInfo.HasMultipleQuantizationScales();
+ const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape(), dimensions);
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales);
+
+ const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ?
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
+
+ return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
+}
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout, unsigned int dimensions)
+{
+ arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo, dimensions);
+ aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
+
+ return aclTensorInfo;
+}
+
+
arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
{
switch(dataLayout)