aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp33
1 files changed, 17 insertions, 16 deletions
diff --git a/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp b/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp
index 78ab813f67..87fafd340c 100644
--- a/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp
+++ b/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp
@@ -26,10 +26,11 @@
#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLValidate.h"
#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/utils/misc/Cast.h"
#include "support/StringSupport.h"
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
Status validate_arguments(const ITensorInfo &input, const ITensorInfo &output)
@@ -50,26 +51,22 @@ Status validate_arguments(const ITensorInfo &input, const ITensorInfo &output)
}
} // namespace
-void CLElementWiseUnaryLayerKernel::configure(const ICLTensor *input, ICLTensor *output, const ElementWiseUnary &op)
+void CLElementWiseUnaryLayerKernel::configure(const ITensorInfo *input, ITensorInfo *output, const ElementWiseUnary &op)
{
configure(CLKernelLibrary::get().get_compile_context(), input, output, op);
}
-void CLElementWiseUnaryLayerKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, const ElementWiseUnary &op)
+void CLElementWiseUnaryLayerKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output, const ElementWiseUnary &op)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input->info(), *output->info()));
-
- // Configure kernel window
- _input = input;
- _output = output;
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input, *output));
const std::string kernel_name = "elementwise_unary";
- const int vec_size_x = 16 / output->info()->element_size();
- const int output_width_x = output->info()->tensor_shape().x();
+ const int vec_size_x = 16 / output->element_size();
+ const int output_width_x = output->tensor_shape().x();
const bool multi_access_x = (output_width_x / vec_size_x > 0);
- Window win = calculate_max_window(*output->info());
+ Window win = calculate_max_window(*output);
if(multi_access_x)
{
win.set(Window::DimX,
@@ -79,7 +76,7 @@ void CLElementWiseUnaryLayerKernel::configure(const CLCompileContext &compile_co
// Set kernel build options
CLBuildOptions build_opts;
- build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->data_type()));
build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));
build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0)));
switch(op)
@@ -122,7 +119,7 @@ Status CLElementWiseUnaryLayerKernel::validate(const ITensorInfo *input, const I
return Status{};
}
-void CLElementWiseUnaryLayerKernel::run(const Window &window, cl::CommandQueue &queue)
+void CLElementWiseUnaryLayerKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
@@ -130,12 +127,16 @@ void CLElementWiseUnaryLayerKernel::run(const Window &window, cl::CommandQueue &
Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
Window slice = collapsed.first_slice_window_3D();
+ const auto src = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC));
+ auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(outputs.at(TensorType::ACL_DST));
+
do
{
unsigned int idx = 0;
- add_3D_tensor_argument(idx, _input, slice);
- add_3D_tensor_argument(idx, _output, slice);
+ add_3D_tensor_argument(idx, src, slice);
+ add_3D_tensor_argument(idx, dst, slice);
enqueue(queue, *this, slice, lws_hint());
}
while(collapsed.slide_window_slice_3D(slice));
}
+} // namespace arm_compute