diff options
Diffstat (limited to 'src/cpu/kernels/softmax/generic/neon/impl.cpp')
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/impl.cpp | 152 |
1 files changed, 96 insertions, 56 deletions
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp index 5d6e6a4f80..487f6ae051 100644 --- a/src/cpu/kernels/softmax/generic/neon/impl.cpp +++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp @@ -29,43 +29,76 @@ namespace arm_compute { namespace cpu { -template void neon_logits_1d_max<qasymm8_signed_t>(const ITensor *in, ITensor *out, const Window &window); -template void neon_logits_1d_max<qasymm8_t>(const ITensor *in, ITensor *out, const Window &window); - -template <typename T> -void neon_softmax_logits_1d_quantized( - const ITensor *in, const ITensor *max, void *const tmp, ITensor *out, float beta, bool is_log, const Window &window) +template <typename T, bool IS_LOG> +void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window) { static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value, "quantized type should be either qasymm8_t or qasymm8_signed_t."); - const int start_x = in->info()->valid_region().anchor.x(); const int input_width = in->info()->valid_region().shape.x(); - const float scale_beta = -beta * in->info()->quantization_info().uniform().scale; - const auto scale_beta_vec = vdupq_n_f32(scale_beta); + const float scale_beta = -beta * in->info()->quantization_info().uniform().scale; + const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta); + + Iterator in_it(in, window); + Iterator out_it(out, window); - Iterator in_it(in, window); - Iterator max_it(max, window); - Iterator out_it(out, window); constexpr int vec_size = 16; +#ifndef __aarch64__ + const int sum_stages = log2(vec_size >> 1); +#endif // __aarch64__ + + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; + execute_window_loop( window, [&](const Coordinates &) { /* Get pointers */ - const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x; - const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x; - const auto tmp_ptr = reinterpret_cast<float *>(tmp); + const T *in_ptr = reinterpret_cast<const T *>(in_it.ptr()); + T *out_ptr = reinterpret_cast<T *>(out_it.ptr()); + float *tmp_ptr = reinterpret_cast<float *>(tmp); + + T max_val; + + /* Compute Max */ + { + // Init max value + auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{}); + int x = 0; - float sum{}; - float sum_inversed{}; + for (; x <= (input_width - vec_size); x += vec_size) + { + const auto current_value = wrapper::vloadq(in_ptr + x); + vec_max = wrapper::vmax(vec_max, current_value); + } + +#ifdef __aarch64__ + max_val = wrapper::vmaxv(vec_max); +#else // __aarch64__ + auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max)); + + for (int i = 0; i < sum_stages; ++i) + { + carry_max = wrapper::vpmax(carry_max, carry_max); + } + + max_val = wrapper::vgetlane(carry_max, 0); +#endif // __aarch64__ + + // Compute left-over elements + for (; x < input_width; ++x) + { + max_val = std::max(*(in_ptr + x), max_val); + } + } // Compute Max + + float sum_transformed{}; /* Compute exponentials and sum */ { /* Get max value */ - const auto max_val = *reinterpret_cast<const T *>(max_it.ptr()); const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{}); /* Init sum to zero */ @@ -80,11 +113,11 @@ void neon_softmax_logits_1d_quantized( int x = 0; for (; x <= (input_width - vec_size); x += vec_size) { - auto vec_elements = wrapper::vloadq(in_ptr + x); - vec_elements = wrapper::vqsub(vec_max, vec_elements); - auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements); + auto vec_elements = wrapper::vloadq(in_ptr + x); + vec_elements = wrapper::vqsub(vec_max, vec_elements); + float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements); - if (is_log) + if (IS_LOG) { vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec); vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec); @@ -111,17 +144,24 @@ void neon_softmax_logits_1d_quantized( } /* Reduce sum */ - const auto sum_16_byte = + const float32x4_t sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3])); + + float sum; + +#ifdef __aarch64__ + sum = wrapper::vaddv(sum_16_byte); +#else // __aarch64__ auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte)); sum_res = vpadd_f32(sum_res, sum_res); sum = wrapper::vgetlane(sum_res, 0); +#endif // __aarch64__ /* Run remaining elements */ for (; x < input_width; ++x) { float element{}; - if (is_log) + if (IS_LOG) { element = (max_val - in_ptr[x]) * scale_beta; sum += std::exp(element); @@ -135,19 +175,22 @@ void neon_softmax_logits_1d_quantized( tmp_ptr[x] = element; } - if (!is_log) + if (!IS_LOG) { - sum_inversed = 256.f / sum; + sum_transformed = 256.f / sum; } else { - sum = std::log(sum); + sum_transformed = std::log(sum); } - } + } // Compute exponentials and sum /* Normalize exponentials */ { constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value; + + const float32x4_t sum_vec = vdupq_n_f32(sum_transformed); + /* Loop over row and compute softmax */ int x = 0; for (; x <= (input_width - vec_size); x += vec_size) @@ -155,23 +198,23 @@ void neon_softmax_logits_1d_quantized( using int_vec_type = wrapper::traits::neon_vector_t<T, 16>; float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x); int_vec_type normalized_value{}; - if (is_log) + if (IS_LOG) { const float32x4x4_t sub = { - vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)), - vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)), - vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)), - vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)), + vsubq_f32(vec_in.val[0], sum_vec), + vsubq_f32(vec_in.val[1], sum_vec), + vsubq_f32(vec_in.val[2], sum_vec), + vsubq_f32(vec_in.val[3], sum_vec), }; normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub); } else { float32x4x4_t mul = { - vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)), - vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)), - vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)), - vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)), + vmulq_f32(vec_in.val[0], sum_vec), + vmulq_f32(vec_in.val[1], sum_vec), + vmulq_f32(vec_in.val[2], sum_vec), + vmulq_f32(vec_in.val[3], sum_vec), }; if (is_qasymm8_signed) @@ -190,34 +233,31 @@ void neon_softmax_logits_1d_quantized( /* Run remaining elements */ for (; x < input_width; ++x) { - if (is_log) + if (IS_LOG) { - out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum); + out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum_transformed); } else { - out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) - + out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_transformed) - (is_qasymm8_signed ? 128.f : 0)); } } - } + } // Normalize exponentials }, - in_it, max_it, out_it); + in_it, out_it); } -template void neon_softmax_logits_1d_quantized<qasymm8_signed_t>(const ITensor *in, - const ITensor *max, - void *const tmp, - ITensor *out, - float beta, - bool is_log, - const Window &window); -template void neon_softmax_logits_1d_quantized<qasymm8_t>(const ITensor *in, - const ITensor *max, - void *const tmp, - ITensor *out, - float beta, - bool is_log, - const Window &window); +template void neon_softmax_quantized<qasymm8_signed_t, true>( + const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window); + +template void neon_softmax_quantized<qasymm8_signed_t, false>( + const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window); + +template void neon_softmax_quantized<qasymm8_t, true>( + const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window); + +template void neon_softmax_quantized<qasymm8_t, false>( + const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window); } // namespace cpu } // namespace arm_compute |