aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLSoftmaxLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLSoftmaxLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLSoftmaxLayer.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index 720f9111a5..759c8706a1 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -31,6 +31,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "src/core/helpers/SoftmaxHelpers.h"
namespace arm_compute
{
@@ -63,7 +64,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_co
{
_memory_group.manage(&_input_permuted);
_memory_group.manage(&_output_permuted);
- _permute_input.configure(compile_context, input, &_input_permuted, get_permutation_vector_from_softmax_axis(actual_axis));
+ _permute_input.configure(compile_context, input, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
tmp_output = &_output_permuted;
}
@@ -99,7 +100,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_co
_sum.allocator()->allocate();
if(_needs_permute)
{
- _permute_output.configure(compile_context, &_output_permuted, output, get_permutation_vector_from_softmax_axis(actual_axis));
+ _permute_output.configure(compile_context, &_output_permuted, output, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
_input_permuted.allocator()->allocate();
_output_permuted.allocator()->allocate();
}
@@ -117,7 +118,7 @@ Status CLSoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I
const bool needs_permute = actual_axis != 0;
if(needs_permute)
{
- const PermutationVector permutation_vector = get_permutation_vector_from_softmax_axis(actual_axis);
+ const PermutationVector permutation_vector = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(*input, permutation_vector);
TensorInfo input_permuted(input->clone()->set_tensor_shape(permuted_shape));
ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(input, &input_permuted, permutation_vector));