diff options
Diffstat (limited to 'src/cpu/kernels/softmax/generic/sve/impl.cpp')
-rw-r--r-- | src/cpu/kernels/softmax/generic/sve/impl.cpp | 211 |
1 files changed, 114 insertions, 97 deletions
diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp index 2340a31cbd..24f1bb8143 100644 --- a/src/cpu/kernels/softmax/generic/sve/impl.cpp +++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp @@ -23,6 +23,7 @@ */ #include "src/cpu/kernels/softmax/generic/sve/impl.h" + #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" namespace arm_compute @@ -36,42 +37,48 @@ void sve_logits_1d_max(const ITensor *in, ITensor *out, const Window &window) const auto window_start_x = static_cast<int>(window.x().start()); const auto window_end_x = static_cast<int>(window.x().end()); - Window win{ window }; + Window win{window}; win.set(Window::DimX, Window::Dimension(0, 1, 1)); Iterator input(in, win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - // Get pointers - const auto in_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); - const auto out_ptr = reinterpret_cast<ScalarType *>(output.ptr()); + execute_window_loop( + win, + [&](const Coordinates &) + { + // Get pointers + const auto in_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); + const auto out_ptr = reinterpret_cast<ScalarType *>(output.ptr()); - // Init max value - auto vec_max = wrapper::svdup_n(support::cpp11::lowest<ScalarType>()); + // Init max value + auto vec_max = wrapper::svdup_n(support::cpp11::lowest<ScalarType>()); - int x = window_start_x; - svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); - do - { - const auto current_value = svld1(pg, in_ptr + x); - vec_max = svmax_m(pg, vec_max, current_value); + int x = window_start_x; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); + do + { + const auto current_value = svld1(pg, in_ptr + x); + vec_max = svmax_m(pg, vec_max, current_value); - x += wrapper::svcnt<ScalarType>(); - pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); - } - while(svptest_any(all_true_pg, pg)); + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); + } while (svptest_any(all_true_pg, pg)); - auto max_val = svmaxv(all_true_pg, vec_max); + auto max_val = svmaxv(all_true_pg, vec_max); - *out_ptr = max_val; - }, - input, output); + *out_ptr = max_val; + }, + input, output); } template <typename ScalarType> -void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *const tmp, - ITensor *out, const float beta, bool is_log, const Window &window) +void sve_softmax_logits_1d_float(const ITensor *in, + const ITensor *max, + void *const tmp, + ITensor *out, + const float beta, + bool is_log, + const Window &window) { const int start_x = in->info()->valid_region().anchor.x(); const int input_width = in->info()->valid_region().shape.x(); @@ -82,88 +89,88 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co const auto all_true_pg = wrapper::svptrue<ScalarType>(); - execute_window_loop(window, [&](const Coordinates &) - { - /* Get pointers */ - const auto in_ptr = reinterpret_cast<const ScalarType *>(in_it.ptr()) + start_x; - const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr()) + start_x; - const auto tmp_ptr = reinterpret_cast<ScalarType *>(tmp); - - ScalarType sum{ 0 }; - - /* Compute exponentials and sum */ + execute_window_loop( + window, + [&](const Coordinates &) { - /* Get max value */ - const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr()); - const auto vec_max = wrapper::svdup_n(max_val); - const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta)); + /* Get pointers */ + const auto in_ptr = reinterpret_cast<const ScalarType *>(in_it.ptr()) + start_x; + const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr()) + start_x; + const auto tmp_ptr = reinterpret_cast<ScalarType *>(tmp); - /* Init sum to zero */ - auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0)); + ScalarType sum{0}; - /* Loop over row and compute exponentials and sum */ - int x = 0; - svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); - do + /* Compute exponentials and sum */ { - auto vec_elements = svld1(pg, in_ptr + x); - vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta); - if(!is_log) + /* Get max value */ + const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr()); + const auto vec_max = wrapper::svdup_n(max_val); + const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta)); + + /* Init sum to zero */ + auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0)); + + /* Loop over row and compute exponentials and sum */ + int x = 0; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); + do { - vec_elements = wrapper::svexp_z(pg, vec_elements); - vec_sum = svadd_m(pg, vec_sum, vec_elements); + auto vec_elements = svld1(pg, in_ptr + x); + vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta); + if (!is_log) + { + vec_elements = wrapper::svexp_z(pg, vec_elements); + vec_sum = svadd_m(pg, vec_sum, vec_elements); + } + svst1(pg, tmp_ptr + x, vec_elements); + + if (is_log) + { + vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); + } + + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, input_width); + } while (svptest_any(all_true_pg, pg)); + + /* Reduce sum */ + sum = svaddv(all_true_pg, vec_sum); + + if (is_log) + { + sum = static_cast<ScalarType>(std::log(sum)); } - svst1(pg, tmp_ptr + x, vec_elements); - - if(is_log) + else { - vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); + sum = ScalarType(1) / sum; } - - x += wrapper::svcnt<ScalarType>(); - pg = wrapper::svwhilelt<ScalarType>(x, input_width); } - while(svptest_any(all_true_pg, pg)); - /* Reduce sum */ - sum = svaddv(all_true_pg, vec_sum); - - if(is_log) - { - sum = static_cast<ScalarType>(std::log(sum)); - } - else - { - sum = ScalarType(1) / sum; - } - } - - /* Normalize exponentials */ - { - /* Loop over row and compute softmax */ - int x = 0; - svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); - do + /* Normalize exponentials */ { - auto vec_in = svld1(pg, tmp_ptr + x); - auto normalized_value = wrapper::svdup_n(static_cast<ScalarType>(0)); - if(is_log) - { - normalized_value = svsub_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum))); - } - else + /* Loop over row and compute softmax */ + int x = 0; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); + do { - normalized_value = svmul_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum))); - } - svst1(pg, out_ptr + x, normalized_value); - - x += wrapper::svcnt<ScalarType>(); - pg = wrapper::svwhilelt<ScalarType>(x, input_width); + auto vec_in = svld1(pg, tmp_ptr + x); + auto normalized_value = wrapper::svdup_n(static_cast<ScalarType>(0)); + if (is_log) + { + normalized_value = svsub_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum))); + } + else + { + normalized_value = svmul_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum))); + } + svst1(pg, out_ptr + x, normalized_value); + + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, input_width); + } while (svptest_any(all_true_pg, pg)); } - while(svptest_any(all_true_pg, pg)); - } - }, - in_it, max_it, out_it); + }, + in_it, max_it, out_it); } template void sve_logits_1d_max<float>(const ITensor *in, ITensor *out, const Window &window); @@ -171,9 +178,19 @@ template void sve_logits_1d_max<float16_t>(const ITensor *in, ITensor *out, cons template void sve_logits_1d_max<qasymm8_t>(const ITensor *in, ITensor *out, const Window &window); template void sve_logits_1d_max<qasymm8_signed_t>(const ITensor *in, ITensor *out, const Window &window); -template void sve_softmax_logits_1d_float<float>(const ITensor *in, const ITensor *max, void *const tmp, - ITensor *out, const float beta, bool is_log, const Window &window); -template void sve_softmax_logits_1d_float<float16_t>(const ITensor *in, const ITensor *max, void *const tmp, - ITensor *out, const float beta, bool is_log, const Window &window); +template void sve_softmax_logits_1d_float<float>(const ITensor *in, + const ITensor *max, + void *const tmp, + ITensor *out, + const float beta, + bool is_log, + const Window &window); +template void sve_softmax_logits_1d_float<float16_t>(const ITensor *in, + const ITensor *max, + void *const tmp, + ITensor *out, + const float beta, + bool is_log, + const Window &window); } // namespace cpu } // namespace arm_compute |