aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon/ArmComputeTensorUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index a21de809f7..d2bb6df625 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -154,15 +154,18 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes
using arm_compute::PadStrideInfo;
using arm_compute::PoolingLayerInfo;
using arm_compute::Size2D;
+ using arm_compute::DataLayout;
// Resolve ARM Compute layer parameters.
const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
+ const DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout);
+
bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
//use specific constructor if global pooling
if(isGlobalPooling)
{
- return arm_compute::PoolingLayerInfo(poolingType);
+ return arm_compute::PoolingLayerInfo(poolingType, dataLayout);
}
const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
@@ -179,7 +182,8 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes
const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
- return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding, fpMixedPrecision);
+ return arm_compute::PoolingLayerInfo(poolingType, poolSize, dataLayout, padStrideInfo, excludePadding,
+ fpMixedPrecision);
}
arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)