aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/DequantizationLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/DequantizationLayer.cpp')
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp74
1 files changed, 67 insertions, 7 deletions
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 286a609d79..d07371c883 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -23,6 +23,8 @@
*/
#include "DequantizationLayer.h"
+#include "Permute.h"
+
namespace arm_compute
{
namespace test
@@ -31,24 +33,82 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src)
+namespace
+{
+template <typename TOut>
+TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo)
+{
+ return static_cast<TOut>(dequantize_qsymm8(val, qinfo));
+}
+template <typename TOut>
+TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo)
+{
+ return static_cast<TOut>(dequantize_qasymm8(val, qinfo));
+}
+
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src)
{
- const DataType dst_data_type = std::is_same<T, float>::value ? DataType::F32 : DataType::F16;
- const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
+ const DataType src_data_type = src.data_type();
+ const DataType dst_data_type = std::is_same<TOut, float>::value ? DataType::F32 : DataType::F16;
- SimpleTensor<T> dst{ src.shape(), dst_data_type };
+ SimpleTensor<TOut> dst{ src.shape(), dst_data_type };
- for(int i = 0; i < src.num_elements(); ++i)
+ if(src_data_type == DataType::QSYMM8_PER_CHANNEL)
{
- dst[i] = static_cast<T>(dequantize_qasymm8(src[i], quantization_info));
+ const int WH = src.shape().x() * src.shape().y();
+ const int C = src.shape().z();
+ const int N = src.shape().total_size() / (WH * C);
+
+ const std::vector<float> qscales = src.quantization_info().scale();
+
+ for(int n = 0; n < N; ++n)
+ {
+ for(int c = 0; c < C; ++c)
+ {
+ const size_t idx = n * C * WH + c * WH;
+ const UniformQuantizationInfo channel_qinfo = { qscales[c], 0 };
+
+ // Dequantize slice
+ for(int s = 0; s < WH; ++s)
+ {
+ dst[idx + s] = dequantize<TOut>(src[idx + s], channel_qinfo);
+ }
+ }
+ }
+ }
+ else
+ {
+ const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
+ ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8);
+
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ dst[i] = static_cast<TOut>(dequantize<TOut>(src[i], quantization_info));
+ }
}
return dst;
}
+} // namespace
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
+{
+ if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL)
+ {
+ SimpleTensor<TIn> src_nchw = reference::permute<TIn>(src, PermutationVector(1U, 2U, 0U));
+ return reference::permute<TOut>(dequantization_layer_nchw<TOut>(src_nchw), PermutationVector(2U, 0U, 1U));
+ }
+ else
+ {
+ return dequantization_layer_nchw<TOut>(src);
+ }
+}
template SimpleTensor<half> dequantization_layer(const SimpleTensor<uint8_t> &src);
template SimpleTensor<float> dequantization_layer(const SimpleTensor<uint8_t> &src);
+template SimpleTensor<half> dequantization_layer(const SimpleTensor<int8_t> &src);
+template SimpleTensor<float> dequantization_layer(const SimpleTensor<int8_t> &src);
} // namespace reference
} // namespace validation
} // namespace test