aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/PoolingLayerFixture.h
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-10-27 10:56:31 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2021-01-20 16:39:53 +0000
commitd556d7bafe6ad943f4aca0f5285ada7b8ce497f7 (patch)
tree11c7077daf97b46c47a4eac821830b37a7ce9e76 /tests/validation/fixtures/PoolingLayerFixture.h
parent7d61ff041826782d14e67b7f5b7a2864905ff38b (diff)
downloadComputeLibrary-d556d7bafe6ad943f4aca0f5285ada7b8ce497f7.tar.gz
Integrate improved pooling layer on NEON
Resolves COMPMID-4035 Change-Id: I559f8c4208fba9193dfe5012f03ddaf26c746215 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4855 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h16
1 files changed, 5 insertions, 11 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 09054aa845..af078d4ce3 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -46,16 +46,9 @@ class PoolingLayerValidationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false)
+ void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false,
+ QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo())
{
- std::mt19937 gen(library->seed());
- std::uniform_int_distribution<> offset_dis(0, 20);
- const float scale = data_type == DataType::QASYMM8_SIGNED ? 1.f / 127.f : 1.f / 255.f;
- const int scale_in = data_type == DataType::QASYMM8_SIGNED ? -offset_dis(gen) : offset_dis(gen);
- const int scale_out = data_type == DataType::QASYMM8_SIGNED ? -offset_dis(gen) : offset_dis(gen);
- const QuantizationInfo input_qinfo(scale, scale_in);
- const QuantizationInfo output_qinfo(scale, scale_out);
-
_pool_info = pool_info;
_target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
_reference = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
@@ -180,10 +173,11 @@ class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGene
{
public:
template <typename...>
- void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
+ void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW,
+ QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo())
{
PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
- data_type, data_layout);
+ data_type, data_layout, false, input_qinfo, output_qinfo);
}
};