diff options
Diffstat (limited to 'src/core/NEON/kernels/NEDequantizationLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEDequantizationLayerKernel.cpp | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp index f555df3828..947f257bcb 100644 --- a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp @@ -43,7 +43,7 @@ namespace Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8, DataType::QSYMM16); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8, DataType::QSYMM16); if(output->tensor_shape().total_size() > 0) { @@ -116,7 +116,7 @@ inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v) } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -template <typename T> +template <typename TOut, typename TIn> void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window) { const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); @@ -137,8 +137,8 @@ void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Win execute_window_loop(win_collapsed, [&](const Coordinates &) { - const auto in_ptr = reinterpret_cast<const uint8_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<T *>(out.ptr()); + const auto in_ptr = reinterpret_cast<const TIn *>(in.ptr()); + const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()); int x = window_start_x; for(; x <= (window_end_x - window_step_x); x += window_step_x) @@ -146,14 +146,14 @@ void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Win const auto vin = wrapper::vloadq(in_ptr + x); const auto vdeq = vdequantize(vin, scale, offset); - store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); + store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq); } // Compute left-over elements for(; x < window_end_x; ++x) { - uint8_t val = *(in_ptr + x); - *(out_ptr + x) = static_cast<T>(dequantize(val, scale, offset)); + auto val = *(in_ptr + x); + *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo)); } }, in, out); @@ -340,7 +340,10 @@ void run_dequantization_core(const ITensor *input, ITensor *output, const Window switch(input->info()->data_type()) { case DataType::QASYMM8: - run_dequantization_qasymm8<T>(input, output, window); + run_dequantization_qasymm8<T, uint8_t>(input, output, window); + break; + case DataType::QASYMM8_SIGNED: + run_dequantization_qasymm8<T, int8_t>(input, output, window); break; case DataType::QSYMM8_PER_CHANNEL: input->info()->data_layout() == DataLayout::NHWC ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window) : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window); |