aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h12
1 files changed, 7 insertions, 5 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 11a491faa7..6b7cbba92e 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -31,7 +31,7 @@
#include "tests/validation/Validation.h"
#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/reference/ArithmeticOperations.h"
-#include "tests/validation/reference/QuantizationLayer.h"
+#include "tests/validation/reference/DequantizationLayer.h"
#include <cstdint>
#include <vector>
@@ -485,7 +485,7 @@ public:
const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset);
TensorFillInfo finfo;
_target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
}
protected:
@@ -495,14 +495,16 @@ protected:
return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
}
- SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate)
+ SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
{
+ QuantizationInfo s32_ref_output_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0, dynamic_qinfo);
+
SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo);
+ s32_ref_output.quantization_info(s32_ref_output_quant_info);
SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32);
- QuantizationInfo dst_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0);
- f32_ref_output = reference::quantization_layer<int32_t, float>(s32_ref_output, DataType::F32, dst_quant_info);
+ f32_ref_output = reference::dequantization_layer<float, int32_t>(s32_ref_output);
if (accumulate)
{