diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 56 |
1 files changed, 44 insertions, 12 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index 2dc6d2a2b2..e476eb38a1 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -187,17 +187,10 @@ arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout) arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor, bool fpMixedPrecision) { - using arm_compute::PoolingType; - using arm_compute::DimensionRoundingType; - using arm_compute::PadStrideInfo; - using arm_compute::PoolingLayerInfo; - using arm_compute::Size2D; - using arm_compute::DataLayout; - // Resolve ARM Compute layer parameters. - const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); + const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); - const DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout); + const arm_compute::DataLayout dataLayout = ConvertDataLayout(descriptor.m_DataLayout); bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0); //use specific constructor if global pooling @@ -206,9 +199,9 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes return arm_compute::PoolingLayerInfo(poolingType, dataLayout); } - const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType( + const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType( descriptor.m_OutputShapeRounding); - const PadStrideInfo padStrideInfo(descriptor.m_StrideX, + const arm_compute::PadStrideInfo padStrideInfo(descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_PadLeft, descriptor.m_PadRight, @@ -218,12 +211,51 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude); - const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight); + const arm_compute::Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight); return arm_compute::PoolingLayerInfo(poolingType, poolSize, dataLayout, padStrideInfo, excludePadding, fpMixedPrecision); } +arm_compute::Pooling3dLayerInfo BuildArmComputePooling3dLayerInfo(const Pooling3dDescriptor& descriptor, + bool fpMixedPrecision) +{ + const arm_compute::PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType); + + bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0 && descriptor.m_StrideZ==0); + //use specific constructor if global pooling + if(isGlobalPooling) + { + return arm_compute::Pooling3dLayerInfo(poolingType); + } + + const arm_compute::Size3D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight, descriptor.m_PoolDepth); + + const arm_compute::Size3D stride(descriptor.m_StrideX, + descriptor.m_StrideY, + descriptor.m_StrideZ); + + const arm_compute::Padding3D padding(descriptor.m_PadLeft, + descriptor.m_PadRight, + descriptor.m_PadTop, + descriptor.m_PadBottom, + descriptor.m_PadFront, + descriptor.m_PadBack); + + const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude); + + const arm_compute::DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType( + descriptor.m_OutputShapeRounding); + + return arm_compute::Pooling3dLayerInfo(poolingType, + poolSize, + stride, + padding, + excludePadding, + fpMixedPrecision, + rounding); +} + arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor) { const arm_compute::NormType normType = |