From d7d7e9035ca28b1b5200b20a73825397d46830fc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 18 Dec 2019 15:40:54 +0000 Subject: COMPMID-2819: Update data type auto_init in Elementwise kernels. Auto initialization functionality is updated in elementwise kernels to check both data types in order to reason for the output data type configuration. Signed-off-by: Georgios Pinitas Change-Id: Ic08b5567d08a3aaca00942acbdbc8aee19686617 Reviewed-on: https://review.mlplatform.org/c/2495 Tested-by: Arm Jenkins Reviewed-by: Giorgio Arena Comments-Addressed: Arm Jenkins --- src/core/CL/kernels/CLElementwiseOperationKernel.cpp | 2 +- src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp | 10 +++++----- src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp | 10 +++++++++- src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp | 11 +++++------ 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/core/CL/kernels/CLElementwiseOperationKernel.cpp b/src/core/CL/kernels/CLElementwiseOperationKernel.cpp index a69e84a16d..66f7ee0a56 100644 --- a/src/core/CL/kernels/CLElementwiseOperationKernel.cpp +++ b/src/core/CL/kernels/CLElementwiseOperationKernel.cpp @@ -195,7 +195,7 @@ std::pair validate_and_configure_window_for_arithmetic_operators { set_format_if_unknown(output, Format::S16); } - else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16) + else if(input1.data_type() == DataType::F16 || input2.data_type() == DataType::F16) { set_format_if_unknown(output, Format::F16); } diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp index 947be18b80..3532526eb8 100644 --- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp @@ -403,7 +403,7 @@ void add_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED(const ITensor *in1, const int x = window_start_x; for(; x <= (window_end_x - window_step_x); x += window_step_x) { - const int8x16_t a = vld1q_s8(non_broadcast_input_ptr + x); + const int8x16_t a = vld1q_s8(non_broadcast_input_ptr + x); const float32x4x4_t af = { { @@ -875,7 +875,7 @@ std::pair validate_and_configure_window(ITensorInfo &input1, ITe { set_format_if_unknown(output, Format::S16); } - else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16) + else if(input1.data_type() == DataType::F16 || input2.data_type() == DataType::F16) { set_format_if_unknown(output, Format::F16); } @@ -883,15 +883,15 @@ std::pair validate_and_configure_window(ITensorInfo &input1, ITe { set_format_if_unknown(output, Format::F32); } - else if(input1.data_type() == DataType::QASYMM8) + else if(input1.data_type() == DataType::QASYMM8 || input2.data_type() == DataType::QASYMM8) { set_data_type_if_unknown(output, DataType::QASYMM8); } - else if(input1.data_type() == DataType::QASYMM8_SIGNED) + else if(input1.data_type() == DataType::QASYMM8_SIGNED || input2.data_type() == DataType::QASYMM8_SIGNED) { set_data_type_if_unknown(output, DataType::QASYMM8_SIGNED); } - else if(input1.data_type() == DataType::QSYMM16) + else if(input1.data_type() == DataType::QSYMM16 || input2.data_type() == DataType::QSYMM16) { set_data_type_if_unknown(output, DataType::QSYMM16); } diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp index 7a2601be26..0695c94927 100644 --- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp @@ -448,7 +448,7 @@ inline std::pair validate_and_configure_window(ITensorInfo &inpu { set_format_if_unknown(output, Format::S16); } - else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16) + else if(input1.data_type() == DataType::F16 || input2.data_type() == DataType::F16) { set_format_if_unknown(output, Format::F16); } @@ -456,6 +456,14 @@ inline std::pair validate_and_configure_window(ITensorInfo &inpu { set_format_if_unknown(output, Format::F32); } + else if(input1.data_type() == DataType::QASYMM8 || input2.data_type() == DataType::QASYMM8) + { + set_data_type_if_unknown(output, DataType::QASYMM8); + } + else if(input1.data_type() == DataType::QASYMM8_SIGNED || input2.data_type() == DataType::QASYMM8_SIGNED) + { + set_data_type_if_unknown(output, DataType::QASYMM8_SIGNED); + } } Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration)); diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index 7ec52f788b..a87588dbb3 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -70,11 +70,10 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); - if(is_data_type_quantized(input1->data_type())|| - is_data_type_quantized(input2->data_type())) + if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type())) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP,"ConvertPolicy cannot be WRAP if datatype is quantized"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized"); } if(output->total_size() > 0) @@ -130,15 +129,15 @@ inline std::pair validate_and_configure_window(ITensorInfo *inpu { set_format_if_unknown(*output, Format::F16); } - else if(input1->data_type() == DataType::QASYMM8) + else if(input1->data_type() == DataType::QASYMM8 || input2->data_type() == DataType::QASYMM8) { set_data_type_if_unknown(*output, DataType::QASYMM8); } - else if(input1->data_type() == DataType::QASYMM8_SIGNED) + else if(input1->data_type() == DataType::QASYMM8_SIGNED || input2->data_type() == DataType::QASYMM8_SIGNED) { set_data_type_if_unknown(*output, DataType::QASYMM8_SIGNED); } - else if(input1->data_type() == DataType::QSYMM16) + else if(input1->data_type() == DataType::QSYMM16 || input2->data_type() == DataType::QSYMM16) { set_data_type_if_unknown(*output, DataType::QSYMM16); } -- cgit v1.2.1