From 9247c92bd8c53be4d0c4ae931f51ca8f88e4150b Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 28 Jun 2017 18:29:47 +0100 Subject: COMPMID-428: Port NESoftmaxLayer to 16-bit fixed point. Change-Id: I65122950bab9124b9758c27096c0f458b77aeabb Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79365 Reviewed-by: Moritz Pflanzer Tested-by: Kaizen Reviewed-by: Steven Niu --- src/core/NEON/kernels/NESoftmaxLayerKernel.cpp | 305 +++++++++++++++++-------- 1 file changed, 216 insertions(+), 89 deletions(-) (limited to 'src/core/NEON/kernels/NESoftmaxLayerKernel.cpp') diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp index 854fd84845..fe62d7b575 100644 --- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp +++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp @@ -43,7 +43,7 @@ using namespace arm_compute; namespace { -void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window) +void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window) { Window in_slice = window.first_slice_window_1D(); @@ -56,25 +56,57 @@ void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window) Iterator input(in, in_slice); Iterator output(out, max_slice); - float32x4_t vec_max = vdupq_n_f32(-FLT_MAX); + qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits::lowest()); execute_window_loop(in_slice, [&](const Coordinates & id) { - const auto in_ptr = reinterpret_cast(input.ptr()); - const float32x4_t current_value = vld1q_f32(in_ptr); - vec_max = vmaxq_f32(vec_max, current_value); + const auto in_ptr = reinterpret_cast(input.ptr()); + const qint8x16_t current_value = vld1q_qs8(in_ptr); + vec_max = vmaxq_qs8(vec_max, current_value); }, input); - float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max)); - carry_max = vpmax_f32(carry_max, carry_max); + qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max)); + carry_max = vpmax_qs8(carry_max, carry_max); + carry_max = vpmax_qs8(carry_max, carry_max); + carry_max = vpmax_qs8(carry_max, carry_max); - *(reinterpret_cast(output.ptr())) = vget_lane_f32(carry_max, 0); + *(reinterpret_cast(output.ptr())) = vget_lane_s8(carry_max, 0); } while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice)); } +void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window) +{ + Window in_slice = window.first_slice_window_1D(); -void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window) + Window window_max(window); + window_max.set(Window::DimX, Window::Dimension(0, 0, 0)); + Window max_slice = window_max.first_slice_window_1D(); + + do + { + Iterator input(in, in_slice); + Iterator output(out, max_slice); + + qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits::lowest()); + + execute_window_loop(in_slice, [&](const Coordinates & id) + { + const auto in_ptr = reinterpret_cast(input.ptr()); + const qint16x8_t current_value = vld1q_qs16(in_ptr); + vec_max = vmaxq_qs16(vec_max, current_value); + }, + input); + + qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max)); + carry_max = vpmax_qs16(carry_max, carry_max); + carry_max = vpmax_qs16(carry_max, carry_max); + + *(reinterpret_cast(output.ptr())) = vget_lane_s16(carry_max, 0); + } + while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice)); +} +void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window) { Window in_slice = window.first_slice_window_1D(); @@ -87,22 +119,20 @@ void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window) Iterator input(in, in_slice); Iterator output(out, max_slice); - qint8x16_t vec_max = vdupq_n_s8(-1); + float32x4_t vec_max = vdupq_n_f32(-FLT_MAX); execute_window_loop(in_slice, [&](const Coordinates & id) { - const auto in_ptr = reinterpret_cast(input.ptr()); - const qint8x16_t current_value = vld1q_qs8(in_ptr); - vec_max = vmaxq_qs8(vec_max, current_value); + const auto in_ptr = reinterpret_cast(input.ptr()); + const float32x4_t current_value = vld1q_f32(in_ptr); + vec_max = vmaxq_f32(vec_max, current_value); }, input); - qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max)); - carry_max = vpmax_qs8(carry_max, carry_max); - carry_max = vpmax_qs8(carry_max, carry_max); - carry_max = vpmax_qs8(carry_max, carry_max); + float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max)); + carry_max = vpmax_f32(carry_max, carry_max); - *(reinterpret_cast(output.ptr())) = vget_lane_s8(carry_max, 0); + *(reinterpret_cast(output.ptr())) = vget_lane_f32(carry_max, 0); } while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice)); } @@ -120,7 +150,7 @@ BorderSize NELogits1DMaxKernel::border_size() const void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); // Softmax across the x dimension @@ -135,17 +165,18 @@ void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output) ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); const int input_width = input->info()->valid_region().shape.x(); - unsigned int num_elems_processed_per_iteration = 0; + unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type()); switch(input->info()->data_type()) { case DataType::QS8: - _func = &logits_1d_max_qs8; - num_elems_processed_per_iteration = 16; + _func = &logits_1d_max_qs8; + break; + case DataType::QS16: + _func = &logits_1d_max_qs16; break; case DataType::F32: - num_elems_processed_per_iteration = 4; - _func = &logits_1d_max_f32; + _func = &logits_1d_max_f32; break; default: ARM_COMPUTE_ERROR("Unsupported data type."); @@ -180,7 +211,7 @@ void NELogits1DMaxKernel::run(const Window &window) namespace { -void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window) +void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window) { Window window_max(window); window_max.set(Window::DimX, Window::Dimension(0, 0, 0)); @@ -188,9 +219,10 @@ void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor Window max_slice = window_max.first_slice_window_1D(); Window in_slice = window.first_slice_window_1D(); - constexpr int step = 4; - const int long_steps = in->info()->valid_region().shape.x() / step; - const int small_steps = in->info()->valid_region().shape.x() % step; + constexpr int step = 8; + const int long_steps = in->info()->valid_region().shape.x() / step; + const int small_steps = in->info()->valid_region().shape.x() % step; + const int fixed_point_position = in->info()->fixed_point_position(); do { @@ -200,48 +232,48 @@ void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor Iterator _sum(sum, max_slice); // Get pointers - auto in_ptr = reinterpret_cast(input.ptr()); - auto exp_ptr = reinterpret_cast(exp.ptr()); + auto in_ptr = reinterpret_cast(input.ptr()); + auto exp_ptr = reinterpret_cast(exp.ptr()); // Init sum to zero - float32x4_t vec_sum_value = vdupq_n_f32(0.0f); + qint16x8_t vec_sum_value = vdupq_n_qs16(0); // Get max value - const auto max_ptr = reinterpret_cast(_max.ptr()); - const float32x4_t vec_max = vdupq_n_f32(*max_ptr); + const auto max_ptr = reinterpret_cast(_max.ptr()); + const qint8x8_t vec_max = vdup_n_qs8(*max_ptr); // Run neon loop for(int i = 0; i < long_steps; ++i) { - float32x4_t vec_elements = vld1q_f32(in_ptr); - vec_elements = vsubq_f32(vec_elements, vec_max); - vec_elements = vexpq_f32(vec_elements); + qint8x8_t vec_elements = vld1_qs8(in_ptr); + vec_elements = vqsub_qs8(vec_elements, vec_max); + vec_elements = vqexp_qs8(vec_elements, fixed_point_position); - vst1q_f32(exp_ptr, vec_elements); - vec_sum_value = vaddq_f32(vec_elements, vec_sum_value); + vst1_qs8(exp_ptr, vec_elements); + vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements)); in_ptr += step; exp_ptr += step; } - // Reduce sum - float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value)); - carry_addition = vpadd_f32(carry_addition, carry_addition); - float sum = vget_lane_f32(carry_addition, 0); + const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value)); + const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1)); + const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3)); + qint16_t sum = sqadd_qs16(sum0, sum1); // Run remaining elements for(int i = 0; i < small_steps; ++i) { - float element = std::exp(in_ptr[i] - *max_ptr); - exp_ptr[i] = element; - sum += element; + qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position); + exp_ptr[i] = element; + sum = sqadd_qs16(sum, element); } - *(reinterpret_cast(_sum.ptr())) = sum; + *(reinterpret_cast(_sum.ptr())) = sqmovn_qs16(sum); } while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice)); } -void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window) +void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window) { Window window_max(window); window_max.set(Window::DimX, Window::Dimension(0, 0, 0)); @@ -249,7 +281,7 @@ void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor Window max_slice = window_max.first_slice_window_1D(); Window in_slice = window.first_slice_window_1D(); - constexpr int step = 8; + constexpr int step = 4; const int long_steps = in->info()->valid_region().shape.x() / step; const int small_steps = in->info()->valid_region().shape.x() % step; const int fixed_point_position = in->info()->fixed_point_position(); @@ -262,44 +294,103 @@ void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor Iterator _sum(sum, max_slice); // Get pointers - auto in_ptr = reinterpret_cast(input.ptr()); - auto exp_ptr = reinterpret_cast(exp.ptr()); + auto in_ptr = reinterpret_cast(input.ptr()); + auto exp_ptr = reinterpret_cast(exp.ptr()); // Init sum to zero - qint16x8_t vec_sum_value = vdupq_n_qs16(0); + qint32x4_t vec_sum_value = vdupq_n_qs32(0); // Get max value - const auto max_ptr = reinterpret_cast(_max.ptr()); - const qint8x8_t vec_max = vdup_n_qs8(*max_ptr); + const auto max_ptr = reinterpret_cast(_max.ptr()); + const qint16x4_t vec_max = vdup_n_qs16(*max_ptr); // Run neon loop for(int i = 0; i < long_steps; ++i) { - qint8x8_t vec_elements = vld1_qs8(in_ptr); - vec_elements = vqsub_qs8(vec_elements, vec_max); - vec_elements = vqexp_qs8(vec_elements, fixed_point_position); + qint16x4_t vec_elements = vld1_qs16(in_ptr); + vec_elements = vqsub_qs16(vec_elements, vec_max); + vec_elements = vqexp_qs16(vec_elements, fixed_point_position); - vst1_qs8(exp_ptr, vec_elements); - vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements)); + vst1_qs16(exp_ptr, vec_elements); + vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements)); in_ptr += step; exp_ptr += step; } // Reduce sum - const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value)); - const qint16_t sum0 = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1)); - const qint16_t sum1 = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3)); - qint16_t sum = sqadd_qs16(sum0, sum1); + qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value)); + qint32_t sum = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1); // Run remaining elements for(int i = 0; i < small_steps; ++i) { - qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position); - exp_ptr[i] = element; - sum = sqadd_qs16(sum, element); + qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position); + exp_ptr[i] = element; + sum = sqadd_qs32(sum, element); } - *(reinterpret_cast(_sum.ptr())) = sqmovn_qs16(sum); + *(reinterpret_cast(_sum.ptr())) = sqmovn_qs32(sum); + } + while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice)); +} +void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window) +{ + Window window_max(window); + window_max.set(Window::DimX, Window::Dimension(0, 0, 0)); + + Window max_slice = window_max.first_slice_window_1D(); + Window in_slice = window.first_slice_window_1D(); + + constexpr int step = 4; + const int long_steps = in->info()->valid_region().shape.x() / step; + const int small_steps = in->info()->valid_region().shape.x() % step; + + do + { + Iterator input(in, in_slice); + Iterator exp(out, in_slice); + Iterator _max(max, max_slice); + Iterator _sum(sum, max_slice); + + // Get pointers + auto in_ptr = reinterpret_cast(input.ptr()); + auto exp_ptr = reinterpret_cast(exp.ptr()); + + // Init sum to zero + float32x4_t vec_sum_value = vdupq_n_f32(0.0f); + + // Get max value + const auto max_ptr = reinterpret_cast(_max.ptr()); + const float32x4_t vec_max = vdupq_n_f32(*max_ptr); + + // Run neon loop + for(int i = 0; i < long_steps; ++i) + { + float32x4_t vec_elements = vld1q_f32(in_ptr); + vec_elements = vsubq_f32(vec_elements, vec_max); + vec_elements = vexpq_f32(vec_elements); + + vst1q_f32(exp_ptr, vec_elements); + vec_sum_value = vaddq_f32(vec_elements, vec_sum_value); + + in_ptr += step; + exp_ptr += step; + } + + // Reduce sum + float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value)); + carry_addition = vpadd_f32(carry_addition, carry_addition); + float sum = vget_lane_f32(carry_addition, 0); + + // Run remaining elements + for(int i = 0; i < small_steps; ++i) + { + float element = std::exp(in_ptr[i] - *max_ptr); + exp_ptr[i] = element; + sum += element; + } + + *(reinterpret_cast(_sum.ptr())) = sum; } while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice)); } @@ -312,7 +403,7 @@ NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel() void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output); // Output auto initialization if not yet initialized @@ -331,6 +422,9 @@ void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor case DataType::QS8: _func = &logits_1d_shift_exp_sum_qs8; break; + case DataType::QS16: + _func = &logits_1d_shift_exp_sum_qs16; + break; case DataType::F32: _func = &logits_1d_shift_exp_sum_f32; break; @@ -369,37 +463,39 @@ void NELogits1DShiftExpSumKernel::run(const Window &window) namespace { -void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window) +void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window) { Window window_sum(window); window_sum.set(Window::DimX, Window::Dimension(0, 0, 0)); Window sum_slice = window_sum.first_slice_window_1D(); Window in_slice = window.first_slice_window_1D(); + const int fixed_point_position = in->info()->fixed_point_position(); + do { Iterator input(in, in_slice); Iterator _sum(sum, sum_slice); Iterator output(out, in_slice); - const float sum_value = *reinterpret_cast(_sum.ptr()); - const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value); + const int8_t sum_value = *reinterpret_cast(_sum.ptr()); + const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position); execute_window_loop(in_slice, [&](const Coordinates & id) { - const auto in_ptr = reinterpret_cast(input.ptr()); - const auto out_ptr = reinterpret_cast(output.ptr()); + const auto in_ptr = reinterpret_cast(input.ptr()); + const auto out_ptr = reinterpret_cast(output.ptr()); - const float32x4_t vec_in = vld1q_f32(in_ptr); - const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed); + const qint8x16_t vec_in = vld1q_qs8(in_ptr); + const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position); - vst1q_f32(out_ptr, normalized_value); + vst1q_qs8(out_ptr, normalized_value); }, input, output); } while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice)); } -void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window) +void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window) { Window window_sum(window); window_sum.set(Window::DimX, Window::Dimension(0, 0, 0)); @@ -414,18 +510,48 @@ void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, con Iterator _sum(sum, sum_slice); Iterator output(out, in_slice); - const int8_t sum_value = *reinterpret_cast(_sum.ptr()); - const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position); + const int16_t sum_value = *reinterpret_cast(_sum.ptr()); + const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position); execute_window_loop(in_slice, [&](const Coordinates & id) { - const auto in_ptr = reinterpret_cast(input.ptr()); - const auto out_ptr = reinterpret_cast(output.ptr()); + const auto in_ptr = reinterpret_cast(input.ptr()); + const auto out_ptr = reinterpret_cast(output.ptr()); - const qint8x16_t vec_in = vld1q_qs8(in_ptr); - const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position); + const qint16x8_t vec_in = vld1q_qs16(in_ptr); + const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position); - vst1q_qs8(out_ptr, normalized_value); + vst1q_qs16(out_ptr, normalized_value); + }, + input, output); + } + while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice)); +} +void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window) +{ + Window window_sum(window); + window_sum.set(Window::DimX, Window::Dimension(0, 0, 0)); + Window sum_slice = window_sum.first_slice_window_1D(); + Window in_slice = window.first_slice_window_1D(); + + do + { + Iterator input(in, in_slice); + Iterator _sum(sum, sum_slice); + Iterator output(out, in_slice); + + const float sum_value = *reinterpret_cast(_sum.ptr()); + const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value); + + execute_window_loop(in_slice, [&](const Coordinates & id) + { + const auto in_ptr = reinterpret_cast(input.ptr()); + const auto out_ptr = reinterpret_cast(output.ptr()); + + const float32x4_t vec_in = vld1q_f32(in_ptr); + const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed); + + vst1q_f32(out_ptr, normalized_value); }, input, output); } @@ -440,7 +566,7 @@ NELogits1DNormKernel::NELogits1DNormKernel() void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::QS8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output); // Output auto initialization if not yet initialized @@ -455,17 +581,18 @@ void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, I _output = output; // Configure kernel window - unsigned int num_elems_processed_per_iteration = 0; + unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type()); switch(input->info()->data_type()) { case DataType::QS8: - _func = &logits_1d_norm_qs8; - num_elems_processed_per_iteration = 16; + _func = &logits_1d_norm_qs8; + break; + case DataType::QS16: + _func = &logits_1d_norm_qs16; break; case DataType::F32: - num_elems_processed_per_iteration = 4; - _func = &logits_1d_norm_f32; + _func = &logits_1d_norm_f32; break; default: ARM_COMPUTE_ERROR("Unsupported data type."); -- cgit v1.2.1