From 3d13af8a39f408318328a95d5329bc17fd923438 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 4 Jun 2019 13:04:16 +0100 Subject: COMPMID-2235: Extend type support for CL/NEON DequantizationLayer. Adds support for: - QSYMM8 Change-Id: Ia0b839fc844ce0f968dad1b69a001f9a660dbcd5 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1378 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Manuel Bottini Reviewed-by: Michalis Spyrou --- .../NEON/kernels/NEDequantizationLayerKernel.cpp | 74 ++++++++++++++++++++-- 1 file changed, 67 insertions(+), 7 deletions(-) (limited to 'src/core/NEON/kernels') diff --git a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp index a6dc0977d2..bf0a2ca7bf 100644 --- a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp @@ -42,7 +42,7 @@ namespace Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8); if(output->tensor_shape().total_size() > 0) { @@ -95,9 +95,11 @@ inline void store_result(float16_t *ptr, const float32x4x4_t &v) #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ template -void run_dequantization(const ITensor *input, ITensor *output, const Window &window) +void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window) { - const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); + const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); + const float scale = qinfo.scale; + const int32_t offset = qinfo.offset; const int window_step_x = 16; const auto window_start_x = static_cast(window.x().start()); @@ -120,7 +122,49 @@ void run_dequantization(const ITensor *input, ITensor *output, const Window &win for(; x <= (window_end_x - window_step_x); x += window_step_x) { const auto vin = wrapper::vloadq(in_ptr + x); - const auto vdeq = vdequantize(vin, qinfo); + const auto vdeq = vdequantize(vin, scale, offset); + + store_result(reinterpret_cast(out_ptr + x), vdeq); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + uint8_t val = *(in_ptr + x); + *(out_ptr + x) = static_cast(dequantize(val, scale, offset)); + } + }, + in, out); +} + +template +void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window) +{ + const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); + const float scale = qinfo.scale; + + const int window_step_x = 16; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win_collapsed); + Iterator out(output, win_collapsed); + + execute_window_loop(win_collapsed, [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast(in.ptr()); + const auto out_ptr = reinterpret_cast(out.ptr()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize(vin, scale); store_result(reinterpret_cast(out_ptr + x), vdeq); } @@ -129,11 +173,27 @@ void run_dequantization(const ITensor *input, ITensor *output, const Window &win for(; x < window_end_x; ++x) { uint8_t val = *(in_ptr + x); - *(out_ptr + x) = static_cast(dequantize_qasymm8(val, qinfo)); + *(out_ptr + x) = static_cast(dequantize(val, scale)); } }, in, out); } + +template +void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window) +{ + switch(input->info()->data_type()) + { + case DataType::QASYMM8: + run_dequantization_qasymm8(input, output, window); + break; + case DataType::QSYMM8: + run_dequantization_qsymm8(input, output, window); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type."); + } +} } // namespace NEDequantizationLayerKernel::NEDequantizationLayerKernel() @@ -173,11 +233,11 @@ void NEDequantizationLayerKernel::run(const Window &window, const ThreadInfo &in switch(_output->info()->data_type()) { case DataType::F32: - run_dequantization(_input, _output, window); + run_dequantization_core(_input, _output, window); break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - run_dequantization(_input, _output, window); + run_dequantization_core(_input, _output, window); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: -- cgit v1.2.1