diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/backends/ArmComputeTensorUtils.cpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnn/backends/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/armnn/backends/ArmComputeTensorUtils.cpp | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/src/armnn/backends/ArmComputeTensorUtils.cpp b/src/armnn/backends/ArmComputeTensorUtils.cpp index f88ed2b4c3..8e4abaf67a 100644 --- a/src/armnn/backends/ArmComputeTensorUtils.cpp +++ b/src/armnn/backends/ArmComputeTensorUtils.cpp @@ -16,23 +16,17 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType) { switch(dataType) { + case armnn::DataType::Float16: + return arm_compute::DataType::F16; case armnn::DataType::Float32: - { return arm_compute::DataType::F32; - } case armnn::DataType::QuantisedAsymm8: - { return arm_compute::DataType::QASYMM8; - } case armnn::DataType::Signed32: - { return arm_compute::DataType::S32; - } default: - { BOOST_ASSERT_MSG(false, "Unknown data type"); return arm_compute::DataType::UNKNOWN; - } } } @@ -40,15 +34,15 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te { arm_compute::TensorShape shape; - // armnn tensors are (batch, channels, height, width) - // arm_compute tensors are (width, height, channels, batch) + // armnn tensors are (batch, channels, height, width). + // arm_compute tensors are (width, height, channels, batch). for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++) { - // note that our dimensions are stored in the opposite order to ACL's + // Note that our dimensions are stored in the opposite order to ACL's. shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]); // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen. - // arm_compute tensors expect this + // arm_compute tensors expect this. } // prevent arm_compute issue where tensor is flattened to nothing @@ -80,11 +74,18 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes using arm_compute::PoolingLayerInfo; using arm_compute::Size2D; - // Resolve ARM Compute layer parameters + // Resolve ARM Compute layer parameters. const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); + + bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0); + //use specific constructor if global pooling + if(isGlobalPooling) + { + return arm_compute::PoolingLayerInfo(poolingType); + } + const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType( descriptor.m_OutputShapeRounding); - const PadStrideInfo padStrideInfo(descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_PadLeft, |