aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
authorDiego Lopez Recas <Diego.LopezRecas@arm.com>2017-12-04 18:56:10 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:00 +0000
commit35ceeb2199c569810a1524a0a21c2df2a3f5f29e (patch)
tree4a55f8626cb2960843547fabdb2431a70ec1029a /src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
parent97cf2497d2b617de3209330893ad51bd0cc126ce (diff)
downloadComputeLibrary-35ceeb2199c569810a1524a0a21c2df2a3f5f29e.tar.gz
IVGCVSW-798 Add Softmax NEON support for QASYMM8
Change-Id: I4f2cca52caf210fdb7d6bb7e9436ac51cb5088b4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112398 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NESoftmaxLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NESoftmaxLayerKernel.cpp1369
1 files changed, 713 insertions, 656 deletions
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index b13fb0e87c..13d87a0989 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,285 +33,433 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/utility.h"
#include <algorithm>
#include <arm_neon.h>
#include <cfloat>
+#include <functional>
-using namespace arm_compute;
-
-namespace
-{
-Status validate_arguments_logits_1d_max(const ITensorInfo *input, const ITensorInfo *output)
+namespace arm_compute
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+template <typename T, int N>
+struct vec_n_type;
- // Checks performed when output is configured
- if(output->total_size() != 0)
- {
- // Softmax across the x dimension
- TensorShape output_shape{ input->tensor_shape() };
- output_shape.set(0, 1);
+#define DECLARE_NEON_VEC_TYPE(T, N, V) \
+ template <> \
+ struct vec_n_type<T, N> \
+ { \
+ using type = V; \
+ };
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
- }
+DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
+DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
- return Status{};
-}
+DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
+DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
-std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo *input, ITensorInfo *output)
-{
- // Configure kernel window
- constexpr unsigned int num_elems_written_per_row = 1;
- const int input_width = input->valid_region().shape.x();
+DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
+DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
- unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
- bool window_changed = false;
+DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
+DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
- if(output->total_size() != 0)
- {
- AccessWindowHorizontal output_access(output, 0, num_elems_written_per_row, 1.f / input_width);
- window_changed = update_window_and_padding(win, input_access, output_access);
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
- }
- else
- {
- window_changed = update_window_and_padding(win, input_access);
- }
+DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
+DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
-}
+DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
+DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
-Status validate_arguments_logits_1d_shift_exp_sum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
-{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, max, sum, output);
- ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->data_type()));
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
+DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- // Checks performed when output is configured
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- }
+DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
+DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
- // Checks performed when sum is configured
- if(sum->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max, sum);
- }
+template <typename T, int N>
+using vec_n_t = typename vec_n_type<T, N>::type;
- return Status{};
-}
+template <typename T, int N>
+using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
+
+template <typename T>
+using vec_16_byte_t = vec_n_byte_t<T, 16>;
+
+template <typename T>
+using vec_8_byte_t = vec_n_byte_t<T, 8>;
+
+template <typename T>
+using const_ptr_t = const T *;
-std::pair<Status, Window> validate_and_configure_window_logits_1d_shift_exp_sum(ITensorInfo *input, ITensorInfo *max, ITensorInfo *output, ITensorInfo *sum)
+template <typename T>
+using ptr_t = T *;
+
+#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
+ template <int lane> \
+ TYPE vget_lane(vec_8_byte_t<TYPE> vec); \
+ template <int lane> \
+ TYPE vget_lane(vec_16_byte_t<TYPE> vec);
+
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
+template <int lane>
+float vget_lane(float32x4x4_t vec);
+
+template <typename V>
+using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
+
+template <typename V>
+constexpr size_t vec_size_of(const V &vec)
{
- unsigned int num_elems_processed_per_iteration = input->valid_region().shape.x();
+ return sizeof(vec) / sizeof(elem_type_t<V>);
+}
- // Configure kernel window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
- AccessWindowHorizontal max_access(max, 0, 1);
- AccessWindowHorizontal sum_access(sum, 0, 1);
- bool window_changed = false;
+template <typename V>
+V vdup_n(elem_type_t<V> val);
+template <typename V>
+V vld(const_ptr_t<elem_type_t<V>> ptr);
+
+#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \
+ template <> \
+ inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val) \
+ { \
+ return vdup_n_##TAG(val); \
+ } \
+ template <> \
+ inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val) \
+ { \
+ return vdupq_n_##TAG(val); \
+ } \
+ template <> \
+ inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
+ { \
+ return vld1_##TAG(ptr); \
+ } \
+ template <> \
+ inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr) \
+ { \
+ return vld1q_##TAG(ptr); \
+ } \
+ inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec) \
+ { \
+ vst1_##TAG(ptr, vec); \
+ } \
+ inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec) \
+ { \
+ vst1q_##TAG(ptr, vec); \
+ } \
+ inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
+ { \
+ return vmaxq_##TAG(a, b); \
+ } \
+ inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
+ { \
+ return vpmax_##TAG(a, b); \
+ } \
+ inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec) \
+ { \
+ return vget_low_##TAG(vec); \
+ } \
+ inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec) \
+ { \
+ return vget_high_##TAG(vec); \
+ } \
+ template <int lane> \
+ inline TYPE vget_lane(vec_8_byte_t<TYPE> vec) \
+ { \
+ static_assert(lane >= 0, "lane is out of bounds"); \
+ static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
+ return vget_lane_##TAG(vec, lane); \
+ } \
+ template <int lane> \
+ inline TYPE vget_lane(vec_16_byte_t<TYPE> vec) \
+ { \
+ static_assert(lane >= 0, "lane is out of bounds"); \
+ static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \
+ return vgetq_lane_##TAG(vec, lane); \
+ }
- if(output->total_size() != 0)
- {
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
- window_changed = update_window_and_padding(win, input_access, max_access, output_access, sum_access);
- output_access.set_valid_region(win, input->valid_region());
+template <typename T>
+T sqadd(T a, T b);
+template <typename T>
+T sqsub(T a, T b);
+template <typename T>
+T sqmul(T a, T b, int fixed_point_position);
+
+#define DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(TYPET, TYPEU, TAGT, TAGU) \
+ inline vec_8_byte_t<TYPET> vqsub(vec_8_byte_t<TYPET> a, vec_8_byte_t<TYPET> b) \
+ { \
+ return vqsub_##TAGT(a, b); \
+ } \
+ inline vec_8_byte_t<TYPEU> vqadd(vec_8_byte_t<TYPEU> a, vec_8_byte_t<TYPEU> b) \
+ { \
+ return vqadd_##TAGU(a, b); \
+ } \
+ inline vec_16_byte_t<TYPEU> vqadd(vec_16_byte_t<TYPEU> a, vec_16_byte_t<TYPEU> b) \
+ { \
+ return vqaddq_##TAGU(a, b); \
+ } \
+ inline vec_8_byte_t<TYPET> vqexp(vec_8_byte_t<TYPET> vec, int fixed_point_position) \
+ { \
+ return vqexp_q##TAGT(vec, fixed_point_position); \
+ } \
+ inline auto vmovl(vec_8_byte_t<TYPET> vec)->decltype(vmovl_##TAGT(vec)) \
+ { \
+ return vmovl_##TAGT(vec); \
+ } \
+ inline vec_16_byte_t<TYPET> vqrecip(vec_16_byte_t<TYPET> vec, int fixed_point_position) \
+ { \
+ return vqrecipq_q##TAGT(vec, fixed_point_position); \
+ } \
+ inline vec_16_byte_t<TYPET> vqmul(vec_16_byte_t<TYPET> a, vec_16_byte_t<TYPET> b, int fixed_point_position) \
+ { \
+ return vqmulq_q##TAGT(a, b, fixed_point_position); \
+ } \
+ template <> \
+ inline TYPEU sqadd<TYPEU>(TYPEU a, TYPEU b) \
+ { \
+ return sqadd_q##TAGU(a, b); \
+ } \
+ inline TYPET sqexp(TYPET val, int fixed_point_position) \
+ { \
+ return sqexp_q##TAGT(val, fixed_point_position); \
+ } \
+ template <> \
+ inline TYPET sqsub<TYPET>(TYPET a, TYPET b) \
+ { \
+ return sqsub_q##TAGT(a, b); \
+ } \
+ template <> \
+ inline TYPET sqmul<TYPET>(TYPET a, TYPET b, int fixed_point_position) \
+ { \
+ return sqmul_q##TAGT(a, b, fixed_point_position); \
}
- else
- {
- window_changed = update_window_and_padding(win, input_access, max_access, sum_access);
+
+#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \
+ inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b) \
+ { \
+ return vadd_##TAG(a, b); \
+ } \
+ inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
+ { \
+ return vaddq_##TAG(a, b); \
+ } \
+ inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
+ { \
+ return vsubq_##TAG(a, b); \
+ } \
+ inline vec_16_byte_t<TYPE> vexp(vec_16_byte_t<TYPE> vec) \
+ { \
+ return vexpq_##TAG(vec); \
+ } \
+ inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val) \
+ { \
+ return vmulq_n_##TAG(vec, val); \
}
- sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->tensor_shape()));
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
-}
+DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int8_t, int16_t, s8, s16)
+DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int16_t, int32_t, s16, s32)
-Status validate_arguments_logits_1d_norm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
-{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
- // Checks performed when output is configured
- if(output->total_size() != 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- }
+template <typename VO, typename VI>
+VO vcvt(VI vec);
- return Status{};
+template <>
+float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
+{
+ const auto low = vmovl_u8(vget_low(vec));
+ const auto high = vmovl_u8(vget_high(vec));
+ float32x4x4_t res = { {
+ vcvtq_f32_u32(vmovl_u16(vget_low(low))),
+ vcvtq_f32_u32(vmovl_u16(vget_high(low))),
+ vcvtq_f32_u32(vmovl_u16(vget_low(high))),
+ vcvtq_f32_u32(vmovl_u16(vget_high(high)))
+ }
+ };
+ return res;
}
-std::pair<Status, Window> validate_and_configure_window_logits_1d_norm(ITensorInfo *input, ITensorInfo *sum, ITensorInfo *output)
+template <>
+uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
{
- // Configure kernel window
- unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
+ uint16x8x2_t resU16 = { {
+ vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
+ vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
+ vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
+ vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
+ }
+ };
- AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
- AccessWindowStatic sum_access(sum, 0, 0, 1, sum->dimension(1));
- bool window_changed = false;
+ uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
+ return res;
+}
- if(output->total_size() != 0)
- {
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+float32x4x4_t vexp(float32x4x4_t vec)
+{
+ float32x4x4_t res = { {
+ vexpq_f32(vec.val[0]),
+ vexpq_f32(vec.val[1]),
+ vexpq_f32(vec.val[2]),
+ vexpq_f32(vec.val[3])
+ }
+ };
+ return res;
+}
- window_changed = update_window_and_padding(win, input_access, sum_access, output_access);
+template <>
+float32x4x4_t vdup_n<float32x4x4_t>(float val)
+{
+ float32x4x4_t res = { {
+ vdupq_n_f32(val),
+ vdupq_n_f32(val),
+ vdupq_n_f32(val),
+ vdupq_n_f32(val)
+ }
+ };
+ return res;
+}
- output_access.set_valid_region(win, input->valid_region());
- }
- else
- {
- window_changed = update_window_and_padding(win, input_access, sum_access);
- }
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
+float32x4x4_t vmul_n(float32x4x4_t vec, float val)
+{
+ float32x4x4_t res = { {
+ vmulq_n_f32(vec.val[0], val),
+ vmulq_n_f32(vec.val[1], val),
+ vmulq_n_f32(vec.val[2], val),
+ vmulq_n_f32(vec.val[3], val)
+ }
+ };
+ return res;
}
-void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
+float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
{
- Window in_slice = window.first_slice_window_1D();
+ float32x4x4_t res = { {
+ vaddq_f32(a.val[0], b.val[0]),
+ vaddq_f32(a.val[1], b.val[1]),
+ vaddq_f32(a.val[2], b.val[2]),
+ vaddq_f32(a.val[3], b.val[3])
+ }
+ };
+ return res;
+}
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window max_slice = window_max.first_slice_window_1D();
+namespace
+{
+Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
+{
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- do
+ // Validate in case of configured output
+ if(output.total_size() != 0)
{
- Iterator input(in, in_slice);
- Iterator output(out, max_slice);
-
- qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
-
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
- const qint8x16_t current_value = vld1q_qs8(in_ptr);
- vec_max = vmaxq_qs8(vec_max, current_value);
- },
- input);
-
- qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
- carry_max = vpmax_qs8(carry_max, carry_max);
- carry_max = vpmax_qs8(carry_max, carry_max);
- carry_max = vpmax_qs8(carry_max, carry_max);
-
- *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
}
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+
+ return Status{};
}
-void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
+
+std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
{
- Window in_slice = window.first_slice_window_1D();
+ // Softmax across the x dimension
+ const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
+ // Output auto initialization if not yet initialized
+ auto_init_if_empty(output, output_shape, 1, input.data_type(), input.fixed_point_position(), input.quantization_info());
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window max_slice = window_max.first_slice_window_1D();
+ // Configure kernel window
+ const int input_width = input.valid_region().shape.x();
+ const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
+ const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
- do
- {
- Iterator input(in, in_slice);
- Iterator output(out, max_slice);
+ const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
+ output.set_valid_region(out_valid_region);
- qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
+ Window win = calculate_max_window(output);
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
- const qint16x8_t current_value = vld1q_qs16(in_ptr);
- vec_max = vmaxq_qs16(vec_max, current_value);
- },
- input);
+ AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
+ AccessWindowHorizontal output_access(&output, 0, 1);
- qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
- carry_max = vpmax_qs16(carry_max, carry_max);
- carry_max = vpmax_qs16(carry_max, carry_max);
+ const bool window_changed = update_window_and_padding(win, input_access, output_access);
- *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+ const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
+template <typename V>
+auto reduce_max(V vec) -> elem_type_t<V>
{
- Window in_slice = window.first_slice_window_1D();
+ constexpr int N = vec_size_of(vec);
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window max_slice = window_max.first_slice_window_1D();
+ auto carry_max = vpmax(vget_high(vec), vget_low(vec));
- do
+ for(int k = N / 2; k > 1; k /= 2)
{
- Iterator input(in, in_slice);
- Iterator output(out, max_slice);
-
- float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
-
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
- const float16x8_t current_value = vld1q_f16(in_ptr);
- vec_max = vmaxq_f16(vec_max, current_value);
- },
- input);
-
- float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
- carry_max = vpmax_f16(carry_max, carry_max);
- carry_max = vpmax_f16(carry_max, carry_max);
-
- *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
+ carry_max = vpmax(carry_max, carry_max);
}
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+
+ return vget_lane<0>(carry_max);
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
+template <typename T>
+void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
{
- Window in_slice = window.first_slice_window_1D();
+ const auto start_x = in.info()->valid_region().anchor.x();
+ const size_t input_width = in.info()->valid_region().shape.x();
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window max_slice = window_max.first_slice_window_1D();
+ Iterator input(&in, window);
+ Iterator output(&out, window);
- do
+ execute_window_loop(window, [&](const Coordinates &)
{
- Iterator input(in, in_slice);
- Iterator output(out, max_slice);
+ // Get pointers
+ const auto in_ptr = reinterpret_cast<const T *>(input.ptr()) + start_x;
+ const auto out_ptr = reinterpret_cast<T *>(output.ptr());
- float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
+ // Init max value
+ auto vec_max = vdup_n<vec_16_byte_t<T>>(std::numeric_limits<T>::lowest());
- execute_window_loop(in_slice, [&](const Coordinates & id)
+ // Loop over input row
+ for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
{
- const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
- const float32x4_t current_value = vld1q_f32(in_ptr);
- vec_max = vmaxq_f32(vec_max, current_value);
- },
- input);
-
- float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
- carry_max = vpmax_f32(carry_max, carry_max);
+ const auto current_value = vld<vec_16_byte_t<T>>(it);
+ vec_max = vmax(vec_max, current_value);
+ }
- *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+ const T max_val = reduce_max(vec_max);
+ *out_ptr = max_val;
+ },
+ input, output);
}
} // namespace
@@ -328,54 +476,54 @@ BorderSize NELogits1DMaxKernel::border_size() const
void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-
- // Softmax across the x dimension
- TensorShape output_shape{ input->info()->tensor_shape() };
- output_shape.set(0, 1);
-
- // Output auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
-
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(input->info(), output->info()));
-
- const int input_width = input->info()->valid_region().shape.x();
- unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
+ // Configure kernel window
+ auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
switch(input->info()->data_type())
{
+ case DataType::QASYMM8:
+ _func = &logits_1d_max<qasymm8_t>;
+ break;
case DataType::QS8:
- _func = &logits_1d_max_qs8;
+ _func = &logits_1d_max<qint8_t>;
break;
case DataType::QS16:
- _func = &logits_1d_max_qs16;
+ _func = &logits_1d_max<qint16_t>;
break;
- case DataType::F32:
- _func = &logits_1d_max_f32;
- break;
- case DataType::F16:
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- _func = &logits_1d_max_f16;
+ case DataType::F16:
+ _func = &logits_1d_max<float16_t>;
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ case DataType::F32:
+ _func = &logits_1d_max<float>;
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported data type.");
}
- _input = input;
- _output = output;
- _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
+ _input = input;
+ _output = output;
+
+ const int input_width = input->info()->valid_region().shape.x();
+ const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
+ const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
+
+ _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
- // Configure kernel window
- auto win_config = validate_and_configure_window_logits_1d_max(input->info(), output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
}
Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(input, output));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(input->clone().get(), output->clone().get()).first);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
return Status{};
}
@@ -387,297 +535,393 @@ void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_func == nullptr);
- (*_func)(_input, _output, window);
+ (*_func)(*_input, *_output, window);
}
namespace
{
-void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
+Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
+ const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
{
- ARM_COMPUTE_UNUSED(beta);
+ // Check input
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+ const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
- Window max_slice = window_max.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
+ // Check max
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &max);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
- constexpr int step = 8;
- const int long_steps = in->info()->valid_region().shape.x() / step;
- const int small_steps = in->info()->valid_region().shape.x() % step;
- const int fixed_point_position = in->info()->fixed_point_position();
+ // Check output if configured
+ if(output.total_size() != 0)
+ {
+ const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
+ }
+
+ // Check beta
+ ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input.data_type()));
- do
+ // Check tmp if configured
+ if(tmp.total_size() != 0)
{
- Iterator input(in, in_slice);
- Iterator exp(out, in_slice);
- Iterator _max(max, max_slice);
- Iterator _sum(sum, max_slice);
+ const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
+ ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &tmp);
+ // We could potentially reduce tmp memory if we could predict or make an assumption
+ // on the maximum number of threads that will run in parallel.
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
+ }
- // Get pointers
- auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
- auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
+ return Status{};
+}
- // Init sum to zero
- qint16x8_t vec_sum_value = vdupq_n_qs16(0);
+std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
+ ITensorInfo &output, ITensorInfo &tmp)
+{
+ const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
- // Get max value
- const auto max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
- const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
+ // Output auto initialization if not yet initialized
+ const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
+ auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
- // Run neon loop
- for(int i = 0; i < long_steps; ++i)
- {
- qint8x8_t vec_elements = vld1_qs8(in_ptr);
- vec_elements = vqsub_qs8(vec_elements, vec_max);
- vec_elements = vqexp_qs8(vec_elements, fixed_point_position);
+ // Tmp auto initialization if not yet initialized
+ const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
+ auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
- vst1_qs8(exp_ptr, vec_elements);
- vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
+ const int input_width = input.valid_region().shape.x();
- in_ptr += step;
- exp_ptr += step;
- }
- // Reduce sum
- const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
- const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
- const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
- qint16_t sum = sqadd_qs16(sum0, sum1);
-
- // Run remaining elements
- for(int i = 0; i < small_steps; ++i)
- {
- qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
- exp_ptr[i] = element;
- sum = sqadd_qs16(sum, element);
- }
+ Window win = calculate_max_window(max);
- *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
-void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
-{
- ARM_COMPUTE_UNUSED(beta);
+ AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
+ AccessWindowHorizontal max_access(&input, 0, 1);
+ AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
+ AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+ const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
- Window max_slice = window_max.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
+ output.set_valid_region(input.valid_region());
- constexpr int step = 4;
- const int long_steps = in->info()->valid_region().shape.x() / step;
- const int small_steps = in->info()->valid_region().shape.x() % step;
- const int fixed_point_position = in->info()->fixed_point_position();
+ const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+ return std::make_pair(err, win);
+}
- do
+template <typename T, int N, int S, int E>
+struct reduce_add_impl
+{
+ template <typename F>
+ static T reduce(F add_fn, vec_n_t<T, N> vec)
+ {
+ constexpr int H = (S + E + 1) / 2;
+ const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
+ const auto reduced_low = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
+ return add_fn(reduced_high, reduced_low);
+ }
+};
+template <typename T, int N, int I>
+struct reduce_add_impl<T, N, I, I>
+{
+ template <typename F>
+ static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
{
- Iterator input(in, in_slice);
- Iterator exp(out, in_slice);
- Iterator _max(max, max_slice);
- Iterator _sum(sum, max_slice);
+ return vget_lane<I>(vec);
+ }
+};
+template <typename V, typename F>
+elem_type_t<V> reduce_add(F add_fn, V vec)
+{
+ constexpr int N = vec_size_of(vec);
+ return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
+}
- // Get pointers
- auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
- auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
+void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
+{
+ const int start_x = in.info()->valid_region().anchor.x();
+ const int input_width = in.info()->valid_region().shape.x();
- // Init sum to zero
- qint32x4_t vec_sum_value = vdupq_n_qs32(0);
+ const float scale_beta = -beta * in.info()->quantization_info().scale;
- // Get max value
- const auto max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
- const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
+ Iterator in_it(&in, window);
+ Iterator max_it(&max, window);
+ Iterator out_it(&out, window);
- // Run neon loop
- for(int i = 0; i < long_steps; ++i)
- {
- qint16x4_t vec_elements = vld1_qs16(in_ptr);
- vec_elements = vqsub_qs16(vec_elements, vec_max);
- vec_elements = vqexp_qs16(vec_elements, fixed_point_position);
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ /* Get pointers */
+ const auto in_ptr = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
+ const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
+ const auto tmp_ptr = reinterpret_cast<float *>(tmp);
- vst1_qs16(exp_ptr, vec_elements);
- vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
+ float sum_inversed;
- in_ptr += step;
- exp_ptr += step;
+ /* Compute exponentials and sum */
+ {
+ /* Get max value */
+ const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
+ const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
+
+ /* Init sum to zero */
+ auto vec_sum = vdup_n<float32x4x4_t>(0.f);
+
+ /* Loop over row and compute exponentials and sum */
+ int i = 0;
+ constexpr int vec_size = vec_size_of(vec_max);
+ for(; i <= (input_width - vec_size); i += vec_size)
+ {
+ auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
+ vec_elements = vsubq_u8(vec_max, vec_elements);
+
+ auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
+ vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
+
+ vec_sum = vadd(vec_sum, vec_elements_flt);
+
+ vst4q_f32(tmp_ptr + i, vec_elements_flt);
+ }
+ /* Reduce sum */
+ const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
+ vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
+ const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
+ float sum = reduce_add(std::plus<float>(), sum_8_byte);
+
+ /* Run remaining elements */
+ for(; i < input_width; ++i)
+ {
+ const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
+ sum += element;
+ tmp_ptr[i] = element;
+ }
+
+ sum_inversed = 256.f / sum;
}
- // Reduce sum
- qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
- qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
- // Run remaining elements
- for(int i = 0; i < small_steps; ++i)
+ /* Normalize exponentials */
{
- qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
- exp_ptr[i] = element;
- sum = sqadd_qs32(sum, element);
+ /* Loop over row and compute softmax */
+ int i = 0;
+ {
+ constexpr int vec_size = 16;
+ for(; i <= (input_width - vec_size); i += vec_size)
+ {
+ float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
+ auto normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
+ vst(out_ptr + i, normalized_value);
+ }
+ }
+ /* Run remaining elements */
+ for(; i < input_width; ++i)
+ {
+ out_ptr[i] = utility::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+ }
}
-
- *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+ },
+ in_it, max_it, out_it);
}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
+template <typename T, typename U>
+void logits_1d_softmax_fixed_point(const ITensor &in, const ITensor &max, void *const tmp,
+ ITensor &out, const float /*beta*/, const Window &window)
{
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+ const int start_x = in.info()->valid_region().anchor.x();
+ const int input_width = in.info()->valid_region().shape.x();
- Window max_slice = window_max.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
+ const int fixed_point_position = in.info()->fixed_point_position();
- constexpr int step = 8;
- const int long_steps = in->info()->valid_region().shape.x() / step;
- const int small_steps = in->info()->valid_region().shape.x() % step;
+ Iterator in_it(&in, window);
+ Iterator max_it(&max, window);
+ Iterator out_it(&out, window);
- do
+ execute_window_loop(window, [&](const Coordinates &)
{
- Iterator input(in, in_slice);
- Iterator exp(out, in_slice);
- Iterator _max(max, max_slice);
- Iterator _sum(sum, max_slice);
-
- // Get pointers
- auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
- auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
+ /* Get pointers */
+ const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
+ const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
+ const auto tmp_ptr = reinterpret_cast<T *>(tmp);
- // Init sum to zero
- float16x8_t vec_sum_value = vdupq_n_f16(0);
+ vec_16_byte_t<T> vec_sum_inversed;
- // Get max value
- const auto max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
- const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
-
- // Run neon loop
- for(int i = 0; i < long_steps; ++i)
+ /* Compute exponentials and sum */
{
- float16x8_t vec_elements = vld1q_f16(in_ptr);
- vec_elements = vsubq_f16(vec_elements, vec_max);
- vec_elements = vmulq_n_f16(vec_elements, beta);
- vec_elements = vexpq_f16(vec_elements);
-
- vst1q_f16(exp_ptr, vec_elements);
- vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
-
- in_ptr += step;
- exp_ptr += step;
+ /* Get max value */
+ const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
+ const auto vec_max = vdup_n<vec_8_byte_t<T>>(max_val);
+
+ /* Init sum to zero */
+ auto vec_sum = vdup_n<vec_16_byte_t<U>>(0);
+
+ /* Loop over row and compute exponentials and sum */
+ int i = 0;
+ constexpr int vec_size = vec_size_of(vec_sum);
+ for(; i <= (input_width - vec_size); i += vec_size)
+ {
+ auto vec_elements = vld<vec_8_byte_t<T>>(in_ptr + i);
+ vec_elements = vqsub(vec_elements, vec_max);
+ vec_elements = vqexp(vec_elements, fixed_point_position);
+ vec_sum = vqadd(vec_sum, vmovl(vec_elements));
+ vst(tmp_ptr + i, vec_elements);
+ }
+ /* Reduce sum */
+ const vec_8_byte_t<U> sum_8_byte = vqadd(vget_high(vec_sum), vget_low(vec_sum));
+ U sum = reduce_add(sqadd<U>, sum_8_byte);
+
+ /* Run remaining elements */
+ for(; i < input_width; ++i)
+ {
+ T element = sqexp(sqsub(in_ptr[i], max_val), fixed_point_position);
+ sum = sqadd<U>(sum, element);
+ tmp_ptr[i] = element;
+ }
+
+ const auto qsum = utility::saturate_cast<T>(sum);
+ vec_sum_inversed = vqrecip(vdup_n<vec_16_byte_t<T>>(qsum), fixed_point_position);
}
- // Reduce sum
- const float16x4_t sum_red = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
- const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
- float16_t sum = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
- // Run remaining elements
- for(int i = 0; i < small_steps; ++i)
+ /* Normalize exponentials */
{
- const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
- exp_ptr[i] = element;
- sum += element;
+ /* Loop over row and compute softmax */
+ int i = 0;
+ constexpr int vec_size = vec_size_of(vec_sum_inversed);
+ for(; i <= (input_width - vec_size); i += vec_size)
+ {
+ const auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
+ const vec_16_byte_t<T> normalized_value = vqmul(vec_in, vec_sum_inversed, fixed_point_position);
+ vst(out_ptr + i, normalized_value);
+ }
+
+ const T sum_inversed = vget_lane<0>(vec_sum_inversed);
+
+ /* Run remaining elements */
+ for(; i < input_width; ++i)
+ {
+ out_ptr[i] = sqmul(tmp_ptr[i], sum_inversed, fixed_point_position);
+ }
}
- *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+ },
+ in_it, max_it, out_it);
}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
+template <typename T>
+void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
+ ITensor &out, const float beta, const Window &window)
{
- Window window_max(window);
- window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+ const int start_x = in.info()->valid_region().anchor.x();
+ const int input_width = in.info()->valid_region().shape.x();
- Window max_slice = window_max.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
+ Iterator in_it(&in, window);
+ Iterator max_it(&max, window);
+ Iterator out_it(&out, window);
- constexpr int step = 4;
- const int long_steps = in->info()->valid_region().shape.x() / step;
- const int small_steps = in->info()->valid_region().shape.x() % step;
-
- do
+ execute_window_loop(window, [&](const Coordinates &)
{
- Iterator input(in, in_slice);
- Iterator exp(out, in_slice);
- Iterator _max(max, max_slice);
- Iterator _sum(sum, max_slice);
-
- // Get pointers
- auto in_ptr = reinterpret_cast<const float *>(input.ptr());
- auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
-
- // Init sum to zero
- float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
+ /* Get pointers */
+ const auto in_ptr = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
+ const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
+ const auto tmp_ptr = reinterpret_cast<T *>(tmp);
- // Get max value
- const auto max_ptr = reinterpret_cast<const float *>(_max.ptr());
- const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
+ T sum_inversed;
- // Run neon loop
- for(int i = 0; i < long_steps; ++i)
+ /* Compute exponentials and sum */
{
- float32x4_t vec_elements = vld1q_f32(in_ptr);
- vec_elements = vsubq_f32(vec_elements, vec_max);
- vec_elements = vmulq_n_f32(vec_elements, beta);
- vec_elements = vexpq_f32(vec_elements);
-
- vst1q_f32(exp_ptr, vec_elements);
- vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
-
- in_ptr += step;
- exp_ptr += step;
+ /* Get max value */
+ const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
+ const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
+
+ /* Init sum to zero */
+ auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
+
+ /* Loop over row and compute exponentials and sum */
+ int i = 0;
+ constexpr int vec_size = vec_size_of(vec_sum);
+ for(; i <= (input_width - vec_size); i += vec_size)
+ {
+ auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
+ vec_elements = vsub(vec_elements, vec_max);
+ vec_elements = vexp(vmul_n(vec_elements, beta));
+ vec_sum = vadd(vec_sum, vec_elements);
+ vst(tmp_ptr + i, vec_elements);
+ }
+ /* Reduce sum */
+ const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
+ T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
+
+ /* Run remaining elements */
+ for(; i < input_width; ++i)
+ {
+ T element = std::exp((in_ptr[i] - max_val) * beta);
+ sum += element;
+ tmp_ptr[i] = element;
+ }
+
+ sum_inversed = T(1) / sum;
}
- // Reduce sum
- float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
- carry_addition = vpadd_f32(carry_addition, carry_addition);
- float sum = vget_lane_f32(carry_addition, 0);
-
- // Run remaining elements
- for(int i = 0; i < small_steps; ++i)
+ /* Normalize exponentials */
{
- float element = std::exp((in_ptr[i] - *max_ptr) * beta);
- exp_ptr[i] = element;
- sum += element;
+ /* Loop over row and compute softmax */
+ int i = 0;
+ {
+ constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
+ for(; i <= (input_width - vec_size); i += vec_size)
+ {
+ auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
+ vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
+ vst(out_ptr + i, normalized_value);
+ }
+ }
+ /* Run remaining elements */
+ for(; i < input_width; ++i)
+ {
+ out_ptr[i] = tmp_ptr[i] * sum_inversed;
+ }
}
-
- *(reinterpret_cast<float *>(_sum.ptr())) = sum;
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+ },
+ in_it, max_it, out_it);
}
-} //namespace
+} // namespace
-NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
- : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
+NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
+ : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
{
}
-void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
+void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
-
- // Output auto initialization if not yet initialized
- auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
-
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info(), beta));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
+ // Configure kernel window
+ auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
switch(input->info()->data_type())
{
+ case DataType::QASYMM8:
+ _func = &logits_1d_softmax_qasymm8;
+ break;
case DataType::QS8:
- _func = &logits_1d_shift_exp_sum_qs8;
+ _func = &logits_1d_softmax_fixed_point<qint8_t, qint16_t>;
break;
case DataType::QS16:
- _func = &logits_1d_shift_exp_sum_qs16;
- break;
- case DataType::F32:
- _func = &logits_1d_shift_exp_sum_f32;
+ _func = &logits_1d_softmax_fixed_point<qint16_t, qint32_t>;
break;
- case DataType::F16:
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- _func = &logits_1d_shift_exp_sum_f16;
+ case DataType::F16:
+ _func = &logits_1d_softmax_float<float16_t>;
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ case DataType::F32:
+ _func = &logits_1d_softmax_float<float>;
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported data type.");
break;
@@ -686,224 +930,37 @@ void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor
_input = input;
_max = max;
_output = output;
- _sum = sum;
_beta = beta;
+ _tmp = tmp;
- // Configure kernel window
- auto win_config = validate_and_configure_window_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
}
-Status NELogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
+Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
+ const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_shift_exp_sum(input, max, output, sum, beta));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_shift_exp_sum(input->clone().get(), max->clone().get(), output->clone().get(), sum->clone().get()).first);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
return Status{};
}
-void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &info)
+void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- ARM_COMPUTE_ERROR_ON(_func == nullptr);
- (*_func)(_input, _max, _output, _sum, window, _beta);
-}
+ const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
+ const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration;
-namespace
-{
-void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
- Window window_sum(window);
- window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window sum_slice = window_sum.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
+ ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
- const int fixed_point_position = in->info()->fixed_point_position();
+ void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- do
- {
- Iterator input(in, in_slice);
- Iterator _sum(sum, sum_slice);
- Iterator output(out, in_slice);
-
- const int8_t sum_value = *reinterpret_cast<const qint8_t *>(_sum.ptr());
- const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
-
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const qint8_t *>(input.ptr());
- const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
-
- const qint8x16_t vec_in = vld1q_qs8(in_ptr);
- const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
-
- vst1q_qs8(out_ptr, normalized_value);
- },
- input, output);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
+ (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
}
-void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
- Window window_sum(window);
- window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window sum_slice = window_sum.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
- const int fixed_point_position = in->info()->fixed_point_position();
-
- do
- {
- Iterator input(in, in_slice);
- Iterator _sum(sum, sum_slice);
- Iterator output(out, in_slice);
-
- const int16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
- const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
-
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const qint16_t *>(input.ptr());
- const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
-
- const qint16x8_t vec_in = vld1q_qs16(in_ptr);
- const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
-
- vst1q_qs16(out_ptr, normalized_value);
- },
- input, output);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
- Window window_sum(window);
- window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window sum_slice = window_sum.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
-
- do
- {
- Iterator input(in, in_slice);
- Iterator _sum(sum, sum_slice);
- Iterator output(out, in_slice);
-
- const float16_t sum_value = *reinterpret_cast<const qint16_t *>(_sum.ptr());
- const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
-
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const float16_t *>(input.ptr());
- const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
-
- const float16x8_t vec_in = vld1q_f16(in_ptr);
- const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
-
- vst1q_f16(out_ptr, normalized_value);
- },
- input, output);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
- Window window_sum(window);
- window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
- Window sum_slice = window_sum.first_slice_window_1D();
- Window in_slice = window.first_slice_window_1D();
-
- do
- {
- Iterator input(in, in_slice);
- Iterator _sum(sum, sum_slice);
- Iterator output(out, in_slice);
-
- const float sum_value = *reinterpret_cast<const float *>(_sum.ptr());
- const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
-
- execute_window_loop(in_slice, [&](const Coordinates & id)
- {
- const auto in_ptr = reinterpret_cast<const float *>(input.ptr());
- const auto out_ptr = reinterpret_cast<float *>(output.ptr());
-
- const float32x4_t vec_in = vld1q_f32(in_ptr);
- const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
-
- vst1q_f32(out_ptr, normalized_value);
- },
- input, output);
- }
- while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-} // namespace
-
-NELogits1DNormKernel::NELogits1DNormKernel()
- : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
-{
-}
-
-void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
-{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
-
- // Output auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
-
- // Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_norm(input->info(), sum->info(), output->info()));
-
- _input = input;
- _sum = sum;
- _output = output;
-
- switch(input->info()->data_type())
- {
- case DataType::QS8:
- _func = &logits_1d_norm_qs8;
- break;
- case DataType::QS16:
- _func = &logits_1d_norm_qs16;
- break;
- case DataType::F32:
- _func = &logits_1d_norm_f32;
- break;
- case DataType::F16:
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- _func = &logits_1d_norm_f16;
- break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("Unsupported data type.");
- break;
- }
-
- // Configure kernel window
- auto win_config = validate_and_configure_window_logits_1d_norm(input->info(), sum->info(), output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- INEKernel::configure(win_config.second);
-}
-
-Status NELogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
-{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_norm(input, sum, output));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_norm(input->clone().get(), sum->clone().get(), output->clone().get()).first);
-
- return Status{};
-}
-
-void NELogits1DNormKernel::run(const Window &window, const ThreadInfo &info)
-{
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- ARM_COMPUTE_ERROR_ON(_func == nullptr);
-
- (*_func)(_input, _sum, _output, window);
-}
+} // namespace arm_compute