aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEQuantizationLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEQuantizationLayerKernel.cpp100
1 files changed, 78 insertions, 22 deletions
diff --git a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
index 0aa34cd411..6a9c4ae14c 100644
--- a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
@@ -34,9 +34,10 @@
#include "arm_compute/core/CPP/Validate.h"
#include <arm_neon.h>
+#include <map>
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
@@ -45,7 +46,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape().total_size() == 0);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM16);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
return Status{};
@@ -71,7 +72,7 @@ inline const float32x4x4_t load_value(const float16_t *input_ptr)
} // namespace
NEQuantizationLayerKernel::NEQuantizationLayerKernel()
- : _input(nullptr), _output(nullptr)
+ : _input(nullptr), _output(nullptr), _func(nullptr)
{
}
@@ -83,6 +84,33 @@ void NEQuantizationLayerKernel::configure(const ITensor *input, ITensor *output)
_input = input;
_output = output;
+ static std::map<DataType, QuantizationFunctionExecutorPtr> quant_map_f32 =
+ {
+ { DataType::QASYMM8, &NEQuantizationLayerKernel::run_quantize_qasymm8<float> },
+ { DataType::QASYMM16, &NEQuantizationLayerKernel::run_quantize_qasymm16<float> },
+ };
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ static std::map<DataType, QuantizationFunctionExecutorPtr> quant_map_f16 =
+ {
+ { DataType::QASYMM8, &NEQuantizationLayerKernel::run_quantize_qasymm8<float16_t> },
+ { DataType::QASYMM16, &NEQuantizationLayerKernel::run_quantize_qasymm16<float16_t> },
+ };
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/
+
+ switch(input->info()->data_type())
+ {
+ case DataType::F32:
+ _func = quant_map_f32[output->info()->data_type()];
+ break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ _func = quant_map_f16[output->info()->data_type()];
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ ARM_COMPUTE_ERROR("Unsupported input data type.");
+ }
+
// Configure kernel window
Window win_config = calculate_max_window(*input->info(), Steps());
@@ -96,18 +124,17 @@ void NEQuantizationLayerKernel::configure(const ITensor *input, ITensor *output)
Status NEQuantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
-
return Status{};
}
template <typename T>
-void NEQuantizationLayerKernel::quantize(const Window &window, const QuantizationInfo &qinfo)
+void NEQuantizationLayerKernel::run_quantize_qasymm8(const Window &window)
{
constexpr auto window_step = 16;
const auto window_start_x = static_cast<int>(window.x().start());
const auto window_end_x = static_cast<int>(window.x().end());
- const UniformQuantizationInfo uqinfo = qinfo.uniform();
+ const UniformQuantizationInfo uqinfo = _output->info()->quantization_info().uniform();
#ifdef __aarch64__
constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
#else //__aarch64__
@@ -139,25 +166,54 @@ void NEQuantizationLayerKernel::quantize(const Window &window, const Quantizatio
input, output);
}
+template <typename T>
+void NEQuantizationLayerKernel::run_quantize_qasymm16(const Window &window)
+{
+ constexpr auto window_step = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo uqinfo = _output->info()->quantization_info().uniform();
+#ifdef __aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
+#else //__aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
+#endif //__aarch64__
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(_input, win_collapsed);
+ Iterator output(_output, win_collapsed);
+ execute_window_loop(win_collapsed, [&](const Coordinates &)
+ {
+ auto input_ptr = reinterpret_cast<const T *>(input.ptr());
+ auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step); x += window_step)
+ {
+ uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo);
+ vst1q_u16(&output_ptr[x], tmp.val[0]);
+ vst1q_u16(&output_ptr[x + 8], tmp.val[1]);
+ }
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy);
+ }
+ },
+ input, output);
+}
+
void NEQuantizationLayerKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+ ARM_COMPUTE_ERROR_ON(_func == nullptr);
- const QuantizationInfo &qinfo = _output->info()->quantization_info();
-
- switch(_input->info()->data_type())
- {
- case DataType::F32:
- NEQuantizationLayerKernel::quantize<float>(window, qinfo);
- break;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- NEQuantizationLayerKernel::quantize<float16_t>(window, qinfo);
- break;
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- default:
- ARM_COMPUTE_ERROR("Unsupported data type.");
- }
+ (this->*_func)(window);
}
+} // namespace arm_compute