aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-04 13:04:16 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-24 14:56:23 +0000
commit3d13af8a39f408318328a95d5329bc17fd923438 (patch)
treeb0d9c82062e229f8938d2c9f762ee67758196bf3 /tests/validation/reference
parentdb09b3783ff9af67c6d373b12aa9a6aff3c5d0f1 (diff)
downloadComputeLibrary-3d13af8a39f408318328a95d5329bc17fd923438.tar.gz
COMPMID-2235: Extend type support for CL/NEON DequantizationLayer.
Adds support for: - QSYMM8 Change-Id: Ia0b839fc844ce0f968dad1b69a001f9a660dbcd5 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1378 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/validation/reference')
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp74
-rw-r--r--tests/validation/reference/DequantizationLayer.h4
2 files changed, 69 insertions, 9 deletions
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 286a609d79..d07371c883 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -23,6 +23,8 @@
*/
#include "DequantizationLayer.h"
+#include "Permute.h"
+
namespace arm_compute
{
namespace test
@@ -31,24 +33,82 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src)
+namespace
+{
+template <typename TOut>
+TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo)
+{
+ return static_cast<TOut>(dequantize_qsymm8(val, qinfo));
+}
+template <typename TOut>
+TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo)
+{
+ return static_cast<TOut>(dequantize_qasymm8(val, qinfo));
+}
+
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src)
{
- const DataType dst_data_type = std::is_same<T, float>::value ? DataType::F32 : DataType::F16;
- const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
+ const DataType src_data_type = src.data_type();
+ const DataType dst_data_type = std::is_same<TOut, float>::value ? DataType::F32 : DataType::F16;
- SimpleTensor<T> dst{ src.shape(), dst_data_type };
+ SimpleTensor<TOut> dst{ src.shape(), dst_data_type };
- for(int i = 0; i < src.num_elements(); ++i)
+ if(src_data_type == DataType::QSYMM8_PER_CHANNEL)
{
- dst[i] = static_cast<T>(dequantize_qasymm8(src[i], quantization_info));
+ const int WH = src.shape().x() * src.shape().y();
+ const int C = src.shape().z();
+ const int N = src.shape().total_size() / (WH * C);
+
+ const std::vector<float> qscales = src.quantization_info().scale();
+
+ for(int n = 0; n < N; ++n)
+ {
+ for(int c = 0; c < C; ++c)
+ {
+ const size_t idx = n * C * WH + c * WH;
+ const UniformQuantizationInfo channel_qinfo = { qscales[c], 0 };
+
+ // Dequantize slice
+ for(int s = 0; s < WH; ++s)
+ {
+ dst[idx + s] = dequantize<TOut>(src[idx + s], channel_qinfo);
+ }
+ }
+ }
+ }
+ else
+ {
+ const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
+ ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8);
+
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ dst[i] = static_cast<TOut>(dequantize<TOut>(src[i], quantization_info));
+ }
}
return dst;
}
+} // namespace
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
+{
+ if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL)
+ {
+ SimpleTensor<TIn> src_nchw = reference::permute<TIn>(src, PermutationVector(1U, 2U, 0U));
+ return reference::permute<TOut>(dequantization_layer_nchw<TOut>(src_nchw), PermutationVector(2U, 0U, 1U));
+ }
+ else
+ {
+ return dequantization_layer_nchw<TOut>(src);
+ }
+}
template SimpleTensor<half> dequantization_layer(const SimpleTensor<uint8_t> &src);
template SimpleTensor<float> dequantization_layer(const SimpleTensor<uint8_t> &src);
+template SimpleTensor<half> dequantization_layer(const SimpleTensor<int8_t> &src);
+template SimpleTensor<float> dequantization_layer(const SimpleTensor<int8_t> &src);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/DequantizationLayer.h b/tests/validation/reference/DequantizationLayer.h
index 1d0e54b442..8c780849fd 100644
--- a/tests/validation/reference/DequantizationLayer.h
+++ b/tests/validation/reference/DequantizationLayer.h
@@ -35,8 +35,8 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src);
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src);
} // namespace reference
} // namespace validation
} // namespace test