diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/DequantizationLayerFixture.h | 55 |
1 files changed, 40 insertions, 15 deletions
diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h index 2e3712dff2..15f3711189 100644 --- a/tests/validation/fixtures/DequantizationLayerFixture.h +++ b/tests/validation/fixtures/DequantizationLayerFixture.h @@ -47,10 +47,11 @@ class DequantizationValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, QuantizationInfo qinfo) + void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype) { - _target = compute_target(shape, data_type, qinfo); - _reference = compute_reference(shape, data_type, qinfo); + _quantization_info = generate_quantization_info(src_data_type); + _target = compute_target(shape, src_data_type, dst_datatype); + _reference = compute_reference(shape, src_data_type); } protected: @@ -60,11 +61,11 @@ protected: library->fill_tensor_uniform(tensor, 0); } - TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo) + TensorType compute_target(const TensorShape &shape, DataType src_data_type, DataType dst_datatype) { // Create tensors - TensorType src = create_tensor<TensorType>(shape, DataType::QASYMM8, 1, qinfo); - TensorType dst = create_tensor<TensorType>(shape, data_type); + TensorType src = create_tensor<TensorType>(shape, src_data_type, 1, _quantization_info); + TensorType dst = create_tensor<TensorType>(shape, dst_datatype); // Create and configure function FunctionType dequantization_layer; @@ -89,19 +90,43 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType src_data_type) { - // Create reference - SimpleTensor<uint8_t> src{ shape, DataType::QASYMM8, 1, qinfo }; - - // Fill reference - fill(src); + if(is_data_type_quantized_asymmetric(src_data_type)) + { + SimpleTensor<uint8_t> src{ shape, src_data_type, 1, _quantization_info }; + fill(src); + return reference::dequantization_layer<T>(src); + } + else + { + SimpleTensor<int8_t> src{ shape, src_data_type, 1, _quantization_info }; + fill(src); + return reference::dequantization_layer<T>(src); + } + } - return reference::dequantization_layer<T>(src); +protected: + QuantizationInfo generate_quantization_info(DataType data_type) + { + std::uniform_int_distribution<> distribution(1, 127); + std::mt19937 gen(library.get()->seed()); + + switch(data_type) + { + case DataType::QSYMM8: + return QuantizationInfo(1.f / distribution(gen)); + case DataType::QASYMM8: + return QuantizationInfo(1.f / distribution(gen), distribution(gen)); + default: + ARM_COMPUTE_ERROR("Unsupported data type"); + } } - TensorType _target{}; - SimpleTensor<T> _reference{}; +protected: + TensorType _target{}; + SimpleTensor<T> _reference{}; + QuantizationInfo _quantization_info{}; }; } // namespace validation } // namespace test |