diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.hpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.hpp | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp index fab643ec1f..f5ae770d6b 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp @@ -123,10 +123,22 @@ arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor& descri /// Utility function used to setup an arm_compute::CropInfo object from an ArmNN layer descriptor. template <typename Descriptor> -arm_compute::CropInfo BuildArmComputeCropInfo(const Descriptor& descriptor) +arm_compute::CropInfo BuildArmComputeCropInfo(const Descriptor& descriptor, const unsigned int rank = 4) { - return arm_compute::CropInfo(descriptor.m_Crops[1].first, descriptor.m_Crops[1].second, - descriptor.m_Crops[0].first, descriptor.m_Crops[0].second); + if (rank == 3) + { + return arm_compute::CropInfo(0, 0, + descriptor.m_Crops[0].first, descriptor.m_Crops[0].second); + } + else if (rank == 4) + { + return arm_compute::CropInfo(descriptor.m_Crops[1].first, descriptor.m_Crops[1].second, + descriptor.m_Crops[0].first, descriptor.m_Crops[0].second); + } + else + { + throw InvalidArgumentException("Tensor rank must be either 3 or 4", CHECK_LOCATION()); + } } /// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor. |