aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NESoftmaxLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NESoftmaxLayerKernel.cpp162
1 files changed, 132 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index 4144a1877b..1003ebd2e3 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -333,6 +333,19 @@ float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
return res;
}
+float32x4x4_t vsub_n(float32x4x4_t a, float val)
+{
+ auto scalar_vector = vdup_n<float32x4x4_t>(val);
+ float32x4x4_t res = { {
+ vsubq_f32(a.val[0], scalar_vector.val[0]),
+ vsubq_f32(a.val[1], scalar_vector.val[1]),
+ vsubq_f32(a.val[2], scalar_vector.val[2]),
+ vsubq_f32(a.val[3], scalar_vector.val[3])
+ }
+ };
+ return res;
+}
+
namespace
{
Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
@@ -590,6 +603,7 @@ elem_type_t<V> reduce_add(F add_fn, V vec)
return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
}
+template <bool is_log>
void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
{
const int start_x = in.info()->valid_region().anchor.x();
@@ -608,7 +622,8 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons
const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
const auto tmp_ptr = reinterpret_cast<float *>(tmp);
- float sum_inversed;
+ float sum{};
+ float sum_inversed{};
/* Compute exponentials and sum */
{
@@ -622,33 +637,55 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons
/* Loop over row and compute exponentials and sum */
int i = 0;
constexpr int vec_size = vec_size_of(vec_max);
+
for(; i <= (input_width - vec_size); i += vec_size)
{
auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
vec_elements = vsubq_u8(vec_max, vec_elements);
auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
- vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
-
- vec_sum = vadd(vec_sum, vec_elements_flt);
+ if(is_log)
+ {
+ vec_elements_flt = vmul_n(vec_elements_flt, scale_beta);
+ vec_sum = vadd(vec_sum, vexp(vec_elements_flt));
+ }
+ else
+ {
+ vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
+ vec_sum = vadd(vec_sum, vec_elements_flt);
+ }
vst4q_f32(tmp_ptr + i, vec_elements_flt);
}
+
/* Reduce sum */
const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
- float sum = reduce_add(std::plus<float>(), sum_8_byte);
+ sum = reduce_add(std::plus<float>(), sum_8_byte);
/* Run remaining elements */
for(; i < input_width; ++i)
{
- const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
- sum += element;
+ float element{};
+ if(is_log)
+ {
+ element = (max_val - in_ptr[i]) * scale_beta;
+ sum += std::exp(element);
+ }
+ else
+ {
+ element = std::exp((max_val - in_ptr[i]) * scale_beta);
+ sum += element;
+ }
+
tmp_ptr[i] = element;
}
- sum_inversed = 256.f / sum;
+ if(!is_log)
+ {
+ sum_inversed = 256.f / sum;
+ }
}
/* Normalize exponentials */
@@ -657,24 +694,40 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons
int i = 0;
{
constexpr int vec_size = 16;
+
for(; i <= (input_width - vec_size); i += vec_size)
{
- float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
- auto normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
+ float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i);
+ vec_16_byte_t<qasymm8_t> normalized_value{};
+ if(is_log)
+ {
+ normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vsub_n(vec_in, sum));
+ }
+ else
+ {
+ normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
+ }
vst(out_ptr + i, normalized_value);
}
}
/* Run remaining elements */
for(; i < input_width; ++i)
{
- out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+ if(is_log)
+ {
+ out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] - sum);
+ }
+ else
+ {
+ out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+ }
}
}
},
in_it, max_it, out_it);
}
-template <typename T>
+template <typename T, bool is_log = false>
void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
ITensor &out, const float beta, const Window &window)
{
@@ -692,7 +745,8 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const
const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
const auto tmp_ptr = reinterpret_cast<T *>(tmp);
- T sum_inversed;
+ T sum{};
+ T sum_inversed{};
/* Compute exponentials and sum */
{
@@ -706,46 +760,87 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const
/* Loop over row and compute exponentials and sum */
int i = 0;
constexpr int vec_size = vec_size_of(vec_sum);
+
for(; i <= (input_width - vec_size); i += vec_size)
{
auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
vec_elements = vsub(vec_elements, vec_max);
- vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
- vec_sum = vadd(vec_sum, vec_elements);
+ if(is_log)
+ {
+ vec_elements = vmul_n(vec_elements, static_cast<T>(beta));
+ vec_sum = vadd(vec_sum, vexp(vec_elements));
+ }
+ else
+ {
+ vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
+ vec_sum = vadd(vec_sum, vec_elements);
+ }
vst(tmp_ptr + i, vec_elements);
}
+
/* Reduce sum */
const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
- T sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
+ sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
/* Run remaining elements */
+
for(; i < input_width; ++i)
{
- T element = std::exp((in_ptr[i] - max_val) * beta);
- sum += element;
+ T element{};
+
+ if(is_log)
+ {
+ element = (in_ptr[i] - max_val) * beta;
+ sum += std::exp(element);
+ }
+ else
+ {
+ element = std::exp((in_ptr[i] - max_val) * beta);
+ sum += element;
+ }
tmp_ptr[i] = element;
}
- sum_inversed = T(1) / sum;
+ if(!is_log)
+ {
+ sum_inversed = T(1) / sum;
+ }
}
/* Normalize exponentials */
{
/* Loop over row and compute softmax */
int i = 0;
+
{
constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
+
for(; i <= (input_width - vec_size); i += vec_size)
{
- auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
- vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
+ auto vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
+ vec_16_byte_t<T> normalized_value{};
+ if(is_log)
+ {
+ normalized_value = vsub(vec_in, vdup_n<vec_16_byte_t<T>>(sum));
+ }
+ else
+ {
+ normalized_value = vmul_n(vec_in, sum_inversed);
+ }
vst(out_ptr + i, normalized_value);
}
}
/* Run remaining elements */
for(; i < input_width; ++i)
{
- out_ptr[i] = tmp_ptr[i] * sum_inversed;
+ if(is_log)
+ {
+ out_ptr[i] = tmp_ptr[i] - sum;
+ }
+ else
+ {
+ out_ptr[i] = tmp_ptr[i] * sum_inversed;
+ }
}
}
},
@@ -753,12 +848,14 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const
}
} // namespace
-NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
+template <bool IS_LOG>
+NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
: _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
{
}
-void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
+template <bool IS_LOG>
+void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
@@ -771,15 +868,15 @@ void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max
switch(input->info()->data_type())
{
case DataType::QASYMM8:
- _func = &logits_1d_softmax_qasymm8;
+ _func = &logits_1d_softmax_qasymm8<IS_LOG>;
break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- _func = &logits_1d_softmax_float<float16_t>;
+ _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
- _func = &logits_1d_softmax_float<float>;
+ _func = &logits_1d_softmax_float<float, IS_LOG>;
break;
default:
ARM_COMPUTE_ERROR("Unsupported data type.");
@@ -795,8 +892,9 @@ void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max
INEKernel::configure(win_config.second);
}
-Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
- const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
+template <bool IS_LOG>
+Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
+ const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
@@ -806,7 +904,8 @@ Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensor
return Status{};
}
-void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
+template <bool IS_LOG>
+void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
@@ -822,4 +921,7 @@ void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
(*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
}
+template class NELogits1DSoftmaxKernel<true>;
+template class NELogits1DSoftmaxKernel<false>;
+
} // namespace arm_compute