aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/DequantizationLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/DequantizationLayerFixture.h')
-rw-r--r--tests/validation/fixtures/DequantizationLayerFixture.h55
1 files changed, 40 insertions, 15 deletions
diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h
index 2e3712dff2..15f3711189 100644
--- a/tests/validation/fixtures/DequantizationLayerFixture.h
+++ b/tests/validation/fixtures/DequantizationLayerFixture.h
@@ -47,10 +47,11 @@ class DequantizationValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, QuantizationInfo qinfo)
+ void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype)
{
- _target = compute_target(shape, data_type, qinfo);
- _reference = compute_reference(shape, data_type, qinfo);
+ _quantization_info = generate_quantization_info(src_data_type);
+ _target = compute_target(shape, src_data_type, dst_datatype);
+ _reference = compute_reference(shape, src_data_type);
}
protected:
@@ -60,11 +61,11 @@ protected:
library->fill_tensor_uniform(tensor, 0);
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo)
+ TensorType compute_target(const TensorShape &shape, DataType src_data_type, DataType dst_datatype)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, DataType::QASYMM8, 1, qinfo);
- TensorType dst = create_tensor<TensorType>(shape, data_type);
+ TensorType src = create_tensor<TensorType>(shape, src_data_type, 1, _quantization_info);
+ TensorType dst = create_tensor<TensorType>(shape, dst_datatype);
// Create and configure function
FunctionType dequantization_layer;
@@ -89,19 +90,43 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType src_data_type)
{
- // Create reference
- SimpleTensor<uint8_t> src{ shape, DataType::QASYMM8, 1, qinfo };
-
- // Fill reference
- fill(src);
+ if(is_data_type_quantized_asymmetric(src_data_type))
+ {
+ SimpleTensor<uint8_t> src{ shape, src_data_type, 1, _quantization_info };
+ fill(src);
+ return reference::dequantization_layer<T>(src);
+ }
+ else
+ {
+ SimpleTensor<int8_t> src{ shape, src_data_type, 1, _quantization_info };
+ fill(src);
+ return reference::dequantization_layer<T>(src);
+ }
+ }
- return reference::dequantization_layer<T>(src);
+protected:
+ QuantizationInfo generate_quantization_info(DataType data_type)
+ {
+ std::uniform_int_distribution<> distribution(1, 127);
+ std::mt19937 gen(library.get()->seed());
+
+ switch(data_type)
+ {
+ case DataType::QSYMM8:
+ return QuantizationInfo(1.f / distribution(gen));
+ case DataType::QASYMM8:
+ return QuantizationInfo(1.f / distribution(gen), distribution(gen));
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
+protected:
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+ QuantizationInfo _quantization_info{};
};
} // namespace validation
} // namespace test