From 3d13af8a39f408318328a95d5329bc17fd923438 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 4 Jun 2019 13:04:16 +0100 Subject: COMPMID-2235: Extend type support for CL/NEON DequantizationLayer. Adds support for: - QSYMM8 Change-Id: Ia0b839fc844ce0f968dad1b69a001f9a660dbcd5 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1378 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Manuel Bottini Reviewed-by: Michalis Spyrou --- .../fixtures/DequantizationLayerFixture.h | 55 ++++++++++++++++------ 1 file changed, 40 insertions(+), 15 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h index 2e3712dff2..15f3711189 100644 --- a/tests/validation/fixtures/DequantizationLayerFixture.h +++ b/tests/validation/fixtures/DequantizationLayerFixture.h @@ -47,10 +47,11 @@ class DequantizationValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, QuantizationInfo qinfo) + void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype) { - _target = compute_target(shape, data_type, qinfo); - _reference = compute_reference(shape, data_type, qinfo); + _quantization_info = generate_quantization_info(src_data_type); + _target = compute_target(shape, src_data_type, dst_datatype); + _reference = compute_reference(shape, src_data_type); } protected: @@ -60,11 +61,11 @@ protected: library->fill_tensor_uniform(tensor, 0); } - TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo) + TensorType compute_target(const TensorShape &shape, DataType src_data_type, DataType dst_datatype) { // Create tensors - TensorType src = create_tensor(shape, DataType::QASYMM8, 1, qinfo); - TensorType dst = create_tensor(shape, data_type); + TensorType src = create_tensor(shape, src_data_type, 1, _quantization_info); + TensorType dst = create_tensor(shape, dst_datatype); // Create and configure function FunctionType dequantization_layer; @@ -89,19 +90,43 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo) + SimpleTensor compute_reference(const TensorShape &shape, DataType src_data_type) { - // Create reference - SimpleTensor src{ shape, DataType::QASYMM8, 1, qinfo }; - - // Fill reference - fill(src); + if(is_data_type_quantized_asymmetric(src_data_type)) + { + SimpleTensor src{ shape, src_data_type, 1, _quantization_info }; + fill(src); + return reference::dequantization_layer(src); + } + else + { + SimpleTensor src{ shape, src_data_type, 1, _quantization_info }; + fill(src); + return reference::dequantization_layer(src); + } + } - return reference::dequantization_layer(src); +protected: + QuantizationInfo generate_quantization_info(DataType data_type) + { + std::uniform_int_distribution<> distribution(1, 127); + std::mt19937 gen(library.get()->seed()); + + switch(data_type) + { + case DataType::QSYMM8: + return QuantizationInfo(1.f / distribution(gen)); + case DataType::QASYMM8: + return QuantizationInfo(1.f / distribution(gen), distribution(gen)); + default: + ARM_COMPUTE_ERROR("Unsupported data type"); + } } - TensorType _target{}; - SimpleTensor _reference{}; +protected: + TensorType _target{}; + SimpleTensor _reference{}; + QuantizationInfo _quantization_info{}; }; } // namespace validation } // namespace test -- cgit v1.2.1