From c577f2c6a3b4ddb6ba87a882723c53a248afbeba Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 31 Aug 2018 09:22:23 +0100 Subject: Release 18.08 --- src/armnn/backends/ArmComputeTensorUtils.cpp | 29 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) (limited to 'src/armnn/backends/ArmComputeTensorUtils.cpp') diff --git a/src/armnn/backends/ArmComputeTensorUtils.cpp b/src/armnn/backends/ArmComputeTensorUtils.cpp index f88ed2b4c3..8e4abaf67a 100644 --- a/src/armnn/backends/ArmComputeTensorUtils.cpp +++ b/src/armnn/backends/ArmComputeTensorUtils.cpp @@ -16,23 +16,17 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType) { switch(dataType) { + case armnn::DataType::Float16: + return arm_compute::DataType::F16; case armnn::DataType::Float32: - { return arm_compute::DataType::F32; - } case armnn::DataType::QuantisedAsymm8: - { return arm_compute::DataType::QASYMM8; - } case armnn::DataType::Signed32: - { return arm_compute::DataType::S32; - } default: - { BOOST_ASSERT_MSG(false, "Unknown data type"); return arm_compute::DataType::UNKNOWN; - } } } @@ -40,15 +34,15 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te { arm_compute::TensorShape shape; - // armnn tensors are (batch, channels, height, width) - // arm_compute tensors are (width, height, channels, batch) + // armnn tensors are (batch, channels, height, width). + // arm_compute tensors are (width, height, channels, batch). for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++) { - // note that our dimensions are stored in the opposite order to ACL's + // Note that our dimensions are stored in the opposite order to ACL's. shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]); // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen. - // arm_compute tensors expect this + // arm_compute tensors expect this. } // prevent arm_compute issue where tensor is flattened to nothing @@ -80,11 +74,18 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes using arm_compute::PoolingLayerInfo; using arm_compute::Size2D; - // Resolve ARM Compute layer parameters + // Resolve ARM Compute layer parameters. const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); + + bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0); + //use specific constructor if global pooling + if(isGlobalPooling) + { + return arm_compute::PoolingLayerInfo(poolingType); + } + const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType( descriptor.m_OutputShapeRounding); - const PadStrideInfo padStrideInfo(descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_PadLeft, -- cgit v1.2.1