aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2019-04-01 14:55:18 +0100
committerPablo Marquez <pablo.tello@arm.com>2019-04-04 14:15:25 +0000
commita52e4cf36ec86b63660f5a687073fa0985384dc1 (patch)
treea02b8636c42ce477b3d541f965b7909130a702d3
parent4fbcac606efe60d0f65b7b2d853435c5a706a8a7 (diff)
downloadComputeLibrary-a52e4cf36ec86b63660f5a687073fa0985384dc1.tar.gz
COMPMID-2060: Support different qinfo in PoolingLayer
CL and Neon back ends now support different qinfos Change-Id: I638d5f258ab2f99b40659601b4c5398d2c34c43b Signed-off-by: Pablo Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/927 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
-rw-r--r--src/core/CL/cl_kernels/pooling_layer_quantized.cl42
-rw-r--r--src/core/CL/kernels/CLPoolingLayerKernel.cpp15
-rw-r--r--src/core/NEON/kernels/NEPoolingLayerKernel.cpp27
-rw-r--r--tests/validation/CL/PoolingLayer.cpp22
-rw-r--r--tests/validation/NEON/PoolingLayer.cpp22
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h44
-rw-r--r--tests/validation/reference/PoolingLayer.cpp15
-rw-r--r--tests/validation/reference/PoolingLayer.h4
8 files changed, 126 insertions, 65 deletions
diff --git a/src/core/CL/cl_kernels/pooling_layer_quantized.cl b/src/core/CL/cl_kernels/pooling_layer_quantized.cl
index 198250bfb3..919b76ed8d 100644
--- a/src/core/CL/cl_kernels/pooling_layer_quantized.cl
+++ b/src/core/CL/cl_kernels/pooling_layer_quantized.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,19 @@
*/
#include "helpers.h"
+#if defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT)
+#define VEC_FLOAT(VEC_SIZE) \
+ VEC_DATA_TYPE(float, VEC_SIZE) \
+#define VEC_INT(VEC_SIZE) VEC_DATA_TYPE(int, VEC_SIZE) #define VEC_UCHAR(VEC_SIZE) VEC_DATA_TYPE(uchar, VEC_SIZE) #define CONVERT_RTE(x, type)(convert_##type##_rte((x)))
+#define CONVERT_DOWN(x, type) CONVERT_RTE(x, type)
+#define REQUANTIZE(VEC_SIZE, input, in_offset, out_offset, in_scale, out_scale, res) \
+ { \
+ const VEC_FLOAT(VEC_SIZE) in_f32 = (CONVERT(input, VEC_FLOAT(VEC_SIZE)) - (VEC_FLOAT(VEC_SIZE))((float)in_offset)) * (VEC_FLOAT(VEC_SIZE))((float)in_scale); \
+ const VEC_FLOAT(VEC_SIZE) out_f32 = in_f32 / ((VEC_FLOAT(VEC_SIZE))(float)out_scale) + ((VEC_FLOAT(VEC_SIZE))((float)out_offset)); \
+ res = CONVERT_SAT(CONVERT_DOWN(out_f32, VEC_INT(VEC_SIZE)), VEC_UCHAR(VEC_SIZE)); \
+ }
+#endif /* defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT) */
+
#if defined(POOL_AVG)
#define POOL_OP(x, y) ((x) + (y))
#else /* defined(POOL_AVG) */
@@ -118,8 +131,22 @@ __kernel void pooling_layer_MxN_quantized_nchw(
res = round(DIV_OP(res, calculate_avg_scale(POOL_SIZE_X, POOL_SIZE_Y, MAX_WIDTH, MAX_HEIGHT, PAD_X, PAD_Y, STRIDE_X, STRIDE_Y)));
#endif /* defined(POOL_AVG) */
- // Store result
- *(__global uchar *)output.ptr = convert_uchar(res);
+ uchar result_u8 = convert_uchar(res);
+
+#if defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT)
+
+ const float result_f32 = convert_float(result_u8);
+ const float input_offset = (float)OFFSET_IN1;
+ const float input_scale = (float)SCALE_IN1;
+ const float scale_out = (float)SCALE_OUT;
+ const float offset_out = (float)OFFSET_OUT;
+ const float in_f32 = (result_f32 - input_offset) * input_scale;
+ const float out_f32 = in_f32 / scale_out + offset_out;
+ result_u8 = convert_uchar_sat(convert_int_rte(out_f32));
+
+#endif /* defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT) */
+
+ *(__global uchar *)output.ptr = result_u8;
}
int calculate_avg_scale_nhwc(const int pool_size_x, const int pool_size_y, int upper_bound_w, int upper_bound_h,
@@ -217,6 +244,11 @@ __kernel void pooling_layer_MxN_quantized_nhwc(
vdata = convert_int8(round(DIV_OP_NHWC(vdata, calculate_avg_scale_nhwc(POOL_SIZE_X, POOL_SIZE_Y, MAX_WIDTH, MAX_HEIGHT, PAD_X, PAD_Y, STRIDE_X, STRIDE_Y))));
#endif /* defined(POOL_AVG) */
+ uchar8 out_u8 = convert_uchar8(vdata);
+#if defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT)
+ REQUANTIZE(8, out_u8, OFFSET_IN1, OFFSET_OUT, SCALE_IN1, SCALE_OUT, out_u8);
+#endif /* defined(OFFSET_IN1) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_OUT) */
+
// Store result
- vstore8(convert_uchar8(vdata), 0, (__global uchar *)output.ptr);
-} \ No newline at end of file
+ vstore8(out_u8, 0, (__global uchar *)output.ptr);
+}
diff --git a/src/core/CL/kernels/CLPoolingLayerKernel.cpp b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
index 7081688bff..7ccbda9be3 100644
--- a/src/core/CL/kernels/CLPoolingLayerKernel.cpp
+++ b/src/core/CL/kernels/CLPoolingLayerKernel.cpp
@@ -78,7 +78,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
TensorInfo out_info(TensorInfo(compute_pool_shape(*input, pool_info), 1, output->data_type()));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &out_info);
}
@@ -201,6 +200,17 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
const int pool_pad_top = pad_stride_info.pad_top();
const int pool_pad_left = pad_stride_info.pad_left();
+ // Set build options
+ CLBuildOptions build_opts;
+
+ if(is_data_type_quantized_asymmetric(input->info()->data_type()) && input->info()->quantization_info() != output->info()->quantization_info())
+ {
+ build_opts.add_option("-DOFFSET_IN1=" + float_to_string_with_full_precision(input->info()->quantization_info().offset));
+ build_opts.add_option("-DOFFSET_OUT=" + float_to_string_with_full_precision(output->info()->quantization_info().offset));
+ build_opts.add_option("-DSCALE_IN1=" + float_to_string_with_full_precision(input->info()->quantization_info().scale));
+ build_opts.add_option("-DSCALE_OUT=" + float_to_string_with_full_precision(output->info()->quantization_info().scale));
+ }
+
// Check output dimensions
auto_init(input->info(), output->info(), pool_info);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), pool_info));
@@ -212,8 +222,6 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
const DataType data_type = input->info()->data_type();
- // Set build options
- CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
build_opts.add_option("-DPOOL_" + string_from_pooling_type(pool_type));
build_opts.add_option("-DSTRIDE_X=" + support::cpp11::to_string(pool_stride_x));
@@ -222,6 +230,7 @@ void CLPoolingLayerKernel::configure(const ICLTensor *input, ICLTensor *output,
build_opts.add_option("-DPAD_Y=" + support::cpp11::to_string(pool_pad_top));
build_opts.add_option("-DPOOL_SIZE_X=" + support::cpp11::to_string(pool_size_x));
build_opts.add_option("-DPOOL_SIZE_Y=" + support::cpp11::to_string(pool_size_y));
+
build_opts.add_option_if(data_type == DataType::F16, "-DFP16");
// Create kernel
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index d00a4af4fe..308fad5ffb 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -138,7 +138,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH)) != pooled_w)
|| (output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT)) != pooled_h));
}
@@ -640,6 +639,15 @@ void NEPoolingLayerKernel::pooling2_qasymm8_nchw(const Window &window_input, con
}
}
+ const QuantizationInfo &input_qinfo = _input->info()->quantization_info();
+ const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+ if(input_qinfo != output_qinfo)
+ {
+ const auto requantized_output = vquantize(vdequantize(vcombine_u8(lower_res, upper_res), input_qinfo), output_qinfo);
+ lower_res = vget_low_u8(requantized_output);
+ upper_res = vget_high_u8(requantized_output);
+ }
+
// Store result
if(pool_stride_x == 1)
{
@@ -1641,6 +1649,11 @@ void NEPoolingLayerKernel::poolingMxN_qasymm8_nchw(const Window &window_input, c
}
// Store result
+ const QuantizationInfo &input_qinfo = _input->info()->quantization_info();
+ const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+ res = (input_qinfo != output_qinfo) ? sqcvt_qasymm8_f32(scvt_f32_qasymm8(res, input_qinfo.scale, input_qinfo.offset), output_qinfo.scale,
+ output_qinfo.offset) :
+ res;
*(reinterpret_cast<uint8_t *>(output.ptr())) = res;
},
input, output);
@@ -1663,7 +1676,9 @@ void NEPoolingLayerKernel::poolingMxN_qasymm8_nhwc(const Window &window_input, c
const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
- const float32x4_t half_scale_v = vdupq_n_f32(0.5f);
+ const float32x4_t half_scale_v = vdupq_n_f32(0.5f);
+ const QuantizationInfo &input_qinfo = _input->info()->quantization_info();
+ const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
execute_window_loop(window, [&](const Coordinates & id)
{
@@ -1713,6 +1728,12 @@ void NEPoolingLayerKernel::poolingMxN_qasymm8_nhwc(const Window &window_input, c
uint8x8_t res1 = vmovn_u16(vcombine_u16(vmovn_u32(vres1), vmovn_u32(vres2)));
uint8x8_t res2 = vmovn_u16(vcombine_u16(vmovn_u32(vres3), vmovn_u32(vres4)));
+ if(input_qinfo != output_qinfo)
+ {
+ const auto requantized_output = vquantize(vdequantize(vcombine_u8(res1, res2), input_qinfo), output_qinfo);
+ res1 = vget_low_u8(requantized_output);
+ res2 = vget_high_u8(requantized_output);
+ }
// Store result
vst1_u8(output.ptr(), res1);
@@ -1733,7 +1754,7 @@ void NEPoolingLayerKernel::poolingMxN_qasymm8_nhwc(const Window &window_input, c
}
// Store result
- vst1q_u8(output.ptr(), vres);
+ vst1q_u8(output.ptr(), (input_qinfo != output_qinfo) ? vquantize(vdequantize(vres, input_qinfo), output_qinfo) : vres);
}
},
input, output);
diff --git a/tests/validation/CL/PoolingLayer.cpp b/tests/validation/CL/PoolingLayer.cpp
index 6afafe6c70..7d79f3f86c 100644
--- a/tests/validation/CL/PoolingLayer.cpp
+++ b/tests/validation/CL/PoolingLayer.cpp
@@ -74,6 +74,8 @@ framework::dataset::make("ExcludePadding", { true }));
constexpr AbsoluteTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */
constexpr AbsoluteTolerance<float> tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */
constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for 8-bit asymmetric type */
+const auto pool_data_layout_dataset = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+
} // namespace
TEST_SUITE(CL)
@@ -133,7 +135,7 @@ FIXTURE_DATA_TEST_CASE(RunSpecial, CLSpecialPoolingLayerFixture<float>, framewor
FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
framework::dataset::make("DataType",
DataType::F32))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -141,7 +143,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerFixture<float>, framework::Datase
FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
framework::dataset::make("DataType",
DataType::F32))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -151,14 +153,14 @@ TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
framework::dataset::make("DataType", DataType::F16))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16);
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
framework::dataset::make("DataType", DataType::F16))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -172,20 +174,16 @@ template <typename T>
using CLPoolingLayerQuantizedFixture = PoolingLayerValidationQuantizedFixture<CLTensor, CLAccessor, CLPoolingLayer, T>;
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
framework::dataset::make("DataType", DataType::QASYMM8))),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 127),
- QuantizationInfo(7.f / 255, 123)
- })),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
+FIXTURE_DATA_TEST_CASE(RunLarge, CLPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
framework::dataset::make("DataType", DataType::QASYMM8))),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) })),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/NEON/PoolingLayer.cpp b/tests/validation/NEON/PoolingLayer.cpp
index 9a15775f8c..129f53bef2 100644
--- a/tests/validation/NEON/PoolingLayer.cpp
+++ b/tests/validation/NEON/PoolingLayer.cpp
@@ -67,6 +67,8 @@ constexpr AbsoluteTolerance<float> tolerance_f32(0.001f); /**< Tolerance value f
constexpr AbsoluteTolerance<float> tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for float types */
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance value for comparing reference's output against implementation's output for 8-bit asymmetric type */
+const auto pool_data_layout_dataset = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+
} // namespace
TEST_SUITE(NEON)
@@ -124,7 +126,7 @@ FIXTURE_DATA_TEST_CASE(RunSpecial, NESpecialPoolingLayerFixture<float>, framewor
FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
framework::dataset::make("DataType",
DataType::F32))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);
@@ -132,7 +134,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerFixture<float>, framework::Datase
FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
framework::dataset::make("DataType",
DataType::F32))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);
@@ -143,14 +145,14 @@ TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetFPSmall,
framework::dataset::make("DataType", DataType::F16))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f16);
}
FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetFP,
framework::dataset::make("DataType", DataType::F16))),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f16);
@@ -165,20 +167,16 @@ template <typename T>
using NEPoolingLayerQuantizedFixture = PoolingLayerValidationQuantizedFixture<Tensor, Accessor, NEPoolingLayer, T>;
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerDatasetQASYMM8Small,
framework::dataset::make("DataType", DataType::QASYMM8))),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255, 127),
- QuantizationInfo(7.f / 255, 123)
- })),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPoolingLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), combine(PoolingLayerDatasetQASYMM8,
framework::dataset::make("DataType", DataType::QASYMM8))),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(1.f / 255, 0) })),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+ pool_data_layout_dataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index 3e34f98271..1813ef4c84 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -47,13 +48,16 @@ class PoolingLayerValidationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+ void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout)
{
- _quantization_info = quantization_info;
- _pool_info = pool_info;
-
- _target = compute_target(shape, pool_info, data_type, data_layout, quantization_info);
- _reference = compute_reference(shape, pool_info, data_type, quantization_info);
+ std::mt19937 gen(library->seed());
+ std::uniform_int_distribution<> offset_dis(0, 20);
+ const QuantizationInfo input_qinfo(1.f / 255.f, offset_dis(gen));
+ const QuantizationInfo output_qinfo(1.f / 255.f, offset_dis(gen));
+
+ _pool_info = pool_info;
+ _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo);
+ _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo);
}
protected:
@@ -72,7 +76,7 @@ protected:
}
TensorType compute_target(TensorShape shape, PoolingLayerInfo info,
- DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+ DataType data_type, DataLayout data_layout, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
{
// Change shape in case of NHWC.
if(data_layout == DataLayout::NHWC)
@@ -81,8 +85,9 @@ protected:
}
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type, 1, quantization_info, data_layout);
- TensorType dst;
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout);
+ const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout);
// Create and configure function
FunctionType pool_layer;
@@ -107,21 +112,19 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info,
- DataType data_type, QuantizationInfo quantization_info)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type, QuantizationInfo input_qinfo, QuantizationInfo output_qinfo)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type, 1, quantization_info };
+ SimpleTensor<T> src{ shape, data_type, 1, input_qinfo };
// Fill reference
fill(src);
- return reference::pooling_layer<T>(src, info);
+ return reference::pooling_layer<T>(src, info, output_qinfo);
}
TensorType _target{};
SimpleTensor<T> _reference{};
- QuantizationInfo _quantization_info{};
PoolingLayerInfo _pool_info{};
};
@@ -133,7 +136,7 @@ public:
void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
{
PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding),
- data_type, data_layout, QuantizationInfo());
+ data_type, data_layout);
}
};
@@ -142,11 +145,10 @@ class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGene
{
public:
template <typename...>
- void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type,
- QuantizationInfo quantization_info, DataLayout data_layout = DataLayout::NCHW)
+ void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
{
PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, pad_stride_info, exclude_padding),
- data_type, data_layout, quantization_info);
+ data_type, data_layout);
}
};
@@ -157,7 +159,7 @@ public:
template <typename...>
void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
{
- PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW, QuantizationInfo());
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW);
}
};
@@ -168,7 +170,7 @@ public:
template <typename...>
void setup(TensorShape shape, PoolingType pool_type, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
{
- PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type), data_type, DataLayout::NCHW, QuantizationInfo());
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type), data_type, DataLayout::NCHW);
}
};
diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp
index e617c939b0..f4112a486d 100644
--- a/tests/validation/reference/PoolingLayer.cpp
+++ b/tests/validation/reference/PoolingLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,8 +38,9 @@ namespace reference
using namespace arm_compute::misc::shape_calculator;
template <typename T>
-SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info)
+SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo)
{
+ ARM_COMPUTE_UNUSED(output_qinfo); // requantization occurs in pooling_layer<uint8_t>
ARM_COMPUTE_ERROR_ON(info.is_global_pooling() && (src.shape().x() != src.shape().y()));
// Create reference
@@ -152,16 +153,16 @@ SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo
}
template <>
-SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, const PoolingLayerInfo &info)
+SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo)
{
SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
- SimpleTensor<float> dst_tmp = pooling_layer<float>(src_tmp, info);
- SimpleTensor<uint8_t> dst = convert_to_asymmetric(dst_tmp, src.quantization_info());
+ SimpleTensor<float> dst_tmp = pooling_layer<float>(src_tmp, info, output_qinfo);
+ SimpleTensor<uint8_t> dst = convert_to_asymmetric(dst_tmp, output_qinfo);
return dst;
}
-template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info);
-template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info);
+template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo);
+template SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/PoolingLayer.h b/tests/validation/reference/PoolingLayer.h
index 00977896a8..1c0b7ff40d 100644
--- a/tests/validation/reference/PoolingLayer.h
+++ b/tests/validation/reference/PoolingLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info);
+SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo);
} // namespace reference
} // namespace validation
} // namespace test