diff options
Diffstat (limited to 'src/runtime/NEON/functions/NESoftmaxLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NESoftmaxLayer.cpp | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp index 5509edec87..5cd6a550af 100644 --- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp +++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp @@ -32,8 +32,8 @@ namespace arm_compute { template <bool IS_LOG> NESoftmaxLayerGeneric<IS_LOG>::NESoftmaxLayerGeneric(std::shared_ptr<IMemoryManager> memory_manager) - : _memory_group(std::move(memory_manager)), _max_kernel(), _softmax_kernel(), _flat_or_reshape_kernel_ptr(nullptr), _fill_border_kernel(), _reshape_kernel(), _max(), _tmp(), _input_flattened(), - _output_flattened(), _needs_flattening(false) + : _memory_group(std::move(memory_manager)), _max_kernel(), _softmax_kernel(), _flat_or_reshape_ptr(nullptr), _fill_border_kernel(), _reshape(), _max(), _tmp(), _input_flattened(), _output_flattened(), + _needs_flattening(false) { } @@ -46,23 +46,20 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const ITensor // Initialize the flat input _input_flattened.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_flatten)); - // If we need to flatten the input, we can use NEFlattenKernel or NEReshapeKernel - // If the number of reduced axes is 3 (max dimension), which means collapsing all axes except the batch axis, we use NEFlattenKernel. - // In all other cases we have to use NEReshapeKernel // Note that the "other cases" include both: // 1. first_n_reduce_axes < 3: Reduce the first 1 (no need to reduce) or 2 dimensions (inclusive) // 2. first_n_reduce_axes == 4: Reduce all 4 dimensions. This can only be handled by NEReshapeKernel instead of NEFlattenKernel. if(first_n_reduce_axes == 3) { - auto flatten_kernel_ptr = support::cpp14::make_unique<NEFlattenLayerKernel>(); + auto flatten_kernel_ptr = support::cpp14::make_unique<NEFlattenLayer>(); flatten_kernel_ptr->configure(input, &_input_flattened); - _flat_or_reshape_kernel_ptr = std::move(flatten_kernel_ptr); + _flat_or_reshape_ptr = std::move(flatten_kernel_ptr); } else { - auto reshape_kernel_ptr = support::cpp14::make_unique<NEReshapeLayerKernel>(); + auto reshape_kernel_ptr = support::cpp14::make_unique<NEReshapeLayer>(); reshape_kernel_ptr->configure(input, &_input_flattened); - _flat_or_reshape_kernel_ptr = std::move(reshape_kernel_ptr); + _flat_or_reshape_ptr = std::move(reshape_kernel_ptr); } // We need to init the output tensor here. Indeed, the reshape kernel expects @@ -127,7 +124,7 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, f _input_flattened.allocator()->allocate(); // Reshape the flat output into the requested (4D) output - _reshape_kernel.configure(&_output_flattened, output); + _reshape.configure(&_output_flattened, output); // Allocate the intermediate flat tensors _output_flattened.allocator()->allocate(); @@ -174,11 +171,11 @@ Status NESoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I if(first_n_reduce_axes == 3) { - ARM_COMPUTE_RETURN_ON_ERROR(NEFlattenLayerKernel::validate(input, &tensor_info_flat)); + ARM_COMPUTE_RETURN_ON_ERROR(NEFlattenLayer::validate(input, &tensor_info_flat)); } else { - ARM_COMPUTE_RETURN_ON_ERROR(NEReshapeLayerKernel::validate(input, &tensor_info_flat)); + ARM_COMPUTE_RETURN_ON_ERROR(NEReshapeLayer::validate(input, &tensor_info_flat)); } } @@ -195,7 +192,7 @@ void NESoftmaxLayerGeneric<IS_LOG>::run() if(_needs_flattening) { - NEScheduler::get().schedule(_flat_or_reshape_kernel_ptr.get(), Window::DimY); + _flat_or_reshape_ptr->run(); } NEScheduler::get().schedule(&_fill_border_kernel, Window::DimY); @@ -204,11 +201,11 @@ void NESoftmaxLayerGeneric<IS_LOG>::run() if(_needs_flattening) { - NEScheduler::get().schedule(&_reshape_kernel, Window::DimY); + _reshape.run(); } } template class NESoftmaxLayerGeneric<false>; template class NESoftmaxLayerGeneric<true>; -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |