aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels')
-rw-r--r--src/core/CL/kernels/CLArithmeticAdditionKernel.cpp29
-rw-r--r--src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp30
-rw-r--r--src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp17
3 files changed, 60 insertions, 16 deletions
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
index aaa62d0268..0cb0847784 100644
--- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp
@@ -48,9 +48,32 @@ CLArithmeticAdditionKernel::CLArithmeticAdditionKernel()
void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, ConvertPolicy policy)
{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+ // Auto initialize output if not initialized
+ {
+ set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+ if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+ {
+ set_format_if_unknown(*output->info(), Format::S16);
+ }
+ else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+ {
+ set_format_if_unknown(*output->info(), Format::F32);
+ }
+ else if(input1->info()->data_type() == DataType::F16 && input2->info()->data_type() == DataType::F16)
+ {
+ set_format_if_unknown(*output->info(), Format::F16);
+ }
+ }
+
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
+ "Output can only be U8 if both inputs are U8");
_input1 = input1;
_input2 = input2;
@@ -58,12 +81,6 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen
const bool has_float_out = is_data_type_float(output->info()->data_type());
- // Check for invalid combination
- if(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8))
- {
- ARM_COMPUTE_ERROR("You called with the wrong data types.");
- }
-
// Set kernel build options
std::set<std::string> build_opts;
build_opts.emplace((policy == ConvertPolicy::WRAP || has_float_out) ? "-DWRAP" : "-DSATURATE");
diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
index 4c847276da..69f9ff17d3 100644
--- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
+++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp
@@ -45,19 +45,29 @@ CLArithmeticSubtractionKernel::CLArithmeticSubtractionKernel()
void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, ConvertPolicy policy)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
- // Check for invalid combination
- if(output->info()->data_type() == DataType::U8)
- {
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8);
- }
- else
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+ // Auto initialize output if not initialized
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+ if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+ {
+ set_format_if_unknown(*output->info(), Format::S16);
+ }
+ else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+ {
+ set_format_if_unknown(*output->info(), Format::F32);
+ }
}
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
+ "Output can only be U8 if both inputs are U8");
+
_input1 = input1;
_input2 = input2;
_output = output;
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index 84eb434bc9..da417a9020 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -48,6 +48,23 @@ CLPixelWiseMultiplicationKernel::CLPixelWiseMultiplicationKernel()
void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+
+ // Auto initialize output if not initialized
+ {
+ set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
+
+ if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
+ {
+ set_format_if_unknown(*output->info(), Format::S16);
+ }
+ else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
+ {
+ set_format_if_unknown(*output->info(), Format::F32);
+ }
+ }
+
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);