diff options
Diffstat (limited to 'src/core/CL')
-rw-r--r-- | src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp b/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp index 78ab813f67..87fafd340c 100644 --- a/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp +++ b/src/core/CL/kernels/CLElementWiseUnaryLayerKernel.cpp @@ -26,10 +26,11 @@ #include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/CLValidate.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "support/StringSupport.h" -using namespace arm_compute; - +namespace arm_compute +{ namespace { Status validate_arguments(const ITensorInfo &input, const ITensorInfo &output) @@ -50,26 +51,22 @@ Status validate_arguments(const ITensorInfo &input, const ITensorInfo &output) } } // namespace -void CLElementWiseUnaryLayerKernel::configure(const ICLTensor *input, ICLTensor *output, const ElementWiseUnary &op) +void CLElementWiseUnaryLayerKernel::configure(const ITensorInfo *input, ITensorInfo *output, const ElementWiseUnary &op) { configure(CLKernelLibrary::get().get_compile_context(), input, output, op); } -void CLElementWiseUnaryLayerKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, const ElementWiseUnary &op) +void CLElementWiseUnaryLayerKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *input, ITensorInfo *output, const ElementWiseUnary &op) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input->info(), *output->info())); - - // Configure kernel window - _input = input; - _output = output; + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input, *output)); const std::string kernel_name = "elementwise_unary"; - const int vec_size_x = 16 / output->info()->element_size(); - const int output_width_x = output->info()->tensor_shape().x(); + const int vec_size_x = 16 / output->element_size(); + const int output_width_x = output->tensor_shape().x(); const bool multi_access_x = (output_width_x / vec_size_x > 0); - Window win = calculate_max_window(*output->info()); + Window win = calculate_max_window(*output); if(multi_access_x) { win.set(Window::DimX, @@ -79,7 +76,7 @@ void CLElementWiseUnaryLayerKernel::configure(const CLCompileContext &compile_co // Set kernel build options CLBuildOptions build_opts; - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->data_type())); build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0))); switch(op) @@ -122,7 +119,7 @@ Status CLElementWiseUnaryLayerKernel::validate(const ITensorInfo *input, const I return Status{}; } -void CLElementWiseUnaryLayerKernel::run(const Window &window, cl::CommandQueue &queue) +void CLElementWiseUnaryLayerKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); @@ -130,12 +127,16 @@ void CLElementWiseUnaryLayerKernel::run(const Window &window, cl::CommandQueue & Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); Window slice = collapsed.first_slice_window_3D(); + const auto src = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC)); + auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(outputs.at(TensorType::ACL_DST)); + do { unsigned int idx = 0; - add_3D_tensor_argument(idx, _input, slice); - add_3D_tensor_argument(idx, _output, slice); + add_3D_tensor_argument(idx, src, slice); + add_3D_tensor_argument(idx, dst, slice); enqueue(queue, *this, slice, lws_hint()); } while(collapsed.slide_window_slice_3D(slice)); } +} // namespace arm_compute |