aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-04 13:04:16 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-24 14:56:23 +0000
commit3d13af8a39f408318328a95d5329bc17fd923438 (patch)
treeb0d9c82062e229f8938d2c9f762ee67758196bf3 /tests/validation/fixtures
parentdb09b3783ff9af67c6d373b12aa9a6aff3c5d0f1 (diff)
downloadComputeLibrary-3d13af8a39f408318328a95d5329bc17fd923438.tar.gz
COMPMID-2235: Extend type support for CL/NEON DequantizationLayer.
Adds support for: - QSYMM8 Change-Id: Ia0b839fc844ce0f968dad1b69a001f9a660dbcd5 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1378 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/DequantizationLayerFixture.h55
1 files changed, 40 insertions, 15 deletions
diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h
index 2e3712dff2..15f3711189 100644
--- a/tests/validation/fixtures/DequantizationLayerFixture.h
+++ b/tests/validation/fixtures/DequantizationLayerFixture.h
@@ -47,10 +47,11 @@ class DequantizationValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, QuantizationInfo qinfo)
+ void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype)
{
- _target = compute_target(shape, data_type, qinfo);
- _reference = compute_reference(shape, data_type, qinfo);
+ _quantization_info = generate_quantization_info(src_data_type);
+ _target = compute_target(shape, src_data_type, dst_datatype);
+ _reference = compute_reference(shape, src_data_type);
}
protected:
@@ -60,11 +61,11 @@ protected:
library->fill_tensor_uniform(tensor, 0);
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo)
+ TensorType compute_target(const TensorShape &shape, DataType src_data_type, DataType dst_datatype)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, DataType::QASYMM8, 1, qinfo);
- TensorType dst = create_tensor<TensorType>(shape, data_type);
+ TensorType src = create_tensor<TensorType>(shape, src_data_type, 1, _quantization_info);
+ TensorType dst = create_tensor<TensorType>(shape, dst_datatype);
// Create and configure function
FunctionType dequantization_layer;
@@ -89,19 +90,43 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType src_data_type)
{
- // Create reference
- SimpleTensor<uint8_t> src{ shape, DataType::QASYMM8, 1, qinfo };
-
- // Fill reference
- fill(src);
+ if(is_data_type_quantized_asymmetric(src_data_type))
+ {
+ SimpleTensor<uint8_t> src{ shape, src_data_type, 1, _quantization_info };
+ fill(src);
+ return reference::dequantization_layer<T>(src);
+ }
+ else
+ {
+ SimpleTensor<int8_t> src{ shape, src_data_type, 1, _quantization_info };
+ fill(src);
+ return reference::dequantization_layer<T>(src);
+ }
+ }
- return reference::dequantization_layer<T>(src);
+protected:
+ QuantizationInfo generate_quantization_info(DataType data_type)
+ {
+ std::uniform_int_distribution<> distribution(1, 127);
+ std::mt19937 gen(library.get()->seed());
+
+ switch(data_type)
+ {
+ case DataType::QSYMM8:
+ return QuantizationInfo(1.f / distribution(gen));
+ case DataType::QASYMM8:
+ return QuantizationInfo(1.f / distribution(gen), distribution(gen));
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
+protected:
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+ QuantizationInfo _quantization_info{};
};
} // namespace validation
} // namespace test