aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
authorPablo Palmier <Pablo.Palmier@arm.com>2017-10-05 15:01:34 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:42:17 +0000
commita2b89ca5407532257a959ad1852f29187e1be4ac (patch)
treea202070aea45a81ec1ea8a86fa4047035eb2d567 /src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
parent5948634bb97e05934e9eea180ba41dcddf874416 (diff)
downloadComputeLibrary-a2b89ca5407532257a959ad1852f29187e1be4ac.tar.gz
IVGCVSW-631 Neon support for Softmax beta parameter (F32 only)
Change-Id: Ibf6f038b39f1a4e557f5d04feb08e3d5ef54e223 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112019 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NESoftmaxLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NESoftmaxLayerKernel.cpp26
1 files changed, 17 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index f1027590e4..a8a0f59a41 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -251,8 +251,10 @@ void NELogits1DMaxKernel::run(const Window &window, const ThreadInfo &info)
namespace
{
-void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
{
+ ARM_COMPUTE_UNUSED(beta);
+
Window window_max(window);
window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -313,8 +315,10 @@ void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor
}
while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
}
-void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
{
+ ARM_COMPUTE_UNUSED(beta);
+
Window window_max(window);
window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -375,7 +379,7 @@ void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
{
Window window_max(window);
window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -410,6 +414,7 @@ void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor
{
float16x8_t vec_elements = vld1q_f16(in_ptr);
vec_elements = vsubq_f16(vec_elements, vec_max);
+ vec_elements = vmulq_n_f16(vec_elements, beta);
vec_elements = vexpq_f16(vec_elements);
vst1q_f16(exp_ptr, vec_elements);
@@ -426,7 +431,7 @@ void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor
// Run remaining elements
for(int i = 0; i < small_steps; ++i)
{
- const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr));
+ const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
exp_ptr[i] = element;
sum += element;
}
@@ -436,7 +441,7 @@ void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
{
Window window_max(window);
window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -471,6 +476,7 @@ void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor
{
float32x4_t vec_elements = vld1q_f32(in_ptr);
vec_elements = vsubq_f32(vec_elements, vec_max);
+ vec_elements = vmulq_n_f32(vec_elements, beta);
vec_elements = vexpq_f32(vec_elements);
vst1q_f32(exp_ptr, vec_elements);
@@ -488,7 +494,7 @@ void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor
// Run remaining elements
for(int i = 0; i < small_steps; ++i)
{
- float element = std::exp(in_ptr[i] - *max_ptr);
+ float element = std::exp((in_ptr[i] - *max_ptr) * beta);
exp_ptr[i] = element;
sum += element;
}
@@ -500,14 +506,15 @@ void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor
} //namespace
NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
- : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
+ : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
{
}
-void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
+void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
+ ARM_COMPUTE_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->info()->data_type()));
// Output auto initialization if not yet initialized
auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
@@ -545,6 +552,7 @@ void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor
_max = max;
_output = output;
_sum = sum;
+ _beta = beta;
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
@@ -568,7 +576,7 @@ void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &in
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_func == nullptr);
- (*_func)(_input, _max, _output, _sum, window);
+ (*_func)(_input, _max, _output, _sum, window, _beta);
}
namespace