diff options
Diffstat (limited to 'src/armnn/backends/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/armnn/backends/ArmComputeTensorUtils.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/armnn/backends/ArmComputeTensorUtils.cpp b/src/armnn/backends/ArmComputeTensorUtils.cpp index 9f21c41a2f..f88ed2b4c3 100644 --- a/src/armnn/backends/ArmComputeTensorUtils.cpp +++ b/src/armnn/backends/ArmComputeTensorUtils.cpp @@ -78,6 +78,7 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes using arm_compute::DimensionRoundingType; using arm_compute::PadStrideInfo; using arm_compute::PoolingLayerInfo; + using arm_compute::Size2D; // Resolve ARM Compute layer parameters const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); @@ -94,7 +95,9 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude); - return arm_compute::PoolingLayerInfo(poolingType, descriptor.m_PoolWidth, padStrideInfo, excludePadding); + const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight); + + return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding); } arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor) @@ -114,7 +117,7 @@ arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::Per arm_compute::PermutationVector aclPerm; unsigned int start = 0; - while ((start == perm[start]) && (start < perm.GetSize())) + while ((start < perm.GetSize()) && (start == perm[start])) { ++start; } |