aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-10-29 17:20:09 +0000
committerManuel Bottini <manuel.bottini@arm.com>2019-11-25 18:13:09 +0000
commit21079dd320c00068208acdfd59177895265a53f2 (patch)
tree76a9f889260146a40cb50023925941418c3b4704 /src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
parent6d8b94ac6864dfd7ad38bc110006bdca5ee0f266 (diff)
downloadComputeLibrary-21079dd320c00068208acdfd59177895265a53f2.tar.gz
COMPMID-2700: Use NEON wrapper on SoftmaxLayer
Change-Id: Id8901e865c9f355dcf7b2a1a539493099591377e Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/2186 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NESoftmaxLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NESoftmaxLayerKernel.cpp562
1 files changed, 129 insertions, 433 deletions
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index 1003ebd2e3..a3ecce3a1e 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -30,6 +30,7 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
#include "arm_compute/core/NEON/NEMath.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
@@ -43,309 +44,6 @@
namespace arm_compute
{
-template <typename T, int N>
-struct vec_n_type;
-
-#define DECLARE_NEON_VEC_TYPE(T, N, V) \
- template <> \
- struct vec_n_type<T, N> \
- { \
- using type = V; \
- };
-
-DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
-DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
-
-DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
-DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
-
-DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
-DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
-
-DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
-DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
-
-DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
-DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
-
-DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
-DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
-DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
-DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
-
-template <typename T, int N>
-using vec_n_t = typename vec_n_type<T, N>::type;
-
-template <typename T, int N>
-using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
-
-template <typename T>
-using vec_16_byte_t = vec_n_byte_t<T, 16>;
-
-template <typename T>
-using vec_8_byte_t = vec_n_byte_t<T, 8>;
-
-template <typename T>
-using const_ptr_t = const T *;
-
-template <typename T>
-using ptr_t = T *;
-
-#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
- template <int lane> \
- TYPE vget_lane(vec_8_byte_t<TYPE> vec); \
- template <int lane> \
- TYPE vget_lane(vec_16_byte_t<TYPE> vec);
-
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
-template <int lane>
-float vget_lane(float32x4x4_t vec);
-
-template <typename V>
-using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
-
-template <typename V>
-constexpr size_t vec_size_of(const V &vec)
-{
- return sizeof(vec) / sizeof(elem_type_t<V>);
-}
-
-template <typename V>
-V vdup_n(elem_type_t<V> val);
-template <typename V>
-V vld(const_ptr_t<elem_type_t<V>> ptr);
-
-#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \
- template <> \
- inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val) \
- { \
- return vdup_n_##TAG(val); \
- } \
- template <> \
- inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val) \
- { \
- return vdupq_n_##TAG(val); \
- } \
- template <> \
- inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
- { \
- return vld1_##TAG(ptr); \
- } \
- template <> \
- inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
- { \
- return vld1q_##TAG(ptr); \
- } \
- inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec) \
- { \
- vst1_##TAG(ptr, vec); \
- } \
- inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec) \
- { \
- vst1q_##TAG(ptr, vec); \
- } \
- inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
- { \
- return vmaxq_##TAG(a, b); \
- } \
- inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
- { \
- return vpmax_##TAG(a, b); \
- } \
- inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec) \
- { \
- return vget_low_##TAG(vec); \
- } \
- inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec) \
- { \
- return vget_high_##TAG(vec); \
- } \
- template <int lane> \
- inline TYPE vget_lane(vec_8_byte_t<TYPE> vec) \
- { \
- static_assert(lane >= 0, "lane is out of bounds"); \
- static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
- return vget_lane_##TAG(vec, lane); \
- } \
- template <int lane> \
- inline TYPE vget_lane(vec_16_byte_t<TYPE> vec) \
- { \
- static_assert(lane >= 0, "lane is out of bounds"); \
- static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
- return vgetq_lane_##TAG(vec, lane); \
- }
-
-template <typename T>
-T sqadd(T a, T b);
-template <typename T>
-T sqsub(T a, T b);
-template <typename T>
-T sqmul(T a, T b);
-
-#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
- inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
- { \
- return vadd_##TAG(a, b); \
- } \
- inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
- { \
- return vaddq_##TAG(a, b); \
- } \
- inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
- { \
- return vsubq_##TAG(a, b); \
- } \
- inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
- { \
- return vmulq_n_##TAG(vec, val); \
- }
-
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
-
-template <typename VO, typename VI>
-VO vcvt(VI vec);
-
-template <>
-float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
-{
- const auto low = vmovl_u8(vget_low(vec));
- const auto high = vmovl_u8(vget_high(vec));
- float32x4x4_t res = { {
- vcvtq_f32_u32(vmovl_u16(vget_low(low))),
- vcvtq_f32_u32(vmovl_u16(vget_high(low))),
- vcvtq_f32_u32(vmovl_u16(vget_low(high))),
- vcvtq_f32_u32(vmovl_u16(vget_high(high)))
- }
- };
- return res;
-}
-
-template <>
-uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
-{
- uint16x8x2_t resU16 = { {
- vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
- vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
- vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
- vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
- }
- };
-
- uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
- return res;
-}
-
-float32x4x4_t vexp(float32x4x4_t vec)
-{
- float32x4x4_t res = { {
- vexpq_f32(vec.val[0]),
- vexpq_f32(vec.val[1]),
- vexpq_f32(vec.val[2]),
- vexpq_f32(vec.val[3])
- }
- };
- return res;
-}
-
-float32x4_t vexp(const float32x4_t &vec)
-{
- return vexpq_f32(vec);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-// TODO (COMPMID-1535) : Revisit FP16 approximations
-float16x8_t vexp(const float16x8_t &vec)
-{
- float16x4x2_t res =
- {
- {
- vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_low_f16(vec)))),
- vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_high_f16(vec))))
- }
- };
- return vcombine_f16(res.val[0], res.val[1]);
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <>
-float32x4x4_t vdup_n<float32x4x4_t>(float val)
-{
- float32x4x4_t res = { {
- vdupq_n_f32(val),
- vdupq_n_f32(val),
- vdupq_n_f32(val),
- vdupq_n_f32(val)
- }
- };
- return res;
-}
-
-float32x4x4_t vmul_n(float32x4x4_t vec, float val)
-{
- float32x4x4_t res = { {
- vmulq_n_f32(vec.val[0], val),
- vmulq_n_f32(vec.val[1], val),
- vmulq_n_f32(vec.val[2], val),
- vmulq_n_f32(vec.val[3], val)
- }
- };
- return res;
-}
-
-float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
-{
- float32x4x4_t res = { {
- vaddq_f32(a.val[0], b.val[0]),
- vaddq_f32(a.val[1], b.val[1]),
- vaddq_f32(a.val[2], b.val[2]),
- vaddq_f32(a.val[3], b.val[3])
- }
- };
- return res;
-}
-
-float32x4x4_t vsub_n(float32x4x4_t a, float val)
-{
- auto scalar_vector = vdup_n<float32x4x4_t>(val);
- float32x4x4_t res = { {
- vsubq_f32(a.val[0], scalar_vector.val[0]),
- vsubq_f32(a.val[1], scalar_vector.val[1]),
- vsubq_f32(a.val[2], scalar_vector.val[2]),
- vsubq_f32(a.val[3], scalar_vector.val[3])
- }
- };
- return res;
-}
-
namespace
{
Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
@@ -390,30 +88,20 @@ std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInf
return std::make_pair(err, win);
}
-template <typename V>
-auto reduce_max(V vec) -> elem_type_t<V>
-{
- constexpr int N = vec_size_of(vec);
-
- auto carry_max = vpmax(vget_high(vec), vget_low(vec));
-
- for(int k = N / 2; k > 1; k /= 2)
- {
- carry_max = vpmax(carry_max, carry_max);
- }
-
- return vget_lane<0>(carry_max);
-}
-
template <typename T>
void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
{
const auto start_x = in.info()->valid_region().anchor.x();
const size_t input_width = in.info()->valid_region().shape.x();
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
Iterator input(&in, window);
Iterator output(&out, window);
+ constexpr int window_step_x = 16 / sizeof(T);
+ const int sum_stages = log2(window_step_x / 2);
execute_window_loop(window, [&](const Coordinates &)
{
// Get pointers
@@ -421,16 +109,22 @@ void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
const auto out_ptr = reinterpret_cast<T *>(output.ptr());
// Init max value
- auto vec_max = vdup_n<vec_16_byte_t<T>>(support::cpp11::lowest<T>());
+ auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
// Loop over input row
- for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
+ for(const T *it = in_ptr; it < (in_ptr + input_width); it += window_step_x)
{
- const auto current_value = vld<vec_16_byte_t<T>>(it);
- vec_max = vmax(vec_max, current_value);
+ const auto current_value = wrapper::vloadq(it);
+ vec_max = wrapper::vmax(vec_max, current_value);
}
- const T max_val = reduce_max(vec_max);
+ auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
+
+ for(int i = 0; i < sum_stages; ++i)
+ {
+ carry_max = wrapper::vpmax(carry_max, carry_max);
+ }
+ const T max_val = wrapper::vgetlane(carry_max, 0);
*out_ptr = max_val;
},
input, output);
@@ -575,45 +269,19 @@ std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorIn
return std::make_pair(err, win);
}
-template <typename T, int N, int S, int E>
-struct reduce_add_impl
-{
- template <typename F>
- static T reduce(F add_fn, vec_n_t<T, N> vec)
- {
- constexpr int H = (S + E + 1) / 2;
- const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
- const auto reduced_low = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
- return add_fn(reduced_high, reduced_low);
- }
-};
-template <typename T, int N, int I>
-struct reduce_add_impl<T, N, I, I>
-{
- template <typename F>
- static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
- {
- return vget_lane<I>(vec);
- }
-};
-template <typename V, typename F>
-elem_type_t<V> reduce_add(F add_fn, V vec)
-{
- constexpr int N = vec_size_of(vec);
- return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
-}
-
template <bool is_log>
void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
{
const int start_x = in.info()->valid_region().anchor.x();
const int input_width = in.info()->valid_region().shape.x();
- const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
+ const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
+ const auto scale_beta_vec = vdupq_n_f32(scale_beta);
- Iterator in_it(&in, window);
- Iterator max_it(&max, window);
- Iterator out_it(&out, window);
+ Iterator in_it(&in, window);
+ Iterator max_it(&max, window);
+ Iterator out_it(&out, window);
+ constexpr int vec_size = 16;
execute_window_loop(window, [&](const Coordinates &)
{
@@ -629,57 +297,73 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons
{
/* Get max value */
const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
- const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
+ const auto vec_max = vdupq_n_u8(max_val);
/* Init sum to zero */
- auto vec_sum = vdup_n<float32x4x4_t>(0.f);
+ float32x4x4_t vec_sum =
+ {
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ };
/* Loop over row and compute exponentials and sum */
- int i = 0;
- constexpr int vec_size = vec_size_of(vec_max);
-
- for(; i <= (input_width - vec_size); i += vec_size)
+ int x = 0;
+ for(; x <= (input_width - vec_size); x += vec_size)
{
- auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
- vec_elements = vsubq_u8(vec_max, vec_elements);
-
- auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
+ auto vec_elements = wrapper::vloadq(in_ptr + x);
+ vec_elements = vsubq_u8(vec_max, vec_elements);
+ auto vec_elements_flt = convert_uint8x16_to_float32x4x4(vec_elements);
if(is_log)
{
- vec_elements_flt = vmul_n(vec_elements_flt, scale_beta);
- vec_sum = vadd(vec_sum, vexp(vec_elements_flt));
+ vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+ vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+ vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+ vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
}
else
{
- vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
- vec_sum = vadd(vec_sum, vec_elements_flt);
+ vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+ vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+ vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+ vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
}
- vst4q_f32(tmp_ptr + i, vec_elements_flt);
+
+ vst4q_f32(tmp_ptr + x, vec_elements_flt);
}
/* Reduce sum */
- const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
- vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
- const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
- sum = reduce_add(std::plus<float>(), sum_8_byte);
+ const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
+ auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
+ sum_res = vpadd_f32(sum_res, sum_res);
+ sum = wrapper::vgetlane(sum_res, 0);
/* Run remaining elements */
- for(; i < input_width; ++i)
+ for(; x < input_width; ++x)
{
float element{};
if(is_log)
{
- element = (max_val - in_ptr[i]) * scale_beta;
+ element = (max_val - in_ptr[x]) * scale_beta;
sum += std::exp(element);
}
else
{
- element = std::exp((max_val - in_ptr[i]) * scale_beta);
+ element = std::exp((max_val - in_ptr[x]) * scale_beta);
sum += element;
}
- tmp_ptr[i] = element;
+ tmp_ptr[x] = element;
}
if(!is_log)
@@ -691,35 +375,45 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons
/* Normalize exponentials */
{
/* Loop over row and compute softmax */
- int i = 0;
+ int x = 0;
+ for(; x <= (input_width - vec_size); x += vec_size)
{
- constexpr int vec_size = 16;
-
- for(; i <= (input_width - vec_size); i += vec_size)
+ float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
+ uint8x16_t normalized_value{};
+ if(is_log)
{
- float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
- vec_16_byte_t<qasymm8_t> normalized_value{};
- if(is_log)
+ const float32x4x4_t sub =
{
- normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vsub_n(vec_in, sum));
- }
- else
+ vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
+ vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
+ vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
+ vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
+ };
+ convert_float32x4x4_to_unit8x16(sub, normalized_value);
+ }
+ else
+ {
+ const float32x4x4_t mul =
{
- normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
- }
- vst(out_ptr + i, normalized_value);
+ vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
+ vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
+ vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
+ vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
+ };
+ convert_float32x4x4_to_unit8x16(mul, normalized_value);
}
+ vst1q_u8(out_ptr + x, normalized_value);
}
/* Run remaining elements */
- for(; i < input_width; ++i)
+ for(; x < input_width; ++x)
{
if(is_log)
{
- out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] - sum);
+ out_ptr[x] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[x] - sum);
}
else
{
- out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+ out_ptr[x] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[x] * sum_inversed);
}
}
}
@@ -738,6 +432,12 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const
Iterator max_it(&max, window);
Iterator out_it(&out, window);
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ constexpr int vec_size = 16 / sizeof(T);
+ const int sum_stages = log2(vec_size / 2);
+
execute_window_loop(window, [&](const Coordinates &)
{
/* Get pointers */
@@ -752,53 +452,54 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const
{
/* Get max value */
const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
- const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
+ const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
/* Init sum to zero */
- auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
+ auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
/* Loop over row and compute exponentials and sum */
- int i = 0;
- constexpr int vec_size = vec_size_of(vec_sum);
-
- for(; i <= (input_width - vec_size); i += vec_size)
+ int x = 0;
+ for(; x <= (input_width - vec_size); x += vec_size)
{
- auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
- vec_elements = vsub(vec_elements, vec_max);
+ auto vec_elements = wrapper::vloadq(in_ptr + x);
+ vec_elements = wrapper::vsub(vec_elements, vec_max);
if(is_log)
{
- vec_elements = vmul_n(vec_elements, static_cast<T>(beta));
- vec_sum = vadd(vec_sum, vexp(vec_elements));
+ vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
+ vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
}
else
{
- vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
- vec_sum = vadd(vec_sum, vec_elements);
+ vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
+ vec_sum = wrapper::vadd(vec_sum, vec_elements);
}
- vst(tmp_ptr + i, vec_elements);
+ wrapper::vstore(tmp_ptr + x, vec_elements);
}
/* Reduce sum */
- const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
- sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
+ auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
+ for(int i = 0; i < sum_stages; ++i)
+ {
+ sum_res = wrapper::vpadd(sum_res, sum_res);
+ }
+ sum = wrapper::vgetlane(sum_res, 0);
/* Run remaining elements */
-
- for(; i < input_width; ++i)
+ for(; x < input_width; ++x)
{
T element{};
if(is_log)
{
- element = (in_ptr[i] - max_val) * beta;
+ element = (in_ptr[x] - max_val) * beta;
sum += std::exp(element);
}
else
{
- element = std::exp((in_ptr[i] - max_val) * beta);
+ element = std::exp((in_ptr[x] - max_val) * beta);
sum += element;
}
- tmp_ptr[i] = element;
+ tmp_ptr[x] = element;
}
if(!is_log)
@@ -810,36 +511,31 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const
/* Normalize exponentials */
{
/* Loop over row and compute softmax */
- int i = 0;
-
+ int x = 0;
+ for(; x <= (input_width - vec_size); x += vec_size)
{
- constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
-
- for(; i <= (input_width - vec_size); i += vec_size)
+ auto vec_in = wrapper::vloadq(tmp_ptr + x);
+ auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+ if(is_log)
{
- auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
- vec_16_byte_t<T> normalized_value{};
- if(is_log)
- {
- normalized_value = vsub(vec_in, vdup_n<vec_16_byte_t<T>>(sum));
- }
- else
- {
- normalized_value = vmul_n(vec_in, sum_inversed);
- }
- vst(out_ptr + i, normalized_value);
+ normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
+ }
+ else
+ {
+ normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
}
+ wrapper::vstore(out_ptr + x, normalized_value);
}
/* Run remaining elements */
- for(; i < input_width; ++i)
+ for(; x < input_width; ++x)
{
if(is_log)
{
- out_ptr[i] = tmp_ptr[i] - sum;
+ out_ptr[x] = tmp_ptr[x] - sum;
}
else
{
- out_ptr[i] = tmp_ptr[i] * sum_inversed;
+ out_ptr[x] = tmp_ptr[x] * sum_inversed;
}
}
}