diff options
Diffstat (limited to 'src/runtime/CL/functions/CLSoftmaxLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLSoftmaxLayer.cpp | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp index 720f9111a5..759c8706a1 100644 --- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp +++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp @@ -31,6 +31,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" +#include "src/core/helpers/SoftmaxHelpers.h" namespace arm_compute { @@ -63,7 +64,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_co { _memory_group.manage(&_input_permuted); _memory_group.manage(&_output_permuted); - _permute_input.configure(compile_context, input, &_input_permuted, get_permutation_vector_from_softmax_axis(actual_axis)); + _permute_input.configure(compile_context, input, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis)); tmp_output = &_output_permuted; } @@ -99,7 +100,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_co _sum.allocator()->allocate(); if(_needs_permute) { - _permute_output.configure(compile_context, &_output_permuted, output, get_permutation_vector_from_softmax_axis(actual_axis)); + _permute_output.configure(compile_context, &_output_permuted, output, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis)); _input_permuted.allocator()->allocate(); _output_permuted.allocator()->allocate(); } @@ -117,7 +118,7 @@ Status CLSoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I const bool needs_permute = actual_axis != 0; if(needs_permute) { - const PermutationVector permutation_vector = get_permutation_vector_from_softmax_axis(actual_axis); + const PermutationVector permutation_vector = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis); const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(*input, permutation_vector); TensorInfo input_permuted(input->clone()->set_tensor_shape(permuted_shape)); ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(input, &input_permuted, permutation_vector)); |