diff options
Diffstat (limited to 'src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp')
-rw-r--r-- | src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp index 5221c5cc5d..1748a5952b 100644 --- a/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp +++ b/src/runtime/GLES_COMPUTE/functions/GCSoftmaxLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -29,8 +29,8 @@ using namespace arm_compute; -GCSoftmaxLayer::GCSoftmaxLayer() - : _max_kernel(), _shift_exp_sum_kernel(), _norm_kernel(), _max(), _sum(), _tmp() +GCSoftmaxLayer::GCSoftmaxLayer(std::shared_ptr<IMemoryManager> memory_manager) + : _memory_group(std::move(memory_manager)), _max_kernel(), _shift_exp_sum_kernel(), _norm_kernel(), _max(), _sum(), _tmp() { } @@ -50,6 +50,11 @@ void GCSoftmaxLayer::configure(const IGCTensor *input, IGCTensor *output, float _max.allocator()->init(tensor_info_max_sum); _sum.allocator()->init(tensor_info_max_sum); + // Manage intermediate buffers + _memory_group.manage(&_tmp); + _memory_group.manage(&_max); + _memory_group.manage(&_sum); + // Configure Kernels _max_kernel.configure(input, &_max); _shift_exp_sum_kernel.configure(input, &_max, &_tmp, &_sum); @@ -63,9 +68,13 @@ void GCSoftmaxLayer::configure(const IGCTensor *input, IGCTensor *output, float void GCSoftmaxLayer::run() { + _memory_group.acquire(); + GCScheduler::get().dispatch(_max_kernel, false); GCScheduler::get().memory_barrier(); GCScheduler::get().dispatch(_shift_exp_sum_kernel, false); GCScheduler::get().memory_barrier(); GCScheduler::get().dispatch(_norm_kernel); + + _memory_group.release(); } |