diff options
Diffstat (limited to 'src/runtime/CL/functions/CLSoftmaxLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLSoftmaxLayer.cpp | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp index e01d2c75ca..b0b2117cd9 100644 --- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp +++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -44,6 +44,12 @@ CLSoftmaxLayerGeneric<IS_LOG>::CLSoftmaxLayerGeneric(std::shared_ptr<IMemoryMana template <bool IS_LOG> void CLSoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const ICLTensor *input, const ICLTensor *output, size_t axis) { + configure_reshape_input_kernel(CLKernelLibrary::get().get_compile_context(), input, output, axis); +} + +template <bool IS_LOG> +void CLSoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *output, size_t axis) +{ // Flatten the input const TensorShape shape_flatten = misc::shape_calculator::compute_softmax_shape(input->info(), axis); @@ -56,13 +62,13 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const ICLTens if(axis != 3) { auto reshape_kernel_ptr = support::cpp14::make_unique<CLReshapeLayerKernel>(); - reshape_kernel_ptr->configure(input, &_input_flattened); + reshape_kernel_ptr->configure(compile_context, input, &_input_flattened); _flatten_kernel_ptr = std::move(reshape_kernel_ptr); } else { auto flatten_kernel_ptr = support::cpp14::make_unique<CLFlattenLayerKernel>(); - flatten_kernel_ptr->configure(input, &_input_flattened); + flatten_kernel_ptr->configure(compile_context, input, &_input_flattened); _flatten_kernel_ptr = std::move(flatten_kernel_ptr); } @@ -74,6 +80,12 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const ICLTens template <bool IS_LOG> void CLSoftmaxLayerGeneric<IS_LOG>::configure(const ICLTensor *input, ICLTensor *output, float beta, size_t axis) { + configure(CLKernelLibrary::get().get_compile_context(), input, output, beta, axis); +} + +template <bool IS_LOG> +void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, float beta, size_t axis) +{ // Perform validation step ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_THROW_ON(CLSoftmaxLayerGeneric<IS_LOG>::validate(input->info(), output->info(), beta, axis)); @@ -123,7 +135,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const ICLTensor *input, ICLTensor softmax_info.input_data_type = input_2D->info()->data_type(); // Configure kernels - _max_shift_exp_sum_kernel.configure(input_2D, &_max, &_tmp, &_sum, softmax_info); + _max_shift_exp_sum_kernel.configure(compile_context, input_2D, &_max, &_tmp, &_sum, softmax_info); if(_needs_flattening) { @@ -131,10 +143,10 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const ICLTensor *input, ICLTensor _memory_group.manage(&_output_flattened); // The normalization kernel stores the result in a flat output tensor - _norm_kernel.configure(&_tmp, &_sum, &_output_flattened, softmax_info); + _norm_kernel.configure(compile_context, &_tmp, &_sum, &_output_flattened, softmax_info); // Reshape the flat output into a the requested (4D) output - _reshape_kernel.configure(&_output_flattened, output); + _reshape_kernel.configure(compile_context, &_output_flattened, output); // Allocate the intermediate flat tensors _input_flattened.allocator()->allocate(); @@ -143,7 +155,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const ICLTensor *input, ICLTensor else { // Softmax 2D case - _norm_kernel.configure(&_tmp, &_sum, output, softmax_info); + _norm_kernel.configure(compile_context, &_tmp, &_sum, output, softmax_info); } // Allocate intermediate buffers @@ -203,7 +215,7 @@ Status CLSoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I } template <bool IS_LOG> -void CLSoftmaxLayerGeneric<IS_LOG>::run() +void CLSoftmaxLayerGeneric<IS_LOG>::run() { MemoryGroupResourceScope scope_mg(_memory_group); |