aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/DequantizationLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/DequantizationLayer.cpp')
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp21
1 files changed, 15 insertions, 6 deletions
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 16f25c4427..7dd36402b3 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -36,18 +36,27 @@ namespace reference
namespace
{
template <typename TOut>
-TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo)
+TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo, DataType dt)
{
- return static_cast<TOut>(dequantize_qsymm8(val, qinfo));
+ if(dt == DataType::QSYMM8 || dt == DataType::QSYMM8_PER_CHANNEL)
+ {
+ return static_cast<TOut>(dequantize_qsymm8(val, qinfo));
+ }
+ else
+ {
+ return static_cast<TOut>(dequantize_qasymm8_signed(val, qinfo));
+ }
}
template <typename TOut>
-TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo)
+TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo, DataType dt)
{
+ ARM_COMPUTE_UNUSED(dt);
return static_cast<TOut>(dequantize_qasymm8(val, qinfo));
}
template <typename TOut>
-TOut dequantize(int16_t val, const UniformQuantizationInfo qinfo)
+TOut dequantize(int16_t val, const UniformQuantizationInfo qinfo, DataType dt)
{
+ ARM_COMPUTE_UNUSED(dt);
return static_cast<TOut>(dequantize_qsymm16(val, qinfo));
}
} // namespace
@@ -77,7 +86,7 @@ SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
// Dequantize slice
for(int s = 0; s < WH; ++s)
{
- dst[idx + s] = dequantize<TOut>(static_cast<TIn>(src[idx + s]), channel_qinfo);
+ dst[idx + s] = dequantize<TOut>(static_cast<TIn>(src[idx + s]), channel_qinfo, src_data_type);
}
}
}
@@ -89,7 +98,7 @@ SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
for(int i = 0; i < src.num_elements(); ++i)
{
- dst[i] = static_cast<TOut>(dequantize<TOut>(static_cast<TIn>(src[i]), quantization_info));
+ dst[i] = static_cast<TOut>(dequantize<TOut>(static_cast<TIn>(src[i]), quantization_info, src_data_type));
}
}