From f0dea703ce3e2b465e79298bca95c4952a60f608 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 3 Jul 2017 18:17:28 +0100 Subject: COMPMID-417: Auto configuration for Add/Sub/Mul Neon/CL. Change-Id: I3580de76bc53d342b53443d1077b1407d75a672a Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79570 Tested-by: Kaizen Reviewed-by: Michele DiGiorgio Reviewed-by: Moritz Pflanzer --- src/core/CL/kernels/CLArithmeticAdditionKernel.cpp | 29 ++++++++++++++++----- .../CL/kernels/CLArithmeticSubtractionKernel.cpp | 30 ++++++++++++++-------- .../CL/kernels/CLPixelWiseMultiplicationKernel.cpp | 17 ++++++++++++ 3 files changed, 60 insertions(+), 16 deletions(-) (limited to 'src/core/CL') diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp index aaa62d0268..0cb0847784 100644 --- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp @@ -48,9 +48,32 @@ CLArithmeticAdditionKernel::CLArithmeticAdditionKernel() void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, ConvertPolicy policy) { + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + + // Auto initialize output if not initialized + { + set_shape_if_empty(*output->info(), input1->info()->tensor_shape()); + + if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16) + { + set_format_if_unknown(*output->info(), Format::S16); + } + else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) + { + set_format_if_unknown(*output->info(), Format::F32); + } + else if(input1->info()->data_type() == DataType::F16 && input2->info()->data_type() == DataType::F16) + { + set_format_if_unknown(*output->info(), Format::F16); + } + } + + ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); _input1 = input1; _input2 = input2; @@ -58,12 +81,6 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen const bool has_float_out = is_data_type_float(output->info()->data_type()); - // Check for invalid combination - if(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8)) - { - ARM_COMPUTE_ERROR("You called with the wrong data types."); - } - // Set kernel build options std::set build_opts; build_opts.emplace((policy == ConvertPolicy::WRAP || has_float_out) ? "-DWRAP" : "-DSATURATE"); diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp index 4c847276da..69f9ff17d3 100644 --- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp @@ -45,19 +45,29 @@ CLArithmeticSubtractionKernel::CLArithmeticSubtractionKernel() void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, ConvertPolicy policy) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); - // Check for invalid combination - if(output->info()->data_type() == DataType::U8) - { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8); - } - else + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + + // Auto initialize output if not initialized { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); + set_shape_if_empty(*output->info(), input1->info()->tensor_shape()); + + if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16) + { + set_format_if_unknown(*output->info(), Format::S16); + } + else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) + { + set_format_if_unknown(*output->info(), Format::F32); + } } + ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + _input1 = input1; _input2 = input2; _output = output; diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index 84eb434bc9..da417a9020 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -48,6 +48,23 @@ CLPixelWiseMultiplicationKernel::CLPixelWiseMultiplicationKernel() void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + + // Auto initialize output if not initialized + { + set_shape_if_empty(*output->info(), input1->info()->tensor_shape()); + + if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16) + { + set_format_if_unknown(*output->info(), Format::S16); + } + else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) + { + set_format_if_unknown(*output->info(), Format::F32); + } + } + + ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); -- cgit v1.2.1