/* * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to * deal in the Software without restriction, including without limitation the * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or * sell copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h" #include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/CPP/Validate.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/NEFixedPoint.h" #include "arm_compute/core/NEON/NEMath.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" #include "arm_compute/core/utils/misc/SaturateCast.h" #include #include #include #include namespace arm_compute { template struct vec_n_type; #define DECLARE_NEON_VEC_TYPE(T, N, V) \ template <> \ struct vec_n_type \ { \ using type = V; \ }; DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t) DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t) DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t) DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t) DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t) DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t) DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t) DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t) DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t) DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t) DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t) DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t) DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t) #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t) DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t) template using vec_n_t = typename vec_n_type::type; template using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >; template using vec_16_byte_t = vec_n_byte_t; template using vec_8_byte_t = vec_n_byte_t; template using const_ptr_t = const T *; template using ptr_t = T *; #define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \ template \ TYPE vget_lane(vec_8_byte_t vec); \ template \ TYPE vget_lane(vec_16_byte_t vec); FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t) FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t) FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t) FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t) FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t) FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t) #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float) template float vget_lane(float32x4x4_t vec); template using elem_type_t = decltype(vget_lane<0>(std::declval())); template constexpr size_t vec_size_of(const V &vec) { return sizeof(vec) / sizeof(elem_type_t); } template V vdup_n(elem_type_t val); template V vld(const_ptr_t> ptr); #define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \ template <> \ inline vec_8_byte_t vdup_n>(TYPE val) \ { \ return vdup_n_##TAG(val); \ } \ template <> \ inline vec_16_byte_t vdup_n>(TYPE val) \ { \ return vdupq_n_##TAG(val); \ } \ template <> \ inline vec_8_byte_t vld>(const_ptr_t ptr) \ { \ return vld1_##TAG(ptr); \ } \ template <> \ inline vec_16_byte_t vld>(const_ptr_t ptr) \ { \ return vld1q_##TAG(ptr); \ } \ inline void vst(ptr_t ptr, vec_8_byte_t vec) \ { \ vst1_##TAG(ptr, vec); \ } \ inline void vst(ptr_t ptr, vec_16_byte_t vec) \ { \ vst1q_##TAG(ptr, vec); \ } \ inline vec_16_byte_t vmax(vec_16_byte_t a, vec_16_byte_t b) \ { \ return vmaxq_##TAG(a, b); \ } \ inline vec_8_byte_t vpmax(vec_8_byte_t a, vec_8_byte_t b) \ { \ return vpmax_##TAG(a, b); \ } \ inline vec_8_byte_t vget_low(vec_16_byte_t vec) \ { \ return vget_low_##TAG(vec); \ } \ inline vec_8_byte_t vget_high(vec_16_byte_t vec) \ { \ return vget_high_##TAG(vec); \ } \ template \ inline TYPE vget_lane(vec_8_byte_t vec) \ { \ static_assert(lane >= 0, "lane is out of bounds"); \ static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \ return vget_lane_##TAG(vec, lane); \ } \ template \ inline TYPE vget_lane(vec_16_byte_t vec) \ { \ static_assert(lane >= 0, "lane is out of bounds"); \ static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \ return vgetq_lane_##TAG(vec, lane); \ } template T sqadd(T a, T b); template T sqsub(T a, T b); template T sqmul(T a, T b); #define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \ inline vec_8_byte_t vadd(vec_8_byte_t a, vec_8_byte_t b) \ { \ return vadd_##TAG(a, b); \ } \ inline vec_16_byte_t vadd(vec_16_byte_t a, vec_16_byte_t b) \ { \ return vaddq_##TAG(a, b); \ } \ inline vec_16_byte_t vsub(vec_16_byte_t a, vec_16_byte_t b) \ { \ return vsubq_##TAG(a, b); \ } \ inline vec_16_byte_t vmul_n(vec_16_byte_t vec, TYPE val) \ { \ return vmulq_n_##TAG(vec, val); \ } DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8) DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8) DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16) DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16) DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32) DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16) #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16) #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32) template VO vcvt(VI vec); template <> float32x4x4_t vcvt(uint8x16_t vec) { const auto low = vmovl_u8(vget_low(vec)); const auto high = vmovl_u8(vget_high(vec)); float32x4x4_t res = { { vcvtq_f32_u32(vmovl_u16(vget_low(low))), vcvtq_f32_u32(vmovl_u16(vget_high(low))), vcvtq_f32_u32(vmovl_u16(vget_low(high))), vcvtq_f32_u32(vmovl_u16(vget_high(high))) } }; return res; } template <> uint8x16_t vcvt(float32x4x4_t vec) { uint16x8x2_t resU16 = { { vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])), vqmovn_u32(vcvtq_u32_f32(vec.val[1]))), vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])), vqmovn_u32(vcvtq_u32_f32(vec.val[3]))) } }; uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1])); return res; } float32x4x4_t vexp(float32x4x4_t vec) { float32x4x4_t res = { { vexpq_f32(vec.val[0]), vexpq_f32(vec.val[1]), vexpq_f32(vec.val[2]), vexpq_f32(vec.val[3]) } }; return res; } float32x4_t vexp(const float32x4_t &vec) { return vexpq_f32(vec); } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC // TODO (COMPMID-1535) : Revisit FP16 approximations float16x8_t vexp(const float16x8_t &vec) { float16x4x2_t res = { { vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_low_f16(vec)))), vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_high_f16(vec)))) } }; return vcombine_f16(res.val[0], res.val[1]); } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ template <> float32x4x4_t vdup_n(float val) { float32x4x4_t res = { { vdupq_n_f32(val), vdupq_n_f32(val), vdupq_n_f32(val), vdupq_n_f32(val) } }; return res; } float32x4x4_t vmul_n(float32x4x4_t vec, float val) { float32x4x4_t res = { { vmulq_n_f32(vec.val[0], val), vmulq_n_f32(vec.val[1], val), vmulq_n_f32(vec.val[2], val), vmulq_n_f32(vec.val[3], val) } }; return res; } float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b) { float32x4x4_t res = { { vaddq_f32(a.val[0], b.val[0]), vaddq_f32(a.val[1], b.val[1]), vaddq_f32(a.val[2], b.val[2]), vaddq_f32(a.val[3], b.val[3]) } }; return res; } namespace { Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); // Validate in case of configured output if(output.total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1)); } return Status{}; } std::pair validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output) { // Softmax across the x dimension const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1); // Output auto initialization if not yet initialized auto_init_if_empty(output, output_shape, 1, input.data_type(), input.quantization_info()); // Configure kernel window const int input_width = input.valid_region().shape.x(); const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type()); const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration); const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1)); output.set_valid_region(out_valid_region); Window win = calculate_max_window(output); AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration); AccessWindowHorizontal output_access(&output, 0, 1); const bool window_changed = update_window_and_padding(win, input_access, output_access); const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; return std::make_pair(err, win); } template auto reduce_max(V vec) -> elem_type_t { constexpr int N = vec_size_of(vec); auto carry_max = vpmax(vget_high(vec), vget_low(vec)); for(int k = N / 2; k > 1; k /= 2) { carry_max = vpmax(carry_max, carry_max); } return vget_lane<0>(carry_max); } template void logits_1d_max(const ITensor &in, ITensor &out, const Window &window) { const auto start_x = in.info()->valid_region().anchor.x(); const size_t input_width = in.info()->valid_region().shape.x(); Iterator input(&in, window); Iterator output(&out, window); execute_window_loop(window, [&](const Coordinates &) { // Get pointers const auto in_ptr = reinterpret_cast(input.ptr()) + start_x; const auto out_ptr = reinterpret_cast(output.ptr()); // Init max value auto vec_max = vdup_n>(support::cpp11::lowest()); // Loop over input row for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max)) { const auto current_value = vld>(it); vec_max = vmax(vec_max, current_value); } const T max_val = reduce_max(vec_max); *out_ptr = max_val; }, input, output); } } // namespace NELogits1DMaxKernel::NELogits1DMaxKernel() : _func(nullptr), _border_size() { } BorderSize NELogits1DMaxKernel::border_size() const { return _border_size; } void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info()); // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info())); // Configure kernel window auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info()); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); switch(input->info()->data_type()) { case DataType::QASYMM8: _func = &logits_1d_max; break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = &logits_1d_max; break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: _func = &logits_1d_max; break; default: ARM_COMPUTE_ERROR("Unsupported data type."); } _input = input; _output = output; const int input_width = input->info()->valid_region().shape.x(); const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type()); const int num_elems_read_per_iteration = ceil_to_multiple(input_width, num_elems_processed_per_iteration); _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0); INEKernel::configure(win_config.second); } Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first); return Status{}; } void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); ARM_COMPUTE_ERROR_ON(_func == nullptr); (*_func)(*_input, *_output, window); } namespace { Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max, const ITensorInfo &output, const float beta, const ITensorInfo &tmp) { ARM_COMPUTE_UNUSED(beta); // Check input ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type()); // Check max ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max); // Check output if configured if(output.total_size() != 0) { const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info(); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output); ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization); } // Check tmp if configured if(tmp.total_size() != 0) { const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type(); ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type); // We could potentially reduce tmp memory if we could predict or make an assumption // on the maximum number of threads that will run in parallel. ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp); } return Status{}; } std::pair validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max, ITensorInfo &output, ITensorInfo &tmp) { const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type()); // Output auto initialization if not yet initialized const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info(); auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding()); // Tmp auto initialization if not yet initialized const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type(); auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding()); const int input_width = input.valid_region().shape.x(); Window win = calculate_max_window(max); AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width); AccessWindowHorizontal max_access(&input, 0, 1); AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width); AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width); const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access); output.set_valid_region(input.valid_region()); const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; return std::make_pair(err, win); } template struct reduce_add_impl { template static T reduce(F add_fn, vec_n_t vec) { constexpr int H = (S + E + 1) / 2; const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec); const auto reduced_low = reduce_add_impl::reduce(add_fn, vec); return add_fn(reduced_high, reduced_low); } }; template struct reduce_add_impl { template static T reduce(F /*add_fn*/, vec_n_t vec) { return vget_lane(vec); } }; template elem_type_t reduce_add(F add_fn, V vec) { constexpr int N = vec_size_of(vec); return reduce_add_impl < elem_type_t, N, 0, N - 1 >::reduce(add_fn, vec); } void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window) { const int start_x = in.info()->valid_region().anchor.x(); const int input_width = in.info()->valid_region().shape.x(); const float scale_beta = -beta * in.info()->quantization_info().uniform().scale; Iterator in_it(&in, window); Iterator max_it(&max, window); Iterator out_it(&out, window); execute_window_loop(window, [&](const Coordinates &) { /* Get pointers */ const auto in_ptr = reinterpret_cast(in_it.ptr()) + start_x; const auto out_ptr = reinterpret_cast(out_it.ptr()) + start_x; const auto tmp_ptr = reinterpret_cast(tmp); float sum_inversed; /* Compute exponentials and sum */ { /* Get max value */ const auto max_val = *reinterpret_cast(max_it.ptr()); const auto vec_max = vdup_n>(max_val); /* Init sum to zero */ auto vec_sum = vdup_n(0.f); /* Loop over row and compute exponentials and sum */ int i = 0; constexpr int vec_size = vec_size_of(vec_max); for(; i <= (input_width - vec_size); i += vec_size) { auto vec_elements = vld>(in_ptr + i); vec_elements = vsubq_u8(vec_max, vec_elements); auto vec_elements_flt = vcvt(vec_elements); vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta)); vec_sum = vadd(vec_sum, vec_elements_flt); vst4q_f32(tmp_ptr + i, vec_elements_flt); } /* Reduce sum */ const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3])); const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte)); float sum = reduce_add(std::plus(), sum_8_byte); /* Run remaining elements */ for(; i < input_width; ++i) { const float element = std::exp((max_val - in_ptr[i]) * scale_beta); sum += element; tmp_ptr[i] = element; } sum_inversed = 256.f / sum; } /* Normalize exponentials */ { /* Loop over row and compute softmax */ int i = 0; { constexpr int vec_size = 16; for(; i <= (input_width - vec_size); i += vec_size) { float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i); auto normalized_value = vcvt>(vmul_n(vec_in, sum_inversed)); vst(out_ptr + i, normalized_value); } } /* Run remaining elements */ for(; i < input_width; ++i) { out_ptr[i] = utils::cast::saturate_cast(tmp_ptr[i] * sum_inversed); } } }, in_it, max_it, out_it); } template void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window) { const int start_x = in.info()->valid_region().anchor.x(); const int input_width = in.info()->valid_region().shape.x(); Iterator in_it(&in, window); Iterator max_it(&max, window); Iterator out_it(&out, window); execute_window_loop(window, [&](const Coordinates &) { /* Get pointers */ const auto in_ptr = reinterpret_cast(in_it.ptr()) + start_x; const auto out_ptr = reinterpret_cast(out_it.ptr()) + start_x; const auto tmp_ptr = reinterpret_cast(tmp); T sum_inversed; /* Compute exponentials and sum */ { /* Get max value */ const auto max_val = *reinterpret_cast(max_it.ptr()); const auto vec_max = vdup_n>(max_val); /* Init sum to zero */ auto vec_sum = vdup_n>(0); /* Loop over row and compute exponentials and sum */ int i = 0; constexpr int vec_size = vec_size_of(vec_sum); for(; i <= (input_width - vec_size); i += vec_size) { auto vec_elements = vld>(in_ptr + i); vec_elements = vsub(vec_elements, vec_max); vec_elements = vexp(vmul_n(vec_elements, static_cast(beta))); vec_sum = vadd(vec_sum, vec_elements); vst(tmp_ptr + i, vec_elements); } /* Reduce sum */ const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum)); T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte); /* Run remaining elements */ for(; i < input_width; ++i) { T element = std::exp((in_ptr[i] - max_val) * beta); sum += element; tmp_ptr[i] = element; } sum_inversed = T(1) / sum; } /* Normalize exponentials */ { /* Loop over row and compute softmax */ int i = 0; { constexpr int vec_size = vec_size_of(vec_16_byte_t {}); for(; i <= (input_width - vec_size); i += vec_size) { auto vec_in = vld>(tmp_ptr + i); vec_16_byte_t normalized_value = vmul_n(vec_in, sum_inversed); vst(out_ptr + i, normalized_value); } } /* Run remaining elements */ for(; i < input_width; ++i) { out_ptr[i] = tmp_ptr[i] * sum_inversed; } } }, in_it, max_it, out_it); } } // namespace NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel() : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr) { } void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp); ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info()); // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info())); // Configure kernel window auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info()); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); switch(input->info()->data_type()) { case DataType::QASYMM8: _func = &logits_1d_softmax_qasymm8; break; #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = &logits_1d_softmax_float; break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: _func = &logits_1d_softmax_float; break; default: ARM_COMPUTE_ERROR("Unsupported data type."); break; } _input = input; _max = max; _output = output; _beta = beta; _tmp = tmp; INEKernel::configure(win_config.second); } Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const float beta, const ITensorInfo *tmp) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first); return Status{}; } void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x(); const unsigned int tmp_size_for_thread = _tmp->info()->element_size() * num_elems_processed_per_iteration; ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread)); void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread); (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window); } } // namespace arm_compute