diff options
Diffstat (limited to 'src/core/CL/kernels/CLArithmeticAdditionKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLArithmeticAdditionKernel.cpp | 33 |
1 files changed, 22 insertions, 11 deletions
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp index 65422c2bbf..a7625f4303 100644 --- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp @@ -68,17 +68,7 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen } } - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type())) - { - // Check that all data types are the same and all fixed-point positions are the same - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); - } + ARM_COMPUTE_ERROR_THROW_ON(CLArithmeticAdditionKernel::validate(input1->info(), input2->info(), output->info(), policy)); _input1 = input1; _input2 = input2; @@ -119,6 +109,27 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen ICLKernel::configure(win); } +Error CLArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + } + + return Error{}; +} + void CLArithmeticAdditionKernel::run(const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); |