diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 69 |
1 files changed, 68 insertions, 1 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index 38c7f70da5..e6c5a9b41c 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include <aclCommon/ArmComputeTensorUtils.hpp> @@ -146,6 +146,51 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te return shape; } +std::vector<unsigned int> ReduceDimsForACL(const armnn::TensorShape tensorShape, unsigned int dimensions) +{ + std::vector<unsigned int> newShape; + + unsigned int dimsToSkip = 0; + + if (tensorShape.GetNumDimensions() > dimensions) + { + dimsToSkip = tensorShape.GetNumDimensions() - dimensions; + } + unsigned int dimsSkipped = 0; + bool insertRemainder = false; + + for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i) + { + if (tensorShape[i] == 1 && dimsSkipped < dimsToSkip && !insertRemainder) + { + ++dimsSkipped; + continue; + } + newShape.insert(newShape.begin(), tensorShape[i]); + // Once we insert the first dimension we can't skip any more + insertRemainder = true; + } + return newShape; +} + +arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape, unsigned int dimensions) +{ + arm_compute::TensorShape shape; + std::vector<unsigned int> strippedShape = ReduceDimsForACL(tensorShape, dimensions); + + for (unsigned int i = 0; i < strippedShape.size(); i++) + { + shape.set(i, strippedShape[i], false); + } + + // prevent arm_compute issue where tensor is flattened to nothing + if (shape.num_dimensions() == 0) + { + shape.set_num_dimensions(1); + } + return shape; +} + // Utility function used to build a TensorInfo object, that can be used to initialise // ARM Compute Tensor and CLTensor allocators. // Note: this utility ignores the value of armnn::TensorInfo.IsConstant(). ACL tensors @@ -174,6 +219,28 @@ arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tenso return aclTensorInfo; } +arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, unsigned int dimensions) +{ + bool multiScales = tensorInfo.HasMultipleQuantizationScales(); + const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape(), dimensions); + const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType(), multiScales); + + const arm_compute::QuantizationInfo aclQuantizationInfo = multiScales ? + arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) : + arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset()); + + return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo); +} +arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, + armnn::DataLayout dataLayout, unsigned int dimensions) +{ + arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo, dimensions); + aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout)); + + return aclTensorInfo; +} + + arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout) { switch(dataLayout) |