diff options
Diffstat (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index cd1c4b28cc..4b2352f4c2 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -62,17 +62,18 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i if(output->total_size() > 0) { - if(is_data_type_quantized(output->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2, output); - } - const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8 && (input1->data_type() != DataType::QASYMM8 || input2->data_type() != DataType::QASYMM8), + "Output can only be QASYMM8 if both inputs are QASYMM8"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QASYMM8_SIGNED && (input1->data_type() != DataType::QASYMM8_SIGNED || input2->data_type() != DataType::QASYMM8_SIGNED), + "Output can only be QASYMM8 if both inputs are QASYMM8"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::QSYMM16 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), + "Output can only be QSYMM16 if both inputs are QSYMM16"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), "Output can only be S32 if both inputs are QSYMM16"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output"); |