aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/DequantizationLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/DequantizationLayer.cpp')
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp18
1 files changed, 2 insertions, 16 deletions
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 74686bdaaf..69a49a3d6d 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -50,9 +50,9 @@ TOut dequantize(int16_t val, const UniformQuantizationInfo qinfo)
{
return static_cast<TOut>(dequantize_qsymm16(val, qinfo));
}
-
+} // namespace
template <typename TOut, typename TIn>
-SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src)
+SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
{
const DataType src_data_type = src.data_type();
const DataType dst_data_type = std::is_same<TOut, float>::value ? DataType::F32 : DataType::F16;
@@ -97,20 +97,6 @@ SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src)
return dst;
}
-} // namespace
-template <typename TOut, typename TIn>
-SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
-{
- if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL)
- {
- SimpleTensor<TIn> src_nchw = reference::permute<TIn>(src, PermutationVector(1U, 2U, 0U));
- return reference::permute<TOut>(dequantization_layer_nchw<TOut>(src_nchw), PermutationVector(2U, 0U, 1U));
- }
- else
- {
- return dequantization_layer_nchw<TOut>(src);
- }
-}
template SimpleTensor<half> dequantization_layer(const SimpleTensor<uint8_t> &src);
template SimpleTensor<float> dequantization_layer(const SimpleTensor<uint8_t> &src);