aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ArmComputeTensorUtils.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/backends/ArmComputeTensorUtils.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnn/backends/ArmComputeTensorUtils.cpp')
-rw-r--r--src/armnn/backends/ArmComputeTensorUtils.cpp29
1 files changed, 15 insertions, 14 deletions
diff --git a/src/armnn/backends/ArmComputeTensorUtils.cpp b/src/armnn/backends/ArmComputeTensorUtils.cpp
index f88ed2b4c3..8e4abaf67a 100644
--- a/src/armnn/backends/ArmComputeTensorUtils.cpp
+++ b/src/armnn/backends/ArmComputeTensorUtils.cpp
@@ -16,23 +16,17 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
{
switch(dataType)
{
+ case armnn::DataType::Float16:
+ return arm_compute::DataType::F16;
case armnn::DataType::Float32:
- {
return arm_compute::DataType::F32;
- }
case armnn::DataType::QuantisedAsymm8:
- {
return arm_compute::DataType::QASYMM8;
- }
case armnn::DataType::Signed32:
- {
return arm_compute::DataType::S32;
- }
default:
- {
BOOST_ASSERT_MSG(false, "Unknown data type");
return arm_compute::DataType::UNKNOWN;
- }
}
}
@@ -40,15 +34,15 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
{
arm_compute::TensorShape shape;
- // armnn tensors are (batch, channels, height, width)
- // arm_compute tensors are (width, height, channels, batch)
+ // armnn tensors are (batch, channels, height, width).
+ // arm_compute tensors are (width, height, channels, batch).
for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
{
- // note that our dimensions are stored in the opposite order to ACL's
+ // Note that our dimensions are stored in the opposite order to ACL's.
shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
// TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
- // arm_compute tensors expect this
+ // arm_compute tensors expect this.
}
// prevent arm_compute issue where tensor is flattened to nothing
@@ -80,11 +74,18 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes
using arm_compute::PoolingLayerInfo;
using arm_compute::Size2D;
- // Resolve ARM Compute layer parameters
+ // Resolve ARM Compute layer parameters.
const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
+
+ bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
+ //use specific constructor if global pooling
+ if(isGlobalPooling)
+ {
+ return arm_compute::PoolingLayerInfo(poolingType);
+ }
+
const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
descriptor.m_OutputShapeRounding);
-
const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
descriptor.m_StrideY,
descriptor.m_PadLeft,