diff options
Diffstat (limited to 'src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index 84eb434bc9..da417a9020 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -48,6 +48,23 @@ CLPixelWiseMultiplicationKernel::CLPixelWiseMultiplicationKernel() void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + + // Auto initialize output if not initialized + { + set_shape_if_empty(*output->info(), input1->info()->tensor_shape()); + + if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16) + { + set_format_if_unknown(*output->info(), Format::S16); + } + else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) + { + set_format_if_unknown(*output->info(), Format::F32); + } + } + + ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32); |