diff options
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index a21de809f7..d2bb6df625 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -154,15 +154,18 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes using arm_compute::PadStrideInfo; using arm_compute::PoolingLayerInfo; using arm_compute::Size2D; + using arm_compute::DataLayout; // Resolve ARM Compute layer parameters. const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); + const DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout); + bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0); //use specific constructor if global pooling if(isGlobalPooling) { - return arm_compute::PoolingLayerInfo(poolingType); + return arm_compute::PoolingLayerInfo(poolingType, dataLayout); } const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType( @@ -179,7 +182,8 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight); - return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding, fpMixedPrecision); + return arm_compute::PoolingLayerInfo(poolingType, poolSize, dataLayout, padStrideInfo, excludePadding, + fpMixedPrecision); } arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor) |