aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2020-07-02 12:43:53 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-07-03 17:15:00 +0000
commit2aad21a900a21f467b3ec6b37420f892f0d80221 (patch)
tree7973bbf13d2bc7ea88ab0bf9d7c51e6b2d3e6907 /src
parentd13931d05b0d5ccea4265c342c6a3bf40a3b85cc (diff)
downloadComputeLibrary-2aad21a900a21f467b3ec6b37420f892f0d80221.tar.gz
COMPMID-3388: Async support to CLReshapeLayerKernel kernels/functions
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Change-Id: I141a943dfd691069317860e852ecdd0ba7391604 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3501 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/kernels/CLReshapeLayerKernel.cpp46
-rw-r--r--src/runtime/CL/CLOperator.cpp53
-rw-r--r--src/runtime/CL/CLScheduler.cpp23
-rw-r--r--src/runtime/CL/functions/CLArgMinMaxLayer.cpp8
-rw-r--r--src/runtime/CL/functions/CLGenerateProposalsLayer.cpp20
-rw-r--r--src/runtime/CL/functions/CLReductionOperation.cpp8
-rw-r--r--src/runtime/CL/functions/CLReshapeLayer.cpp59
-rw-r--r--src/runtime/CL/functions/CLSoftmaxLayer.cpp24
8 files changed, 175 insertions, 66 deletions
diff --git a/src/core/CL/kernels/CLReshapeLayerKernel.cpp b/src/core/CL/kernels/CLReshapeLayerKernel.cpp
index ce792489c5..97fde8645e 100644
--- a/src/core/CL/kernels/CLReshapeLayerKernel.cpp
+++ b/src/core/CL/kernels/CLReshapeLayerKernel.cpp
@@ -38,8 +38,8 @@
#include <string>
/** [CLReshapeLayerKernel Kernel] **/
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
@@ -54,44 +54,30 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
return Status{};
}
-
} // namespace
-CLReshapeLayerKernel::CLReshapeLayerKernel()
- : _input(nullptr), _output(nullptr)
-{
-}
-
-void CLReshapeLayerKernel::configure(const ICLTensor *input, ICLTensor *output)
-{
- configure(CLKernelLibrary::get().get_compile_context(), input, output);
-}
-
-void CLReshapeLayerKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output)
+void CLReshapeLayerKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
-
- _input = input;
- _output = output;
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input, output));
// Create kernel
- std::set<std::string> build_opts = { "-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(input->info()->element_size()) };
+ std::set<std::string> build_opts = { "-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(input->element_size()) };
_kernel = create_kernel(compile_context, "reshape_layer", build_opts);
// Add static arguments
const cl_int2 input_shape =
{
{
- static_cast<cl_int>(_input->info()->tensor_shape()[0]),
- static_cast<cl_int>(_input->info()->tensor_shape()[1])
+ static_cast<cl_int>(input->tensor_shape()[0]),
+ static_cast<cl_int>(input->tensor_shape()[1])
}
};
const cl_int2 output_shape =
{
{
- static_cast<cl_int>(_output->info()->tensor_shape()[0]),
- static_cast<cl_int>(_output->info()->tensor_shape()[1])
+ static_cast<cl_int>(output->tensor_shape()[0]),
+ static_cast<cl_int>(output->tensor_shape()[1])
}
};
unsigned int idx = 2 * num_arguments_per_3D_tensor(); // Skip the input and output parameters
@@ -99,10 +85,10 @@ void CLReshapeLayerKernel::configure(const CLCompileContext &compile_context, co
_kernel.setArg<cl_int2>(idx++, output_shape);
// Configure kernel window
- Window win = calculate_max_window(*input->info());
+ Window win = calculate_max_window(*input);
// Set the output valid region
- output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
+ output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
ICLKernel::configure_internal(win);
}
@@ -113,7 +99,7 @@ Status CLReshapeLayerKernel::validate(const ITensorInfo *input, const ITensorInf
return Status{};
}
-void CLReshapeLayerKernel::run(const Window &window, cl::CommandQueue &queue)
+void CLReshapeLayerKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
@@ -121,10 +107,14 @@ void CLReshapeLayerKernel::run(const Window &window, cl::CommandQueue &queue)
Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
Window slice = window_collapsed.first_slice_window_3D();
+ const auto src = dynamic_cast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC));
+ auto dst = dynamic_cast<ICLTensor *>(outputs.at(TensorType::ACL_DST));
+
// Set inputs
unsigned int idx = 0;
- add_3D_tensor_argument(idx, _input, window_collapsed);
- add_3D_tensor_argument(idx, _output, window_collapsed);
+ add_3D_tensor_argument(idx, src, window_collapsed);
+ add_3D_tensor_argument(idx, dst, window_collapsed);
enqueue(queue, *this, slice, lws_hint());
}
+} // namespace arm_compute
/** [CLReshapeLayerKernel Kernel] **/
diff --git a/src/runtime/CL/CLOperator.cpp b/src/runtime/CL/CLOperator.cpp
new file mode 100644
index 0000000000..8dc05d8380
--- /dev/null
+++ b/src/runtime/CL/CLOperator.cpp
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2020 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/ICLOperator.h"
+
+namespace arm_compute
+{
+namespace experimental
+{
+ICLOperator::ICLOperator(IRuntimeContext *ctx)
+ : _kernel(), _ctx(ctx), _workspace()
+{
+}
+
+void ICLOperator::run(InputTensorMap inputs, OutputTensorMap outputs, OperatorTensorMap workspace)
+{
+ ARM_COMPUTE_UNUSED(workspace);
+
+ if(inputs.empty() || outputs.empty())
+ {
+ ARM_COMPUTE_ERROR("No inputs provided");
+ }
+
+ CLScheduler::get().enqueue_op(*_kernel.get(), inputs, outputs, false);
+}
+
+void ICLOperator::prepare(OperatorTensorMap constants)
+{
+ ARM_COMPUTE_UNUSED(constants);
+}
+} // namespace experimental
+} // namespace arm_compute
diff --git a/src/runtime/CL/CLScheduler.cpp b/src/runtime/CL/CLScheduler.cpp
index e78eaa482f..2c1024fcc7 100644
--- a/src/runtime/CL/CLScheduler.cpp
+++ b/src/runtime/CL/CLScheduler.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -151,7 +151,7 @@ void CLScheduler::init(cl::Context context, cl::CommandQueue queue, const cl::De
_cl_tuner = cl_tuner;
}
-void CLScheduler::enqueue(ICLKernel &kernel, bool flush)
+void CLScheduler::enqueue_common(ICLKernel &kernel, const InputTensorMap &inputs, const OutputTensorMap &outputs, bool flush)
{
ARM_COMPUTE_ERROR_ON_MSG(!_is_initialised,
"The CLScheduler is not initialised yet! Please call the CLScheduler::get().default_init(), \
@@ -165,11 +165,28 @@ void CLScheduler::enqueue(ICLKernel &kernel, bool flush)
}
// Run kernel
- kernel.run(kernel.window(), _queue);
+ if(inputs.empty())
+ {
+ kernel.run(kernel.window(), _queue);
+ }
+ else
+ {
+ kernel.run_op(inputs, outputs, kernel.window(), _queue);
+ }
if(flush)
{
_queue.flush();
}
}
+
+void CLScheduler::enqueue(ICLKernel &kernel, bool flush)
+{
+ enqueue_common(kernel, {}, {}, flush);
+}
+
+void CLScheduler::enqueue_op(ICLKernel &kernel, const InputTensorMap &inputs, const OutputTensorMap &outputs, bool flush)
+{
+ enqueue_common(kernel, inputs, outputs, flush);
+}
} // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLArgMinMaxLayer.cpp b/src/runtime/CL/functions/CLArgMinMaxLayer.cpp
index cb2b290adf..8fcd04f2fa 100644
--- a/src/runtime/CL/functions/CLArgMinMaxLayer.cpp
+++ b/src/runtime/CL/functions/CLArgMinMaxLayer.cpp
@@ -35,7 +35,7 @@
namespace arm_compute
{
CLArgMinMaxLayer::CLArgMinMaxLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _results_vector(), _not_reshaped_output(), _reduction_kernels_vector(), _reshape_kernel(), _num_of_stages(), _reduction_axis()
+ : _memory_group(std::move(memory_manager)), _results_vector(), _not_reshaped_output(), _reduction_kernels_vector(), _reshape(), _num_of_stages(), _reduction_axis()
{
}
@@ -103,7 +103,7 @@ Status CLArgMinMaxLayer::validate(const ITensorInfo *input, int axis, const ITen
const unsigned int last_stage = num_of_stages - 1;
ARM_COMPUTE_RETURN_ON_ERROR(CLArgMinMaxLayerKernel::validate(input, &sums_vector[last_stage - 1], &not_reshaped_output, axis, op));
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(&not_reshaped_output, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayer::validate(&not_reshaped_output, output));
return Status{};
}
@@ -158,7 +158,7 @@ void CLArgMinMaxLayer::configure(const CLCompileContext &compile_context, const
_reduction_kernels_vector[last_stage].configure(compile_context, input, &_results_vector[last_stage - 1], &_not_reshaped_output, axis, op);
_results_vector[last_stage - 1].allocator()->allocate();
}
- _reshape_kernel.configure(compile_context, &_not_reshaped_output, output);
+ _reshape.configure(compile_context, &_not_reshaped_output, output);
_not_reshaped_output.allocator()->allocate();
}
@@ -170,6 +170,6 @@ void CLArgMinMaxLayer::run()
{
CLScheduler::get().enqueue(_reduction_kernels_vector[i], false);
}
- CLScheduler::get().enqueue(_reshape_kernel, false);
+ _reshape.run();
}
} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/CL/functions/CLGenerateProposalsLayer.cpp b/src/runtime/CL/functions/CLGenerateProposalsLayer.cpp
index 7f037fc51f..1b89bb4cfe 100644
--- a/src/runtime/CL/functions/CLGenerateProposalsLayer.cpp
+++ b/src/runtime/CL/functions/CLGenerateProposalsLayer.cpp
@@ -31,9 +31,9 @@ namespace arm_compute
CLGenerateProposalsLayer::CLGenerateProposalsLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_group(memory_manager),
_permute_deltas_kernel(),
- _flatten_deltas_kernel(),
+ _flatten_deltas(),
_permute_scores_kernel(),
- _flatten_scores_kernel(),
+ _flatten_scores(),
_compute_anchors_kernel(),
_bounding_box_kernel(),
_pad_kernel(),
@@ -102,12 +102,12 @@ void CLGenerateProposalsLayer::configure(const CLCompileContext &compile_context
{
_memory_group.manage(&_deltas_permuted);
_permute_deltas_kernel.configure(compile_context, deltas, &_deltas_permuted, PermutationVector{ 2, 0, 1 });
- _flatten_deltas_kernel.configure(compile_context, &_deltas_permuted, &_deltas_flattened);
+ _flatten_deltas.configure(compile_context, &_deltas_permuted, &_deltas_flattened);
_deltas_permuted.allocator()->allocate();
}
else
{
- _flatten_deltas_kernel.configure(compile_context, deltas, &_deltas_flattened);
+ _flatten_deltas.configure(compile_context, deltas, &_deltas_flattened);
}
const TensorShape flatten_shape_scores(1, total_num_anchors);
@@ -119,12 +119,12 @@ void CLGenerateProposalsLayer::configure(const CLCompileContext &compile_context
{
_memory_group.manage(&_scores_permuted);
_permute_scores_kernel.configure(compile_context, scores, &_scores_permuted, PermutationVector{ 2, 0, 1 });
- _flatten_scores_kernel.configure(compile_context, &_scores_permuted, &_scores_flattened);
+ _flatten_scores.configure(compile_context, &_scores_permuted, &_scores_flattened);
_scores_permuted.allocator()->allocate();
}
else
{
- _flatten_scores_kernel.configure(compile_context, scores, &_scores_flattened);
+ _flatten_scores.configure(compile_context, scores, &_scores_flattened);
}
CLTensor *anchors_to_use = &_all_anchors;
@@ -240,12 +240,12 @@ Status CLGenerateProposalsLayer::validate(const ITensorInfo *scores, const ITens
}
TensorInfo deltas_flattened_info(deltas->clone()->set_tensor_shape(TensorShape(values_per_roi, total_num_anchors)).set_is_resizable(true));
- ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(&deltas_permuted_info, &deltas_flattened_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayer::validate(&deltas_permuted_info, &deltas_flattened_info));
TensorInfo scores_flattened_info(scores->clone()->set_tensor_shape(TensorShape(1, total_num_anchors)).set_is_resizable(true));
TensorInfo proposals_4_roi_values(deltas->clone()->set_tensor_shape(TensorShape(values_per_roi, total_num_anchors)).set_is_resizable(true));
- ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(&scores_permuted_info, &scores_flattened_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayer::validate(&scores_permuted_info, &scores_flattened_info));
TensorInfo *proposals_4_roi_values_to_use = &proposals_4_roi_values;
TensorInfo proposals_4_roi_values_quantized(deltas->clone()->set_tensor_shape(TensorShape(values_per_roi, total_num_anchors)).set_is_resizable(true));
@@ -350,8 +350,8 @@ void CLGenerateProposalsLayer::run()
CLScheduler::get().enqueue(_permute_deltas_kernel, false);
CLScheduler::get().enqueue(_permute_scores_kernel, false);
}
- CLScheduler::get().enqueue(_flatten_deltas_kernel, false);
- CLScheduler::get().enqueue(_flatten_scores_kernel, false);
+ _flatten_deltas.run();
+ _flatten_scores.run();
if(_is_qasymm8)
{
diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp
index b659ecfaf6..2d7db0a20a 100644
--- a/src/runtime/CL/functions/CLReductionOperation.cpp
+++ b/src/runtime/CL/functions/CLReductionOperation.cpp
@@ -39,7 +39,7 @@
namespace arm_compute
{
CLReductionOperation::CLReductionOperation(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _results_vector(), _reduction_kernels_vector(), _border_handlers_vector(), _reshape_kernel(), _op(), _num_of_stages(), _reduction_axis(), _is_serial(),
+ : _memory_group(std::move(memory_manager)), _results_vector(), _reduction_kernels_vector(), _border_handlers_vector(), _reshape(), _op(), _num_of_stages(), _reduction_axis(), _is_serial(),
_is_reshape_required(false)
{
}
@@ -152,7 +152,7 @@ Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInf
if(is_reshape_required)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(output_internal, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayer::validate(output_internal, output));
}
return Status{};
@@ -351,7 +351,7 @@ void CLReductionOperation::configure(const CLCompileContext &compile_context, IC
if(_is_reshape_required)
{
- _reshape_kernel.configure(compile_context, &_results_vector.back(), output);
+ _reshape.configure(compile_context, &_results_vector.back(), output);
_results_vector.back().allocator()->allocate();
}
}
@@ -375,7 +375,7 @@ void CLReductionOperation::run()
if(_is_reshape_required)
{
- CLScheduler::get().enqueue(_reshape_kernel, false);
+ _reshape.run();
}
}
} // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLReshapeLayer.cpp b/src/runtime/CL/functions/CLReshapeLayer.cpp
index 13baedb3f9..6fc8608552 100644
--- a/src/runtime/CL/functions/CLReshapeLayer.cpp
+++ b/src/runtime/CL/functions/CLReshapeLayer.cpp
@@ -28,7 +28,43 @@
#include "support/MemorySupport.h"
/** [CLReshapeLayer snippet] **/
-using namespace arm_compute;
+namespace arm_compute
+{
+namespace experimental
+{
+void CLReshapeLayer::configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output)
+{
+ auto k = arm_compute::support::cpp14::make_unique<CLReshapeLayerKernel>();
+ k->configure(compile_context, input, output);
+ _kernel = std::move(k);
+}
+
+Status CLReshapeLayer::validate(const ITensorInfo *input, const ITensorInfo *output)
+{
+ return arm_compute::CLReshapeLayerKernel::validate(input, output);
+}
+
+MemoryRequirements CLReshapeLayer::workspace() const
+{
+ return MemoryRequirements{};
+}
+} // namespace experimental
+
+struct CLReshapeLayer::Impl
+{
+ const ICLTensor *src{ nullptr };
+ ICLTensor *dst{ nullptr };
+ std::unique_ptr<experimental::CLReshapeLayer> op{ nullptr };
+};
+
+CLReshapeLayer::CLReshapeLayer()
+ : _impl(support::cpp14::make_unique<Impl>())
+{
+}
+
+CLReshapeLayer::CLReshapeLayer(CLReshapeLayer &&) = default;
+CLReshapeLayer &CLReshapeLayer::operator=(CLReshapeLayer &&) = default;
+CLReshapeLayer::~CLReshapeLayer() = default;
void CLReshapeLayer::configure(const ICLTensor *input, ICLTensor *output)
{
@@ -37,13 +73,26 @@ void CLReshapeLayer::configure(const ICLTensor *input, ICLTensor *output)
void CLReshapeLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output)
{
- auto k = arm_compute::support::cpp14::make_unique<CLReshapeLayerKernel>();
- k->configure(compile_context, input, output);
- _kernel = std::move(k);
+ _impl->src = input;
+ _impl->dst = output;
+ _impl->op = arm_compute::support::cpp14::make_unique<experimental::CLReshapeLayer>();
+ _impl->op->configure(compile_context, input->info(), output->info());
}
Status CLReshapeLayer::validate(const ITensorInfo *input, const ITensorInfo *output)
{
- return CLReshapeLayerKernel::validate(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ON_ERROR(experimental::CLReshapeLayer::validate(input, output));
+
+ return Status{};
+}
+
+void CLReshapeLayer::run()
+{
+ const InputTensorMap src{ { TensorType::ACL_SRC, _impl->src } };
+ const OutputTensorMap dst{ { TensorType::ACL_DST, _impl->dst } };
+
+ _impl->op->run(src, dst, {});
}
+} // namespace arm_compute
/** [CLReshapeLayer snippet] **/
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index 71ccf9fa01..52fac4f846 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -36,7 +36,7 @@ namespace arm_compute
{
template <bool IS_LOG>
CLSoftmaxLayerGeneric<IS_LOG>::CLSoftmaxLayerGeneric(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _max_shift_exp_sum_kernel(), _norm_kernel(), _flatten_kernel_ptr(), _reshape_kernel(), _max(), _sum(), _tmp(), _input_flattened(), _output_flattened(),
+ : _memory_group(std::move(memory_manager)), _max_shift_exp_sum_kernel(), _norm_kernel(), _flatten_ptr(), _reshape(), _max(), _sum(), _tmp(), _input_flattened(), _output_flattened(),
_needs_flattening(false)
{
}
@@ -64,15 +64,15 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure_reshape_input_kernel(const CLCompi
// 2. first_n_reduce_axes == 4: Reduce all 4 dimensions. This can only be handled by CLReshapeKernel instead of CLFlattenKernel.
if(first_n_reduce_axes == 3)
{
- auto flatten_kernel_ptr = support::cpp14::make_unique<CLFlattenLayerKernel>();
- flatten_kernel_ptr->configure(compile_context, input, &_input_flattened);
- _flatten_kernel_ptr = std::move(flatten_kernel_ptr);
+ auto flatten = support::cpp14::make_unique<CLFlattenLayer>();
+ flatten->configure(compile_context, input, &_input_flattened);
+ _flatten_ptr = std::move(flatten);
}
else
{
- auto reshape_kernel_ptr = support::cpp14::make_unique<CLReshapeLayerKernel>();
- reshape_kernel_ptr->configure(compile_context, input, &_input_flattened);
- _flatten_kernel_ptr = std::move(reshape_kernel_ptr);
+ auto reshape_ptr = support::cpp14::make_unique<CLReshapeLayer>();
+ reshape_ptr->configure(compile_context, input, &_input_flattened);
+ _flatten_ptr = std::move(reshape_ptr);
}
// We need to init the output tensor here. Indeed, the reshape kernel expects
@@ -152,7 +152,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_co
_norm_kernel.configure(compile_context, &_tmp, &_sum, &_output_flattened, softmax_info);
// Reshape the flat output into a the requested (4D) output
- _reshape_kernel.configure(compile_context, &_output_flattened, output);
+ _reshape.configure(compile_context, &_output_flattened, output);
// Allocate the intermediate flat tensors
_input_flattened.allocator()->allocate();
@@ -199,11 +199,11 @@ Status CLSoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I
if(first_n_reduce_axes == 3)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayerKernel::validate(input, &tensor_info_flat));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayer::validate(input, &tensor_info_flat));
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(input, &tensor_info_flat));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayer::validate(input, &tensor_info_flat));
}
}
@@ -231,7 +231,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::run()
if(_needs_flattening)
{
- CLScheduler::get().enqueue(*_flatten_kernel_ptr, false);
+ _flatten_ptr->run();
}
CLScheduler::get().enqueue(_max_shift_exp_sum_kernel, false);
@@ -239,7 +239,7 @@ void CLSoftmaxLayerGeneric<IS_LOG>::run()
if(_needs_flattening)
{
- CLScheduler::get().enqueue(_reshape_kernel, true);
+ _reshape.run();
}
}