diff options
Diffstat (limited to 'src/cpu/kernels/elementwise_unary/generic/neon/impl.h')
-rw-r--r-- | src/cpu/kernels/elementwise_unary/generic/neon/impl.h | 244 |
1 files changed, 128 insertions, 116 deletions
diff --git a/src/cpu/kernels/elementwise_unary/generic/neon/impl.h b/src/cpu/kernels/elementwise_unary/generic/neon/impl.h index dbc1dde4fa..d54d3984cb 100644 --- a/src/cpu/kernels/elementwise_unary/generic/neon/impl.h +++ b/src/cpu/kernels/elementwise_unary/generic/neon/impl.h @@ -26,6 +26,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Types.h" + #include "src/core/NEON/NEAsymm.h" #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" @@ -36,7 +37,7 @@ namespace cpu template <typename ScalarType> inline ScalarType elementwise_op_scalar_imp(ElementWiseUnary op, const ScalarType &a) { - switch(op) + switch (op) { case ElementWiseUnary::RSQRT: return 1 / sqrt(a); @@ -60,7 +61,7 @@ inline ScalarType elementwise_op_scalar_imp(ElementWiseUnary op, const ScalarTyp template <typename ScalarType, typename VectorType> inline VectorType elementwise_op_imp(ElementWiseUnary op, const VectorType &a) { - switch(op) + switch (op) { case ElementWiseUnary::RSQRT: return wrapper::vinvsqrt(a); @@ -94,22 +95,24 @@ inline void elementwise_op(const ITensor *in, ITensor *out, const Window &window Iterator input(in, win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr()); - const auto input_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); - - int x = window_start_x; - for(; x <= window_end_x - window_step_x; x += window_step_x) + execute_window_loop( + win, + [&](const Coordinates &) { - wrapper::vstore(output_ptr + x, elementwise_op_imp<ScalarType>(op, wrapper::vloadq(input_ptr + x))); - } - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = elementwise_op_scalar_imp(op, *(input_ptr + x)); - } - }, - input, output); + auto output_ptr = reinterpret_cast<ScalarType *>(output.ptr()); + const auto input_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); + + int x = window_start_x; + for (; x <= window_end_x - window_step_x; x += window_step_x) + { + wrapper::vstore(output_ptr + x, elementwise_op_imp<ScalarType>(op, wrapper::vloadq(input_ptr + x))); + } + for (; x < window_end_x; ++x) + { + *(output_ptr + x) = elementwise_op_scalar_imp(op, *(input_ptr + x)); + } + }, + input, output); } template <> @@ -128,75 +131,81 @@ inline void elementwise_op<int8_t>(const ITensor *in, ITensor *out, const Window Iterator input(in, win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - int8x16_t vout; - auto output_ptr = reinterpret_cast<int8_t *>(output.ptr()); - const auto input_ptr = reinterpret_cast<const int8_t *>(input.ptr()); - const auto vconst_0_f32 = vdupq_n_f32(0); - auto clamped_value = (op == ElementWiseUnary::LOG) ? min_clamped_value : max_clamped_value; - - int x = window_start_x; - for(; x <= window_end_x - window_step_x; x += window_step_x) + execute_window_loop( + win, + [&](const Coordinates &) { - const auto vin = wrapper::vloadq(input_ptr + x); - - // De-quantize - const auto vin_deq = vdequantize(vin, qi_in); + int8x16_t vout; + auto output_ptr = reinterpret_cast<int8_t *>(output.ptr()); + const auto input_ptr = reinterpret_cast<const int8_t *>(input.ptr()); + const auto vconst_0_f32 = vdupq_n_f32(0); + auto clamped_value = (op == ElementWiseUnary::LOG) ? min_clamped_value : max_clamped_value; - // Perform activation - float32x4x4_t vtmp_deq = + int x = window_start_x; + for (; x <= window_end_x - window_step_x; x += window_step_x) { - { + const auto vin = wrapper::vloadq(input_ptr + x); + + // De-quantize + const auto vin_deq = vdequantize(vin, qi_in); + + // Perform activation + float32x4x4_t vtmp_deq = {{ elementwise_op_imp<float>(op, vin_deq.val[0]), elementwise_op_imp<float>(op, vin_deq.val[1]), elementwise_op_imp<float>(op, vin_deq.val[2]), elementwise_op_imp<float>(op, vin_deq.val[3]), + }}; + + if ((op == ElementWiseUnary::LOG) || (op == ElementWiseUnary::RSQRT)) + { + vtmp_deq.val[0] = + vbslq_f32(vcleq_f32(vin_deq.val[0], vconst_0_f32), clamped_value, vtmp_deq.val[0]); + vtmp_deq.val[1] = + vbslq_f32(vcleq_f32(vin_deq.val[1], vconst_0_f32), clamped_value, vtmp_deq.val[1]); + vtmp_deq.val[2] = + vbslq_f32(vcleq_f32(vin_deq.val[2], vconst_0_f32), clamped_value, vtmp_deq.val[2]); + vtmp_deq.val[3] = + vbslq_f32(vcleq_f32(vin_deq.val[3], vconst_0_f32), clamped_value, vtmp_deq.val[3]); } - }; - if((op == ElementWiseUnary::LOG) || (op == ElementWiseUnary::RSQRT)) - { - vtmp_deq.val[0] = vbslq_f32(vcleq_f32(vin_deq.val[0], vconst_0_f32), clamped_value, vtmp_deq.val[0]); - vtmp_deq.val[1] = vbslq_f32(vcleq_f32(vin_deq.val[1], vconst_0_f32), clamped_value, vtmp_deq.val[1]); - vtmp_deq.val[2] = vbslq_f32(vcleq_f32(vin_deq.val[2], vconst_0_f32), clamped_value, vtmp_deq.val[2]); - vtmp_deq.val[3] = vbslq_f32(vcleq_f32(vin_deq.val[3], vconst_0_f32), clamped_value, vtmp_deq.val[3]); + // Re-quantize to new output space + vout = vquantize_signed(vtmp_deq, qi_out); + wrapper::vstore(output_ptr + x, vout); } - - // Re-quantize to new output space - vout = vquantize_signed(vtmp_deq, qi_out); - wrapper::vstore(output_ptr + x, vout); - } - for(; x < window_end_x; ++x) - { - qasymm8_signed_t in = *(reinterpret_cast<const qasymm8_signed_t *>(input_ptr + x)); - qasymm8_signed_t tmp = 0; - float tmp_f = dequantize_qasymm8_signed(in, qi_in); - if(tmp_f <= 0.0) + for (; x < window_end_x; ++x) { - if(op == ElementWiseUnary::LOG) - { - tmp_f = (-128 - qi_out.offset) * qi_out.scale; - } - else if(op == ElementWiseUnary::RSQRT) + qasymm8_signed_t in = *(reinterpret_cast<const qasymm8_signed_t *>(input_ptr + x)); + qasymm8_signed_t tmp = 0; + float tmp_f = dequantize_qasymm8_signed(in, qi_in); + if (tmp_f <= 0.0) { - tmp_f = (127 - qi_out.offset) * qi_out.scale; + if (op == ElementWiseUnary::LOG) + { + tmp_f = (-128 - qi_out.offset) * qi_out.scale; + } + else if (op == ElementWiseUnary::RSQRT) + { + tmp_f = (127 - qi_out.offset) * qi_out.scale; + } + else + { + tmp_f = elementwise_op_scalar_imp<float>(op, tmp_f); + } } else { tmp_f = elementwise_op_scalar_imp<float>(op, tmp_f); } + tmp = quantize_qasymm8_signed( + tmp_f, qi_out, + RoundingPolicy:: + TO_ZERO); // Set rounding policy TO_ZERO to be compatible with vquantize_signed() used above that follow same policy for armv7a. + // For aarch64 LUT is used and rounding to nearest is used + *(output_ptr + x) = tmp; } - else - { - tmp_f = elementwise_op_scalar_imp<float>(op, tmp_f); - } - tmp = quantize_qasymm8_signed(tmp_f, qi_out, RoundingPolicy::TO_ZERO); // Set rounding policy TO_ZERO to be compatible with vquantize_signed() used above that follow same policy for armv7a. - // For aarch64 LUT is used and rounding to nearest is used - *(output_ptr + x) = tmp; - } - }, - input, output); + }, + input, output); } template <> inline void elementwise_op<uint8_t>(const ITensor *in, ITensor *out, const Window &window, ElementWiseUnary op) @@ -215,71 +224,74 @@ inline void elementwise_op<uint8_t>(const ITensor *in, ITensor *out, const Windo Iterator input(in, win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - uint8x16_t vout; - auto clamped_value = (op == ElementWiseUnary::LOG) ? min_clamped_value : max_clamped_value; - auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); - const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr()); - int x = window_start_x; - for(; x <= window_end_x - window_step_x; x += window_step_x) + execute_window_loop( + win, + [&](const Coordinates &) { - const auto vin = wrapper::vloadq(input_ptr + x); + uint8x16_t vout; + auto clamped_value = (op == ElementWiseUnary::LOG) ? min_clamped_value : max_clamped_value; + auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); + const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr()); + int x = window_start_x; + for (; x <= window_end_x - window_step_x; x += window_step_x) + { + const auto vin = wrapper::vloadq(input_ptr + x); - // De-quantize - const auto vin_deq = vdequantize(vin, qi_in); + // De-quantize + const auto vin_deq = vdequantize(vin, qi_in); - // Perform activation - float32x4x4_t vtmp_deq = - { - { + // Perform activation + float32x4x4_t vtmp_deq = {{ elementwise_op_imp<float>(op, vin_deq.val[0]), elementwise_op_imp<float>(op, vin_deq.val[1]), elementwise_op_imp<float>(op, vin_deq.val[2]), elementwise_op_imp<float>(op, vin_deq.val[3]), + }}; + if ((op == ElementWiseUnary::LOG) || (op == ElementWiseUnary::RSQRT)) + { + vtmp_deq.val[0] = + vbslq_f32(vcleq_f32(vin_deq.val[0], vconst_0_f32), clamped_value, vtmp_deq.val[0]); + vtmp_deq.val[1] = + vbslq_f32(vcleq_f32(vin_deq.val[1], vconst_0_f32), clamped_value, vtmp_deq.val[1]); + vtmp_deq.val[2] = + vbslq_f32(vcleq_f32(vin_deq.val[2], vconst_0_f32), clamped_value, vtmp_deq.val[2]); + vtmp_deq.val[3] = + vbslq_f32(vcleq_f32(vin_deq.val[3], vconst_0_f32), clamped_value, vtmp_deq.val[3]); } - }; - if((op == ElementWiseUnary::LOG) || (op == ElementWiseUnary::RSQRT)) - { - vtmp_deq.val[0] = vbslq_f32(vcleq_f32(vin_deq.val[0], vconst_0_f32), clamped_value, vtmp_deq.val[0]); - vtmp_deq.val[1] = vbslq_f32(vcleq_f32(vin_deq.val[1], vconst_0_f32), clamped_value, vtmp_deq.val[1]); - vtmp_deq.val[2] = vbslq_f32(vcleq_f32(vin_deq.val[2], vconst_0_f32), clamped_value, vtmp_deq.val[2]); - vtmp_deq.val[3] = vbslq_f32(vcleq_f32(vin_deq.val[3], vconst_0_f32), clamped_value, vtmp_deq.val[3]); - } - // Re-quantize to new output space - vout = vquantize(vtmp_deq, qi_out); - wrapper::vstore(output_ptr + x, vout); - } - for(; x < window_end_x; ++x) - { - qasymm8_t in = *(reinterpret_cast<const qasymm8_t *>(input_ptr + x)); - qasymm8_t tmp = 0; - float tmp_f = dequantize_qasymm8(in, qi_in); - if(tmp_f <= 0.0) + // Re-quantize to new output space + vout = vquantize(vtmp_deq, qi_out); + wrapper::vstore(output_ptr + x, vout); + } + for (; x < window_end_x; ++x) { - if(op == ElementWiseUnary::LOG) + qasymm8_t in = *(reinterpret_cast<const qasymm8_t *>(input_ptr + x)); + qasymm8_t tmp = 0; + float tmp_f = dequantize_qasymm8(in, qi_in); + if (tmp_f <= 0.0) { - tmp_f = (0 - qi_out.offset) * qi_out.scale; - } - else if(op == ElementWiseUnary::RSQRT) - { - tmp_f = (255 - qi_out.offset) * qi_out.scale; + if (op == ElementWiseUnary::LOG) + { + tmp_f = (0 - qi_out.offset) * qi_out.scale; + } + else if (op == ElementWiseUnary::RSQRT) + { + tmp_f = (255 - qi_out.offset) * qi_out.scale; + } + else + { + tmp_f = elementwise_op_scalar_imp<float>(op, tmp_f); + } } else { tmp_f = elementwise_op_scalar_imp<float>(op, tmp_f); } + tmp = quantize_qasymm8(tmp_f, qi_out, RoundingPolicy::TO_ZERO); + *(output_ptr + x) = tmp; } - else - { - tmp_f = elementwise_op_scalar_imp<float>(op, tmp_f); - } - tmp = quantize_qasymm8(tmp_f, qi_out, RoundingPolicy::TO_ZERO); - *(output_ptr + x) = tmp; - } - }, - input, output); + }, + input, output); } } // namespace cpu |