aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NESoftmaxLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NESoftmaxLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NESoftmaxLayer.cpp23
1 files changed, 16 insertions, 7 deletions
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
index 79a94961d8..f530a87d05 100644
--- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp
+++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
@@ -33,13 +33,15 @@
namespace arm_compute
{
-NESoftmaxLayer::NESoftmaxLayer(std::shared_ptr<IMemoryManager> memory_manager)
+template <bool IS_LOG>
+NESoftmaxLayerGeneric<IS_LOG>::NESoftmaxLayerGeneric(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(std::move(memory_manager)), _max_kernel(), _softmax_kernel(), _flat_or_reshape_kernel_ptr(nullptr), _fill_border_kernel(), _reshape_kernel(), _max(), _tmp(), _input_flattened(),
_output_flattened(), _needs_flattening(false)
{
}
-void NESoftmaxLayer::configure_reshape_input_kernel(const ITensor *input, const ITensor *output, size_t axis)
+template <bool IS_LOG>
+void NESoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const ITensor *input, const ITensor *output, size_t axis)
{
// Flatten the input
const TensorShape shape_flatten = misc::shape_calculator::compute_softmax_shape(input->info(), axis);
@@ -68,11 +70,12 @@ void NESoftmaxLayer::configure_reshape_input_kernel(const ITensor *input, const
auto_init_if_empty(*output->info(), *input->info()->clone());
}
-void NESoftmaxLayer::configure(ITensor *input, ITensor *output, float beta, size_t axis)
+template <bool IS_LOG>
+void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, float beta, size_t axis)
{
// Perform validation step
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_ERROR_THROW_ON(NESoftmaxLayer::validate(input->info(), output->info(), beta, axis));
+ ARM_COMPUTE_ERROR_THROW_ON(NESoftmaxLayerGeneric::validate(input->info(), output->info(), beta, axis));
// We don't need flattening only in the case the input is 2D and axis is 1
_needs_flattening = axis != 1;
@@ -138,7 +141,8 @@ void NESoftmaxLayer::configure(ITensor *input, ITensor *output, float beta, size
_tmp.allocator()->allocate();
}
-Status NESoftmaxLayer::validate(const ITensorInfo *input, const ITensorInfo *output, float beta, size_t axis)
+template <bool IS_LOG>
+Status NESoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *output, float beta, size_t axis)
{
// Perform validation step
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
@@ -173,12 +177,13 @@ Status NESoftmaxLayer::validate(const ITensorInfo *input, const ITensorInfo *out
}
ARM_COMPUTE_RETURN_ON_ERROR(NELogits1DMaxKernel::validate(input, &tensor_info_max_sum));
- ARM_COMPUTE_RETURN_ON_ERROR(NELogits1DSoftmaxKernel::validate(&tensor_info_tmp, &tensor_info_max_sum, output, beta, &dont_care));
+ ARM_COMPUTE_RETURN_ON_ERROR(NELogits1DSoftmaxKernel<IS_LOG>::validate(&tensor_info_tmp, &tensor_info_max_sum, output, beta, &dont_care));
return Status{};
}
-void NESoftmaxLayer::run()
+template <bool IS_LOG>
+void NESoftmaxLayerGeneric<IS_LOG>::run()
{
MemoryGroupResourceScope scope_mg(_memory_group);
@@ -196,4 +201,8 @@ void NESoftmaxLayer::run()
NEScheduler::get().schedule(&_reshape_kernel, Window::DimY);
}
}
+
+template class NESoftmaxLayerGeneric<false>;
+template class NESoftmaxLayerGeneric<true>;
+
} // namespace arm_compute \ No newline at end of file