diff options
Diffstat (limited to 'src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp | 39 |
1 files changed, 18 insertions, 21 deletions
diff --git a/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp b/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp index 20cb962b7e..591c26f877 100644 --- a/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp +++ b/src/core/CL/kernels/CLWidthConcatenateLayerKernel.cpp @@ -28,9 +28,9 @@ #include "arm_compute/core/CL/CLValidate.h" #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/Helpers.h" -#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Window.h" +#include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "support/StringSupport.h" @@ -74,7 +74,7 @@ Status validate_arguments(const ITensorInfo *input, unsigned int width_offset, c } // namespace CLWidthConcatenateLayerKernel::CLWidthConcatenateLayerKernel() - : _input(nullptr), _output(nullptr), _width_offset(0) + : _width_offset(0) { } @@ -85,31 +85,24 @@ Status CLWidthConcatenateLayerKernel::validate(const ITensorInfo *input, unsigne return Status{}; } -void CLWidthConcatenateLayerKernel::configure(const ICLTensor *input, unsigned int width_offset, ICLTensor *output) -{ - configure(CLKernelLibrary::get().get_compile_context(), input, width_offset, output); -} - -void CLWidthConcatenateLayerKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, unsigned int width_offset, ICLTensor *output) +void CLWidthConcatenateLayerKernel::configure(const CLCompileContext &compile_context, ITensorInfo *input, unsigned int width_offset, ITensorInfo *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), width_offset, output->info())); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input, width_offset, output)); - _input = input; - _output = output; _width_offset = width_offset; // Add build options CLBuildOptions build_opts; - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->data_type())); build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration)); build_opts.add_option("-DWIDTH_OFFSET=" + support::cpp11::to_string(_width_offset)); - build_opts.add_option("-DDEPTH=" + support::cpp11::to_string(input->info()->dimension(2))); + build_opts.add_option("-DDEPTH=" + support::cpp11::to_string(input->dimension(2))); - if(is_data_type_quantized_asymmetric(input->info()->data_type()) && input->info()->quantization_info() != output->info()->quantization_info()) + if(is_data_type_quantized_asymmetric(input->data_type()) && input->quantization_info() != output->quantization_info()) { - const UniformQuantizationInfo iqinfo = input->info()->quantization_info().uniform(); - const UniformQuantizationInfo oqinfo = output->info()->quantization_info().uniform(); + const UniformQuantizationInfo iqinfo = input->quantization_info().uniform(); + const UniformQuantizationInfo oqinfo = output->quantization_info().uniform(); build_opts.add_option("-DOFFSET_IN1=" + float_to_string_with_full_precision(iqinfo.offset)); build_opts.add_option("-DOFFSET_OUT=" + float_to_string_with_full_precision(oqinfo.offset)); @@ -120,23 +113,27 @@ void CLWidthConcatenateLayerKernel::configure(const CLCompileContext &compile_co // Create kernel _kernel = create_kernel(compile_context, "concatenate_width", build_opts.options()); // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), width_offset, output->info()); + auto win_config = validate_and_configure_window(input, width_offset, output); ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config)); ICLKernel::configure_internal(std::get<1>(win_config)); // Set output valid region - output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape())); + output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape())); } -void CLWidthConcatenateLayerKernel::run(const Window &window, cl::CommandQueue &queue) +void CLWidthConcatenateLayerKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, + const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); + const auto src = utils::cast::polymorphic_downcast<const ICLTensor *>(inputs.at(TensorType::ACL_SRC)); + auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(outputs.at(TensorType::ACL_DST)); + unsigned int idx = 0; - add_4D_tensor_argument(idx, _input, window); - add_4D_tensor_argument(idx, _output, window); + add_4D_tensor_argument(idx, src, window); + add_4D_tensor_argument(idx, dst, window); enqueue(queue, *this, window, lws_hint()); } } // namespace arm_compute |