aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/PoolingLayerFixture.h
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2019-04-01 14:55:18 +0100
committerPablo Marquez <pablo.tello@arm.com>2019-04-04 14:15:25 +0000
commita52e4cf36ec86b63660f5a687073fa0985384dc1 (patch)
treea02b8636c42ce477b3d541f965b7909130a702d3 /tests/validation/fixtures/PoolingLayerFixture.h
parent4fbcac606efe60d0f65b7b2d853435c5a706a8a7 (diff)
downloadComputeLibrary-a52e4cf36ec86b63660f5a687073fa0985384dc1.tar.gz
COMPMID-2060: Support different qinfo in PoolingLayer
CL and Neon back ends now support different qinfos Change-Id: I638d5f258ab2f99b40659601b4c5398d2c34c43b Signed-off-by: Pablo Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/927 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/fixtures/PoolingLayerFixture.h')
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h44
1 files changed, 23 insertions, 21 deletions
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 3e34f98271..1813ef4c84 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -47,13 +48,16 @@ class PoolingLayerValidationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+ void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout)
{
- _quantization_info = quantization_info;
- _pool_info = pool_info;
-
- _target = compute_target(shape, pool_info, data_type, data_layout, quantization_info);
- _reference = compute_reference(shape, pool_info, data_type, quantization_info);
+ std::mt19937 gen(library->seed());
+ std::uniform_int_distribution<> offset_dis(0, 20);
+ const QuantizationInfo input_qinfo(1.f / 255.f, offset_dis(gen));
+ const QuantizationInfo output_qinfo(1.f / 255.f, offset_dis(gen));
+
+ _pool_info = pool_info;
+ _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo);
+ _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo);
}
protected:
@@ -72,7 +76,7 @@ protected:
}
TensorType compute_target(TensorShape shape, PoolingLayerInfo info,
- DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+ DataType data_type, DataLayout data_layout, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
{
// Change shape in case of NHWC.
if(data_layout == DataLayout::NHWC)
@@ -81,8 +85,9 @@ protected:
}
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type, 1, quantization_info, data_layout);
- TensorType dst;
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout);
+ const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout);
// Create and configure function
FunctionType pool_layer;
@@ -107,21 +112,19 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info,
- DataType data_type, QuantizationInfo quantization_info)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type, 1, quantization_info };
+ SimpleTensor<T> src{ shape, data_type, 1, input_qinfo };
// Fill reference
fill(src);
- return reference::pooling_layer<T>(src, info);
+ return reference::pooling_layer<T>(src, info, output_qinfo);
}
TensorType _target{};
SimpleTensor<T> _reference{};
- QuantizationInfo _quantization_info{};
PoolingLayerInfo _pool_info{};
};
@@ -133,7 +136,7 @@ public:
void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
{
PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding),
- data_type, data_layout, QuantizationInfo());
+ data_type, data_layout);
}
};
@@ -142,11 +145,10 @@ class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGene
{
public:
template <typename...>
- void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type,
- QuantizationInfo quantization_info, DataLayout data_layout = DataLayout::NCHW)
+ void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
{
PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding),
- data_type, data_layout, quantization_info);
+ data_type, data_layout);
}
};
@@ -157,7 +159,7 @@ public:
template <typename...>
void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
{
- PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW, QuantizationInfo());
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW);
}
};
@@ -168,7 +170,7 @@ public:
template <typename...>
void setup(TensorShape shape, PoolingType pool_type, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
{
- PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type), data_type, DataLayout::NCHW, QuantizationInfo());
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type), data_type, DataLayout::NCHW);
}
};