From 9104cd559222b98f2b21f14d4fd561ed4a4e9bc2 Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Wed, 6 Apr 2022 16:19:31 +0100 Subject: Add support for int8 CpuPool3d - Add implementation for the CPU pooling 3d layer. - NDHWC data layout support. - Support QASYMM8/QASYMM8_SIGNED. - Add Pooling helper file for Pool3d/2d common functions. Resolves COMPMID-4668 Change-Id: Iadf042036b076099c2353d6e2fe9fc623bc263d8 Signed-off-by: Adnan AlSinan Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7387 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- tests/validation/fixtures/Pooling3dLayerFixture.h | 37 +++++++++++++++-------- 1 file changed, 25 insertions(+), 12 deletions(-) (limited to 'tests/validation/fixtures/Pooling3dLayerFixture.h') diff --git a/tests/validation/fixtures/Pooling3dLayerFixture.h b/tests/validation/fixtures/Pooling3dLayerFixture.h index c1b3519e80..563f1dcced 100644 --- a/tests/validation/fixtures/Pooling3dLayerFixture.h +++ b/tests/validation/fixtures/Pooling3dLayerFixture.h @@ -46,10 +46,10 @@ class Pooling3dLayerValidationGenericFixture : public framework::Fixture { public: template - void setup(TensorShape shape, Pooling3dLayerInfo pool_info, DataType data_type) + void setup(TensorShape shape, Pooling3dLayerInfo pool_info, DataType data_type, QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo()) { - _target = compute_target(shape, pool_info, data_type); - _reference = compute_reference(shape, pool_info, data_type); + _target = compute_target(shape, pool_info, data_type, input_qinfo, output_qinfo); + _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo); } protected: @@ -68,17 +68,17 @@ protected: } else // data type is quantized_asymmetric { - ARM_COMPUTE_ERROR("Passed Type Not Supported"); + library->fill_tensor_uniform(tensor, 0); } } TensorType compute_target(TensorShape shape, Pooling3dLayerInfo info, - DataType data_type) + DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo) { // Create tensors - TensorType src = create_tensor(shape, data_type, 1, QuantizationInfo(), DataLayout::NDHWC); + TensorType src = create_tensor(shape, data_type, 1, input_qinfo, DataLayout::NDHWC); const TensorShape dst_shape = misc::shape_calculator::compute_pool3d_shape((src.info()->tensor_shape()), info); - TensorType dst = create_tensor(dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NDHWC); + TensorType dst = create_tensor(dst_shape, data_type, 1, output_qinfo, DataLayout::NDHWC); // Create and configure function FunctionType pool_layer; @@ -103,17 +103,17 @@ protected: return dst; } - SimpleTensor compute_reference(TensorShape shape, Pooling3dLayerInfo info, DataType data_type) + SimpleTensor compute_reference(TensorShape shape, Pooling3dLayerInfo info, DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo) { // Create reference - SimpleTensor src(shape, data_type, 1, QuantizationInfo(), DataLayout::NDHWC); + SimpleTensor src(shape, data_type, 1, input_qinfo, DataLayout::NDHWC); // Fill reference fill(src); - return reference::pooling_3d_layer(src, info); + return reference::pooling_3d_layer(src, info, output_qinfo); } - TensorType _target{}; - SimpleTensor _reference{}; + TensorType _target{}; + SimpleTensor _reference{}; }; template @@ -128,6 +128,19 @@ public: } }; +template +class Pooling3dLayerValidationQuantizedFixture : public Pooling3dLayerValidationGenericFixture +{ +public: + template + void setup(TensorShape shape, PoolingType pool_type, Size3D pool_size, Size3D stride, Padding3D padding, bool exclude_padding, DataType data_type, + QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo()) + { + Pooling3dLayerValidationGenericFixture::setup(shape, Pooling3dLayerInfo(pool_type, pool_size, stride, padding, exclude_padding), + data_type, input_qinfo, output_qinfo); + } +}; + template class Pooling3dLayerGlobalValidationFixture : public Pooling3dLayerValidationGenericFixture { -- cgit v1.2.1