aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2023-06-19 12:06:19 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2023-07-10 11:35:02 +0000
commit2ea403d130db0d2853d5c43c29b5112893efc2bf (patch)
treeb2e64805b95825c3cd29f05c5838b9d71124bd4b /src/backends/aclCommon
parent944fb508b1c30415e423b8916849c66a13867ea4 (diff)
downloadarmnn-2ea403d130db0d2853d5c43c29b5112893efc2bf.tar.gz
IVGCVSW-7785 3D tensors in BATCH_TO_SPACE and SPACE_TO_BATCH in CpuAcc & GpuAcc
* Add Reshape layers before and after to extend support for 3D tensors, as ACL only supports 4D tensors for those layers * Add Unit Tests Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I4431185ce3a3b2f595d2a79bdda7095212d1c52d
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.hpp18
1 files changed, 15 insertions, 3 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
index fab643ec1f..f5ae770d6b 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -123,10 +123,22 @@ arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor& descri
/// Utility function used to setup an arm_compute::CropInfo object from an ArmNN layer descriptor.
template <typename Descriptor>
-arm_compute::CropInfo BuildArmComputeCropInfo(const Descriptor& descriptor)
+arm_compute::CropInfo BuildArmComputeCropInfo(const Descriptor& descriptor, const unsigned int rank = 4)
{
- return arm_compute::CropInfo(descriptor.m_Crops[1].first, descriptor.m_Crops[1].second,
- descriptor.m_Crops[0].first, descriptor.m_Crops[0].second);
+ if (rank == 3)
+ {
+ return arm_compute::CropInfo(0, 0,
+ descriptor.m_Crops[0].first, descriptor.m_Crops[0].second);
+ }
+ else if (rank == 4)
+ {
+ return arm_compute::CropInfo(descriptor.m_Crops[1].first, descriptor.m_Crops[1].second,
+ descriptor.m_Crops[0].first, descriptor.m_Crops[0].second);
+ }
+ else
+ {
+ throw InvalidArgumentException("Tensor rank must be either 3 or 4", CHECK_LOCATION());
+ }
}
/// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor.