aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuSoftmaxKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp62
1 files changed, 44 insertions, 18 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 68bc397acf..54ff858eeb 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -81,7 +81,7 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker
};
Status validate_arguments_softmax(
- const ITensorInfo &src, const ITensorInfo &dst, float beta, const ITensorInfo &tmp, bool is_log)
+ const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
{
ARM_COMPUTE_UNUSED(beta);
// Check input
@@ -89,6 +89,8 @@ Status validate_arguments_softmax(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
+
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
// Check output if configured
@@ -124,10 +126,13 @@ const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> &CpuSoftmaxKernel::g
return available_kernels;
}
-void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp)
+void CpuSoftmaxKernel::configure(
+ const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
{
+ _axis = axis;
+
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
// Configure kernel window
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
@@ -154,25 +159,40 @@ void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float
_run_method = uk->ukernel;
_name = kernel_name.append("/").append(uk->name);
- Window win = calculate_max_window(*dst, Steps());
+ Window win;
+
+ int vec_size = 16 / dst->element_size();
- /// TODO: Check dimensions > 0 for holes only. For this, we need
- /// a utility function checking if there are holes after some dimension.
- if (!has_holes(*dst, dst->num_dimensions() - 1))
+ if (_axis == 0)
+ {
+ win = calculate_max_window(*dst, Steps());
+
+ /// TODO:Check dimensions > 0 for holes only. For this, we need
+ /// a utility function checking if there are holes after some dimension.
+ if (!has_holes(*dst, dst->num_dimensions() - 1))
+ {
+ win = win.collapse(win, Window::DimY);
+ }
+ }
+ else if (_axis > 0 && _axis <= 3)
{
- win = win.collapse(win, Window::DimY);
+ win = calculate_max_window(*dst, Steps(vec_size));
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Invalid axis");
}
- win.set(Window::DimX, Window::Dimension(0, 1, 1)); // First dimension is the reduction axis
+ win.set(_axis, Window::Dimension(0, 1, 1));
ICpuKernel<CpuSoftmaxKernel>::configure(win);
}
Status CpuSoftmaxKernel::validate(
- const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
+ const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
return Status{};
}
@@ -188,19 +208,25 @@ void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const
if (is_data_type_quantized_asymmetric(src->info()->data_type()))
{
- auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
-
- const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
+ auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+ unsigned int num_elems_processed_per_iteration;
+ if (_axis == 0)
+ {
+ num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
+ }
+ else
+ {
+ //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
+ num_elems_processed_per_iteration = 16;
+ }
const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
- ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
-
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- _run_method(src, tmp_for_thread, dst, _beta, window);
+ _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
}
else
{
- _run_method(src, nullptr, dst, _beta, window);
+ _run_method(src, nullptr, dst, _beta, _axis, window);
}
}