diff options
Diffstat (limited to 'src/cpu/kernels/elementwise_binary/generic/neon/impl.h')
-rw-r--r-- | src/cpu/kernels/elementwise_binary/generic/neon/impl.h | 852 |
1 files changed, 532 insertions, 320 deletions
diff --git a/src/cpu/kernels/elementwise_binary/generic/neon/impl.h b/src/cpu/kernels/elementwise_binary/generic/neon/impl.h index 98b154e8fd..98f7e8b949 100644 --- a/src/cpu/kernels/elementwise_binary/generic/neon/impl.h +++ b/src/cpu/kernels/elementwise_binary/generic/neon/impl.h @@ -39,7 +39,7 @@ typename VectorType::type elementwise_arithm_op(const typename VectorType::type vec_type res = wrapper::vdup_n(static_cast<scalar_type>(0), tag_type{}); - switch(op) + switch (op) { case ArithmeticOperation::MAX: res = wrapper::vmax(a, b); @@ -71,7 +71,9 @@ typename VectorType::type elementwise_arithm_op(const typename VectorType::type } template <ArithmeticOperation op, typename ScalarType, typename VectorType> -typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, const ScalarType &broadcast_value, const bool reorder) +typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorType::type &a, + const ScalarType &broadcast_value, + const bool reorder) { using tag_type = typename VectorType::tag_type; using vec_type = typename VectorType::type; @@ -81,10 +83,15 @@ typename VectorType::type elementwise_arithm_op_broadcast(const typename VectorT } template <typename InputScalarType, typename OutputScalarType, typename InputVectorType> -void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, - OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &), - int (*broadcast_func)(int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool), - int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *)) +void elementwise_op( + const ITensor *in1, + const ITensor *in2, + ITensor *out, + const Window &window, + OutputScalarType (*scalar_func)(const InputScalarType &, const InputScalarType &), + int (*broadcast_func)( + int, int, int, const InputScalarType *, const InputScalarType &, OutputScalarType *, const bool), + int (*neon_func)(int, int, int, const InputScalarType *, const InputScalarType *, OutputScalarType *)) { // Create input windows Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); @@ -99,7 +106,7 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const const auto window_end_x = static_cast<int>(window.x().end()); const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x(); - if(is_broadcast_across_x) + if (is_broadcast_across_x) { const bool is_broadcast_input_2 = input2_win.x().step() == 0; Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win; @@ -114,20 +121,26 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr()); - const auto non_broadcast_input_ptr = reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr()); - const InputScalarType broadcast_value = *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr()); - - int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_value, output_ptr, !is_broadcast_input_2); - for(; x < window_end_x; ++x) + execute_window_loop( + win, + [&](const Coordinates &) { - const auto a = *(non_broadcast_input_ptr + x); - *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, !is_broadcast_input_2 ? a : broadcast_value); - } - }, - broadcast_input, non_broadcast_input, output); + auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr()); + const auto non_broadcast_input_ptr = + reinterpret_cast<const InputScalarType *>(non_broadcast_input.ptr()); + const InputScalarType broadcast_value = + *reinterpret_cast<const InputScalarType *>(broadcast_input.ptr()); + + int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, + broadcast_value, output_ptr, !is_broadcast_input_2); + for (; x < window_end_x; ++x) + { + const auto a = *(non_broadcast_input_ptr + x); + *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? broadcast_value : a, + !is_broadcast_input_2 ? a : broadcast_value); + } + }, + broadcast_input, non_broadcast_input, output); } else { @@ -139,21 +152,23 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Iterator input2(in2, input2_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr()); - const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr()); - const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr()); - - int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr); - for(; x < window_end_x; ++x) + execute_window_loop( + win, + [&](const Coordinates &) { - const auto a = *(input1_ptr + x); - const auto b = *(input2_ptr + x); - *(output_ptr + x) = (*scalar_func)(a, b); - } - }, - input1, input2, output); + auto output_ptr = reinterpret_cast<OutputScalarType *>(output.ptr()); + const auto input1_ptr = reinterpret_cast<const InputScalarType *>(input1.ptr()); + const auto input2_ptr = reinterpret_cast<const InputScalarType *>(input2.ptr()); + + int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr); + for (; x < window_end_x; ++x) + { + const auto a = *(input1_ptr + x); + const auto b = *(input2_ptr + x); + *(output_ptr + x) = (*scalar_func)(a, b); + } + }, + input1, input2, output); } } @@ -162,7 +177,7 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar { auto res = ScalarType(0); - switch(op) + switch (op) { case ArithmeticOperation::MAX: res = std::max(a, b); @@ -183,10 +198,10 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar case ArithmeticOperation::DIV: { res = a / b; - if(std::is_integral<ScalarType>::value) + if (std::is_integral<ScalarType>::value) { res = (b == 0) ? 0 : res; - if(static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0))) + if (static_cast<int32_t>(a) % static_cast<int32_t>(b) != 0 && ((a < 0) != (b < 0))) { --res; } @@ -205,43 +220,56 @@ inline ScalarType elementwise_arithm_op_scalar(const ScalarType &a, const Scalar } template <> -inline int32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, const int32x4_t &b) +inline int32x4_t +elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<int32_t, 4>>(const int32x4_t &a, + const int32x4_t &b) { return vcvtq_s32_f32(vfloorq_f32(wrapper::vdiv(vcvtq_f32_s32(a), vcvtq_f32_s32(b)))); } template <> -inline float32x4_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b) +inline float32x4_t +elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, + const float32x4_t &b) { return wrapper::vdiv(a, b); } template <> -inline float32x4_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, const float32x4_t &b) +inline float32x4_t +elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float, 4>>(const float32x4_t &a, + const float32x4_t &b) { return wrapper::vpow(a, b); } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template <> -inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b) +inline float16x8_t elementwise_arithm_op<ArithmeticOperation::DIV, typename wrapper::traits::neon_vector<float16_t, 8>>( + const float16x8_t &a, const float16x8_t &b) { return wrapper::vdiv(a, b); } template <> -inline float16x8_t elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>(const float16x8_t &a, const float16x8_t &b) +inline float16x8_t +elementwise_arithm_op<ArithmeticOperation::POWER, typename wrapper::traits::neon_vector<float16_t, 8>>( + const float16x8_t &a, const float16x8_t &b) { return wrapper::vpow(a, b); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template <ArithmeticOperation op, typename ScalarType, typename VectorType> -inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int window_step_x, - const ScalarType *input1_ptr, const ScalarType *input2_ptr, ScalarType *output_ptr) +inline int elementwise_arithm_op_loop(int window_start_x, + int window_end_x, + int window_step_x, + const ScalarType *input1_ptr, + const ScalarType *input2_ptr, + ScalarType *output_ptr) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { const auto a = wrapper::vloadq(input1_ptr + x); const auto b = wrapper::vloadq(input2_ptr + x); @@ -251,14 +279,20 @@ inline int elementwise_arithm_op_loop(int window_start_x, int window_end_x, int } template <ArithmeticOperation op, typename ScalarType, typename VectorType> -inline int elementwise_arithm_op_broadcast_loop(int window_start_x, int window_end_x, int window_step_x, - const ScalarType *non_broadcast_input_ptr, const ScalarType &broadcast_value, ScalarType *output_ptr, const bool reorder) +inline int elementwise_arithm_op_broadcast_loop(int window_start_x, + int window_end_x, + int window_step_x, + const ScalarType *non_broadcast_input_ptr, + const ScalarType &broadcast_value, + ScalarType *output_ptr, + const bool reorder) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { const auto a = wrapper::vloadq((non_broadcast_input_ptr + x)); - wrapper::vstore(output_ptr + x, elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder)); + wrapper::vstore(output_ptr + x, + elementwise_arithm_op_broadcast<op, ScalarType, VectorType>(a, broadcast_value, reorder)); } return x; } @@ -268,10 +302,10 @@ void elementwise_arithm_op(const ITensor *in1, const ITensor *in2, ITensor *out, { using scalar_type = typename VectorType::scalar_type; - elementwise_op<scalar_type, scalar_type, VectorType>(in1, in2, out, window, - &elementwise_arithm_op_scalar<op, scalar_type>, - &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>, - &elementwise_arithm_op_loop<op, scalar_type, VectorType>); + elementwise_op<scalar_type, scalar_type, VectorType>( + in1, in2, out, window, &elementwise_arithm_op_scalar<op, scalar_type>, + &elementwise_arithm_op_broadcast_loop<op, scalar_type, VectorType>, + &elementwise_arithm_op_loop<op, scalar_type, VectorType>); } template <ComparisonOperation op, typename InputScalarType> @@ -279,7 +313,7 @@ inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputS { bool res = false; - switch(op) + switch (op) { case ComparisonOperation::Equal: res = (a == b); @@ -308,9 +342,9 @@ inline uint8_t elementwise_comp_op_scalar(const InputScalarType &a, const InputS template <ComparisonOperation op, typename InputVectorType, typename OutputVectorType> inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const InputVectorType &b) { - OutputVectorType res = { 0, 0, 0, 0 }; + OutputVectorType res = {0, 0, 0, 0}; - switch(op) + switch (op) { case ComparisonOperation::Equal: res = wrapper::vceq(a, b); @@ -338,53 +372,75 @@ inline OutputVectorType elementwise_comp_op(const InputVectorType &a, const Inpu } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType, typename OutputVectorType> -inline OutputVectorType elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder) +inline OutputVectorType +elementwise_comp_op_broadcast(const InputVectorType &a, const InputScalarType &broadcast_value, const bool reorder) { InputVectorType broadcast_vector = wrapper::vdup_n(broadcast_value, wrapper::traits::vector_128_tag()); - return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, reorder ? a : broadcast_vector); + return elementwise_comp_op<op, InputVectorType, OutputVectorType>(reorder ? broadcast_vector : a, + reorder ? a : broadcast_vector); } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> -inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, int window_end_x, int window_step_x, - const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder) +inline int elementwise_comp_op_broadcast_8_loop(int window_start_x, + int window_end_x, + int window_step_x, + const InputScalarType *non_broadcast_input_ptr, + const InputScalarType &broadcast_value, + uint8_t *output_ptr, + const bool reorder) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder); + const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint8x16_t>( + wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder); wrapper::vstore(output_ptr + x, a); } return x; } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> -inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, int window_end_x, int window_step_x, - const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder) +inline int elementwise_comp_op_broadcast_16_loop(int window_start_x, + int window_end_x, + int window_step_x, + const InputScalarType *non_broadcast_input_ptr, + const InputScalarType &broadcast_value, + uint8_t *output_ptr, + const bool reorder) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder); + const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint16x8_t>( + wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder); wrapper::vstore(output_ptr + x, wrapper::vmovn(a)); } return x; } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> -inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_end_x, int window_step_x, - const InputScalarType *non_broadcast_input_ptr, const InputScalarType &broadcast_value, uint8_t *output_ptr, const bool reorder) +inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, + int window_end_x, + int window_step_x, + const InputScalarType *non_broadcast_input_ptr, + const InputScalarType &broadcast_value, + uint8_t *output_ptr, + const bool reorder) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder); - const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder); + const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>( + wrapper::vloadq(non_broadcast_input_ptr + x), broadcast_value, reorder); + const auto b = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>( + wrapper::vloadq(non_broadcast_input_ptr + x + 4), broadcast_value, reorder); wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(a), wrapper::vmovn(b)))); } - if(x <= window_end_x - 4) + if (x <= window_end_x - 4) { - const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>(wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder); - for(int i = 0; i < 4; i++) + const auto a = elementwise_comp_op_broadcast<op, InputScalarType, InputVectorType, uint32x4_t>( + wrapper::vloadq((non_broadcast_input_ptr + x)), broadcast_value, reorder); + for (int i = 0; i < 4; i++) { *(output_ptr + x + i) = wrapper::vgetlane(a, i); } @@ -394,11 +450,15 @@ inline int elementwise_comp_op_broadcast_32_loop(int window_start_x, int window_ } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> -inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int window_step_x, - const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr) +inline int elementwise_comp_op_8_loop(int window_start_x, + int window_end_x, + int window_step_x, + const InputScalarType *input1_ptr, + const InputScalarType *input2_ptr, + uint8_t *output_ptr) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { const auto a = wrapper::vloadq(input1_ptr + x); const auto b = wrapper::vloadq(input2_ptr + x); @@ -409,11 +469,15 @@ inline int elementwise_comp_op_8_loop(int window_start_x, int window_end_x, int } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> -inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int window_step_x, - const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr) +inline int elementwise_comp_op_16_loop(int window_start_x, + int window_end_x, + int window_step_x, + const InputScalarType *input1_ptr, + const InputScalarType *input2_ptr, + uint8_t *output_ptr) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { const auto a = wrapper::vloadq(input1_ptr + x); const auto b = wrapper::vloadq(input2_ptr + x); @@ -424,11 +488,15 @@ inline int elementwise_comp_op_16_loop(int window_start_x, int window_end_x, int } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> -inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int window_step_x, - const InputScalarType *input1_ptr, const InputScalarType *input2_ptr, uint8_t *output_ptr) +inline int elementwise_comp_op_32_loop(int window_start_x, + int window_end_x, + int window_step_x, + const InputScalarType *input1_ptr, + const InputScalarType *input2_ptr, + uint8_t *output_ptr) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { auto a = wrapper::vloadq(input1_ptr + x); auto b = wrapper::vloadq(input2_ptr + x); @@ -438,12 +506,12 @@ inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int const auto res2 = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b); wrapper::vstore(output_ptr + x, wrapper::vmovn(wrapper::vcombine(wrapper::vmovn(res), wrapper::vmovn(res2)))); } - if(x <= window_end_x - 4) + if (x <= window_end_x - 4) { const auto a = wrapper::vloadq(input1_ptr + x); const auto b = wrapper::vloadq(input2_ptr + x); const auto res = elementwise_comp_op<op, InputVectorType, uint32x4_t>(a, b); - for(int i = 0; i < 4; i++) + for (int i = 0; i < 4; i++) { *(output_ptr + x + i) = wrapper::vgetlane(res, i); } @@ -455,57 +523,59 @@ inline int elementwise_comp_op_32_loop(int window_start_x, int window_end_x, int template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> void elementwise_comp_op_8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window, - &elementwise_comp_op_scalar<op, InputScalarType>, - &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>, - &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>); + elementwise_op<InputScalarType, uint8_t, InputVectorType>( + in1, in2, out, window, &elementwise_comp_op_scalar<op, InputScalarType>, + &elementwise_comp_op_broadcast_8_loop<op, InputScalarType, InputVectorType>, + &elementwise_comp_op_8_loop<op, InputScalarType, InputVectorType>); } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> void elementwise_comp_op_16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window, - &elementwise_comp_op_scalar<op, InputScalarType>, - &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>, - &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>); + elementwise_op<InputScalarType, uint8_t, InputVectorType>( + in1, in2, out, window, &elementwise_comp_op_scalar<op, InputScalarType>, + &elementwise_comp_op_broadcast_16_loop<op, InputScalarType, InputVectorType>, + &elementwise_comp_op_16_loop<op, InputScalarType, InputVectorType>); } template <ComparisonOperation op, typename InputScalarType, typename InputVectorType> void elementwise_comp_op_32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - elementwise_op<InputScalarType, uint8_t, InputVectorType>(in1, in2, out, window, - &elementwise_comp_op_scalar<op, InputScalarType>, - &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>, - &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>); + elementwise_op<InputScalarType, uint8_t, InputVectorType>( + in1, in2, out, window, &elementwise_comp_op_scalar<op, InputScalarType>, + &elementwise_comp_op_broadcast_32_loop<op, InputScalarType, InputVectorType>, + &elementwise_comp_op_32_loop<op, InputScalarType, InputVectorType>); } inline float32x4x4_t load_quantized(const uint8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale) { - qasymm8x16_t x = vld1q_u8(input1_ptr); - const float32x4x4_t out = - { - { - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale), - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), scale), - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale), - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), scale), - } - }; + qasymm8x16_t x = vld1q_u8(input1_ptr); + const float32x4x4_t out = {{ + vmulq_f32( + vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(x))))), offset)), + scale), + vmulq_f32( + vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(x))))), offset)), + scale), + vmulq_f32( + vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(x))))), offset)), + scale), + vmulq_f32(vcvtq_f32_s32( + vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(x))))), offset)), + scale), + }}; return out; } inline float32x4x4_t load_quantized_signed(const int8_t *input1_ptr, const int32x4_t &offset, const float32x4_t &scale) { - qasymm8x16_signed_t x = vld1q_s8(input1_ptr); - const float32x4x4_t out = - { - { - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale), - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale), - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale), - vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale), - } - }; + qasymm8x16_signed_t x = vld1q_s8(input1_ptr); + const float32x4x4_t out = {{ + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_low_s8(x)))), offset)), scale), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vmovl_s16(vget_high_s16(vmovl_s8(vget_high_s8(x)))), offset)), scale), + }}; return out; } @@ -523,17 +593,15 @@ inline void store_quantized(uint8_t *output_ptr, const int32x4x4_t &out) vst1q_u8(output_ptr, vcombine_u8(pa, pb)); } -inline void store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale) +inline void +store_quantized(uint8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale) { - int32x4x4_t out = - { - { - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)), - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)), - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)), - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)), - } - }; + int32x4x4_t out = {{ + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)), + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)), + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)), + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)), + }}; store_quantized(output_ptr, out); } @@ -544,17 +612,17 @@ inline void store_quantized_signed(int8_t *output_ptr, const int32x4x4_t &out) vst1q_s8(output_ptr, vcombine_s8(pa, pb)); } -inline void store_quantized_signed(int8_t *output_ptr, const float32x4x4_t &rf, const float32x4_t &offset, const float32x4_t &invscale) +inline void store_quantized_signed(int8_t *output_ptr, + const float32x4x4_t &rf, + const float32x4_t &offset, + const float32x4_t &invscale) { - int32x4x4_t out = - { - { - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)), - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)), - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)), - vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)), - } - }; + int32x4x4_t out = {{ + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[0], invscale)), + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[1], invscale)), + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[2], invscale)), + vcvtq_s32_f32(vmlaq_f32(offset, rf.val[3], invscale)), + }}; store_quantized_signed(output_ptr, out); } @@ -565,7 +633,8 @@ inline uint8_t elementwise_arithm_op_quantized_scalar(const float &a, const floa } template <ArithmeticOperation op> -inline int8_t elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo) +inline int8_t +elementwise_arithm_op_quantized_signed_scalar(const float &a, const float &b, UniformQuantizationInfo qinfo) { return quantize_qasymm8_signed(elementwise_arithm_op_scalar<op>(a, b), qinfo); } @@ -574,15 +643,12 @@ template <ArithmeticOperation op> float32x4x4_t elementwise_arithm_op(const float32x4x4_t &a, const float32x4x4_t &b) { using neon_vector_float = wrapper::traits::neon_vector<float, 4>; - float32x4x4_t out = - { - { - elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]), - elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]), - elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]), - elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]), - } - }; + float32x4x4_t out = {{ + elementwise_arithm_op<op, neon_vector_float>(a.val[0], b.val[0]), + elementwise_arithm_op<op, neon_vector_float>(a.val[1], b.val[1]), + elementwise_arithm_op<op, neon_vector_float>(a.val[2], b.val[2]), + elementwise_arithm_op<op, neon_vector_float>(a.val[3], b.val[3]), + }}; return out; } @@ -596,26 +662,29 @@ inline uint8_t elementwise_comp_op_quantized_scalar(const float &a, const float template <ComparisonOperation op> inline uint32x4x4_t elementwise_comp_op(const float32x4x4_t &a, const float32x4x4_t &b) { - uint32x4x4_t out = - { - { - elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]), - elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]), - elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]), - elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3]) - } - }; + uint32x4x4_t out = {{elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[0], b.val[0]), + elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[1], b.val[1]), + elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[2], b.val[2]), + elementwise_comp_op<op, float32x4_t, uint32x4_t>(a.val[3], b.val[3])}}; return out; } template <ArithmeticOperation op> -inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x, - const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr, - int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2, - float32x4_t voffseto, float32x4_t invvscaleo) +inline int elementwise_arithm_op_quantized_loop(int window_start_x, + int window_end_x, + int window_step_x, + const uint8_t *input1_ptr, + const uint8_t *input2_ptr, + uint8_t *output_ptr, + int32x4_t voffset1, + int32x4_t voffset2, + float32x4_t vscale1, + float32x4_t vscale2, + float32x4_t voffseto, + float32x4_t invvscaleo) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { // Get inputs and compute output const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1); @@ -627,13 +696,21 @@ inline int elementwise_arithm_op_quantized_loop(int window_start_x, int window_e } template <ArithmeticOperation op> -inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int window_end_x, int window_step_x, - const int8_t *input1_ptr, const int8_t *input2_ptr, int8_t *output_ptr, - int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2, - float32x4_t voffseto, float32x4_t invvscaleo) +inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, + int window_end_x, + int window_step_x, + const int8_t *input1_ptr, + const int8_t *input2_ptr, + int8_t *output_ptr, + int32x4_t voffset1, + int32x4_t voffset2, + float32x4_t vscale1, + float32x4_t vscale2, + float32x4_t voffseto, + float32x4_t invvscaleo) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { // Get inputs and compute output const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1); @@ -645,45 +722,71 @@ inline int elementwise_arithm_op_quantized_singed_loop(int window_start_x, int w } template <ArithmeticOperation op> -inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x, - const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr, - int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast, - float32x4_t voffseto, float32x4_t invvscaleo, bool reorder) +inline int elementwise_arithm_op_quantized_broadcast_loop(int window_start_x, + int window_end_x, + int window_step_x, + const uint8_t *non_broadcast_input_ptr, + float32x4x4_t broadcast_vector, + uint8_t *output_ptr, + int32x4_t voffset_non_broadcast, + float32x4_t vscale_non_broadcast, + float32x4_t voffseto, + float32x4_t invvscaleo, + bool reorder) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); - const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); + const float32x4x4_t af = + load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); + const float32x4x4_t rf = + elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); store_quantized(output_ptr + x, rf, voffseto, invvscaleo); } return x; } template <ArithmeticOperation op> -inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x, - const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, int8_t *output_ptr, - int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast, - float32x4_t voffseto, float32x4_t invvscaleo, bool reorder) +inline int elementwise_arithm_op_quantized_signed_broadcast_loop(int window_start_x, + int window_end_x, + int window_step_x, + const int8_t *non_broadcast_input_ptr, + float32x4x4_t broadcast_vector, + int8_t *output_ptr, + int32x4_t voffset_non_broadcast, + float32x4_t vscale_non_broadcast, + float32x4_t voffseto, + float32x4_t invvscaleo, + bool reorder) { int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); - const float32x4x4_t rf = elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); + const float32x4x4_t af = + load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); + const float32x4x4_t rf = + elementwise_arithm_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); store_quantized_signed(output_ptr + x, rf, voffseto, invvscaleo); } return x; } template <ComparisonOperation op> -inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end_x, int window_step_x, - const uint8_t *input1_ptr, const uint8_t *input2_ptr, uint8_t *output_ptr, - int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2, - float32x4_t voffseto, float32x4_t invvscaleo) +inline int elementwise_comp_op_quantized_loop(int window_start_x, + int window_end_x, + int window_step_x, + const uint8_t *input1_ptr, + const uint8_t *input2_ptr, + uint8_t *output_ptr, + int32x4_t voffset1, + int32x4_t voffset2, + float32x4_t vscale1, + float32x4_t vscale2, + float32x4_t voffseto, + float32x4_t invvscaleo) { ARM_COMPUTE_UNUSED(voffseto, invvscaleo); int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { const float32x4x4_t af = load_quantized(input1_ptr + x, voffset1, vscale1); const float32x4x4_t bf = load_quantized(input2_ptr + x, voffset2, vscale2); @@ -694,14 +797,22 @@ inline int elementwise_comp_op_quantized_loop(int window_start_x, int window_end } template <ComparisonOperation op> -inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int window_end_x, int window_step_x, - const int8_t *input1_ptr, const int8_t *input2_ptr, uint8_t *output_ptr, - int32x4_t voffset1, int32x4_t voffset2, float32x4_t vscale1, float32x4_t vscale2, - float32x4_t voffseto, float32x4_t invvscaleo) +inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, + int window_end_x, + int window_step_x, + const int8_t *input1_ptr, + const int8_t *input2_ptr, + uint8_t *output_ptr, + int32x4_t voffset1, + int32x4_t voffset2, + float32x4_t vscale1, + float32x4_t vscale2, + float32x4_t voffseto, + float32x4_t invvscaleo) { ARM_COMPUTE_UNUSED(voffseto, invvscaleo); int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { const float32x4x4_t af = load_quantized_signed(input1_ptr + x, voffset1, vscale1); const float32x4x4_t bf = load_quantized_signed(input2_ptr + x, voffset2, vscale2); @@ -712,46 +823,85 @@ inline int elementwise_comp_op_quantized_signed_loop(int window_start_x, int win } template <ComparisonOperation op> -inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, int window_end_x, int window_step_x, - const uint8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr, - int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast, - float32x4_t voffseto, float32x4_t invvscaleo, bool reorder) +inline int elementwise_comp_op_quantized_broadcast_loop(int window_start_x, + int window_end_x, + int window_step_x, + const uint8_t *non_broadcast_input_ptr, + float32x4x4_t broadcast_vector, + uint8_t *output_ptr, + int32x4_t voffset_non_broadcast, + float32x4_t vscale_non_broadcast, + float32x4_t voffseto, + float32x4_t invvscaleo, + bool reorder) { ARM_COMPUTE_UNUSED(voffseto, invvscaleo); int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const float32x4x4_t af = load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); - const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); + const float32x4x4_t af = + load_quantized(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); + const uint32x4x4_t rf = + elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); store_quantized(output_ptr + x, rf); } return x; } template <ComparisonOperation op> -inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, int window_end_x, int window_step_x, - const int8_t *non_broadcast_input_ptr, float32x4x4_t broadcast_vector, uint8_t *output_ptr, - int32x4_t voffset_non_broadcast, float32x4_t vscale_non_broadcast, - float32x4_t voffseto, float32x4_t invvscaleo, bool reorder) +inline int elementwise_comp_op_quantized_signed_broadcast_loop(int window_start_x, + int window_end_x, + int window_step_x, + const int8_t *non_broadcast_input_ptr, + float32x4x4_t broadcast_vector, + uint8_t *output_ptr, + int32x4_t voffset_non_broadcast, + float32x4_t vscale_non_broadcast, + float32x4_t voffseto, + float32x4_t invvscaleo, + bool reorder) { ARM_COMPUTE_UNUSED(voffseto, invvscaleo); int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) + for (; x <= (window_end_x - window_step_x); x += window_step_x) { - const float32x4x4_t af = load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); - const uint32x4x4_t rf = elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); + const float32x4x4_t af = + load_quantized_signed(non_broadcast_input_ptr + x, voffset_non_broadcast, vscale_non_broadcast); + const uint32x4x4_t rf = + elementwise_comp_op<op>(reorder ? broadcast_vector : af, reorder ? af : broadcast_vector); store_quantized(output_ptr + x, rf); } return x; } -inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, +inline void elementwise_op_quantized(const ITensor *in1, + const ITensor *in2, + ITensor *out, + const Window &window, uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo), - int (*broadcast_func)(int, int, int, const uint8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t, - float32x4_t, float32x4_t, const bool), - int (*neon_func)(int, int, int, const uint8_t *, const uint8_t *, uint8_t *, - int32x4_t, int32x4_t, float32x4_t, float32x4_t, - float32x4_t, float32x4_t)) + int (*broadcast_func)(int, + int, + int, + const uint8_t *, + float32x4x4_t, + uint8_t *, + int32x4_t, + float32x4_t, + float32x4_t, + float32x4_t, + const bool), + int (*neon_func)(int, + int, + int, + const uint8_t *, + const uint8_t *, + uint8_t *, + int32x4_t, + int32x4_t, + float32x4_t, + float32x4_t, + float32x4_t, + float32x4_t)) { // Create input windows Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); @@ -772,7 +922,7 @@ inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITe const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset + 0.5f); const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale); - if(is_broadcast_across_x) + if (is_broadcast_across_x) { // Select the broadcast input on the X axis const bool is_broadcast_input_2 = input2_win.x().step() == 0; @@ -794,24 +944,28 @@ inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITe Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr()); - const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); + execute_window_loop( + win, + [&](const Coordinates &) + { + const auto non_broadcast_input_ptr = reinterpret_cast<const uint8_t *>(non_broadcast_input.ptr()); + const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); - const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr()); - const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo); + const uint8_t broadcast_value = *reinterpret_cast<const uint8_t *>(broadcast_input.ptr()); + const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_u8(broadcast_value), broadcast_qinfo); - int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr, - voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2); - for(; x < window_end_x; ++x) - { - const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo); - const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo); - *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo); - } - }, - broadcast_input, non_broadcast_input, output); + int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, + broadcast_vector, output_ptr, voffset_non_broadcast, vscale_non_broadcast, + voffseto, invvscaleo, !is_broadcast_input_2); + for (; x < window_end_x; ++x) + { + const float afs = dequantize_qasymm8(*(non_broadcast_input_ptr + x), non_broadcast_qinfo); + const float bfs = dequantize_qasymm8(broadcast_value, broadcast_qinfo); + *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, + !is_broadcast_input_2 ? afs : bfs, output_qinfo); + } + }, + broadcast_input, non_broadcast_input, output); } else { @@ -834,32 +988,56 @@ inline void elementwise_op_quantized(const ITensor *in1, const ITensor *in2, ITe Iterator input2(in2, input2_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr()); - const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr()); - const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); - - int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2, - vscale1, vscale2, voffseto, invvscaleo); - for(; x < window_end_x; ++x) + execute_window_loop( + win, + [&](const Coordinates &) { - const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo); - const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo); - *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo); - } - }, - input1, input2, output); + const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr()); + const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr()); + const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); + + int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, + voffset1, voffset2, vscale1, vscale2, voffseto, invvscaleo); + for (; x < window_end_x; ++x) + { + const float afs = dequantize_qasymm8(*(input1_ptr + x), input1_qinfo); + const float bfs = dequantize_qasymm8(*(input2_ptr + x), input2_qinfo); + *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo); + } + }, + input1, input2, output); } } -inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, - uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo), - int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, uint8_t *, int32x4_t, float32x4_t, - float32x4_t, float32x4_t, const bool), - int (*neon_func)(int, int, int, const int8_t *, const int8_t *, uint8_t *, - int32x4_t, int32x4_t, float32x4_t, float32x4_t, - float32x4_t, float32x4_t)) +inline void +elementwise_comp_quantized_signed(const ITensor *in1, + const ITensor *in2, + ITensor *out, + const Window &window, + uint8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo), + int (*broadcast_func)(int, + int, + int, + const int8_t *, + float32x4x4_t, + uint8_t *, + int32x4_t, + float32x4_t, + float32x4_t, + float32x4_t, + const bool), + int (*neon_func)(int, + int, + int, + const int8_t *, + const int8_t *, + uint8_t *, + int32x4_t, + int32x4_t, + float32x4_t, + float32x4_t, + float32x4_t, + float32x4_t)) { // Create input windows Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); @@ -879,7 +1057,7 @@ inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset); const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale); - if(is_broadcast_across_x) + if (is_broadcast_across_x) { // Select the broadcast input on the X axis const bool is_broadcast_input_2 = input2_win.x().step() == 0; @@ -901,24 +1079,28 @@ inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr()); - const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); + execute_window_loop( + win, + [&](const Coordinates &) + { + const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr()); + const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); - const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr()); - const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo); + const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr()); + const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo); - int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr, - voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2); - for(; x < window_end_x; ++x) - { - const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo); - const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo); - *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo); - } - }, - broadcast_input, non_broadcast_input, output); + int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, + broadcast_vector, output_ptr, voffset_non_broadcast, vscale_non_broadcast, + voffseto, invvscaleo, !is_broadcast_input_2); + for (; x < window_end_x; ++x) + { + const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo); + const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo); + *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, + !is_broadcast_input_2 ? afs : bfs, output_qinfo); + } + }, + broadcast_input, non_broadcast_input, output); } else { @@ -941,32 +1123,56 @@ inline void elementwise_comp_quantized_signed(const ITensor *in1, const ITensor Iterator input2(in2, input2_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr()); - const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr()); - const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); - - int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2, - vscale1, vscale2, voffseto, invvscaleo); - for(; x < window_end_x; ++x) + execute_window_loop( + win, + [&](const Coordinates &) { - const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo); - const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo); - *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo); - } - }, - input1, input2, output); + const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr()); + const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr()); + const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr()); + + int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, + voffset1, voffset2, vscale1, vscale2, voffseto, invvscaleo); + for (; x < window_end_x; ++x) + { + const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo); + const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo); + *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo); + } + }, + input1, input2, output); } } -inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, - int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo), - int (*broadcast_func)(int, int, int, const int8_t *, float32x4x4_t, int8_t *, int32x4_t, float32x4_t, - float32x4_t, float32x4_t, const bool), - int (*neon_func)(int, int, int, const int8_t *, const int8_t *, int8_t *, - int32x4_t, int32x4_t, float32x4_t, float32x4_t, - float32x4_t, float32x4_t)) +inline void +elementwise_op_quantized_signed(const ITensor *in1, + const ITensor *in2, + ITensor *out, + const Window &window, + int8_t (*scalar_func)(const float &, const float &, UniformQuantizationInfo), + int (*broadcast_func)(int, + int, + int, + const int8_t *, + float32x4x4_t, + int8_t *, + int32x4_t, + float32x4_t, + float32x4_t, + float32x4_t, + const bool), + int (*neon_func)(int, + int, + int, + const int8_t *, + const int8_t *, + int8_t *, + int32x4_t, + int32x4_t, + float32x4_t, + float32x4_t, + float32x4_t, + float32x4_t)) { // Create input windows Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); @@ -986,7 +1192,7 @@ inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *i const float32x4_t voffseto = vdupq_n_f32(output_qinfo.offset); const float32x4_t invvscaleo = vdupq_n_f32(1.f / output_qinfo.scale); - if(is_broadcast_across_x) + if (is_broadcast_across_x) { // Select the broadcast input on the X axis const bool is_broadcast_input_2 = input2_win.x().step() == 0; @@ -1008,24 +1214,28 @@ inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *i Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr()); - const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr()); + execute_window_loop( + win, + [&](const Coordinates &) + { + const auto non_broadcast_input_ptr = reinterpret_cast<const int8_t *>(non_broadcast_input.ptr()); + const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr()); - const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr()); - const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo); + const int8_t broadcast_value = *reinterpret_cast<const int8_t *>(broadcast_input.ptr()); + const float32x4x4_t broadcast_vector = vdequantize(vdupq_n_s8(broadcast_value), broadcast_qinfo); - int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, broadcast_vector, output_ptr, - voffset_non_broadcast, vscale_non_broadcast, voffseto, invvscaleo, !is_broadcast_input_2); - for(; x < window_end_x; ++x) - { - const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo); - const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo); - *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, !is_broadcast_input_2 ? afs : bfs, output_qinfo); - } - }, - broadcast_input, non_broadcast_input, output); + int x = (*broadcast_func)(window_start_x, window_end_x, window_step_x, non_broadcast_input_ptr, + broadcast_vector, output_ptr, voffset_non_broadcast, vscale_non_broadcast, + voffseto, invvscaleo, !is_broadcast_input_2); + for (; x < window_end_x; ++x) + { + const float afs = dequantize_qasymm8_signed(*(non_broadcast_input_ptr + x), non_broadcast_qinfo); + const float bfs = dequantize_qasymm8_signed(broadcast_value, broadcast_qinfo); + *(output_ptr + x) = (*scalar_func)(!is_broadcast_input_2 ? bfs : afs, + !is_broadcast_input_2 ? afs : bfs, output_qinfo); + } + }, + broadcast_input, non_broadcast_input, output); } else { @@ -1048,22 +1258,24 @@ inline void elementwise_op_quantized_signed(const ITensor *in1, const ITensor *i Iterator input2(in2, input2_win); Iterator output(out, win); - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr()); - const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr()); - const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr()); - - int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, voffset1, voffset2, - vscale1, vscale2, voffseto, invvscaleo); - for(; x < window_end_x; ++x) + execute_window_loop( + win, + [&](const Coordinates &) { - const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo); - const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo); - *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo); - } - }, - input1, input2, output); + const auto input1_ptr = reinterpret_cast<const int8_t *>(input1.ptr()); + const auto input2_ptr = reinterpret_cast<const int8_t *>(input2.ptr()); + const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr()); + + int x = (*neon_func)(window_start_x, window_end_x, window_step_x, input1_ptr, input2_ptr, output_ptr, + voffset1, voffset2, vscale1, vscale2, voffseto, invvscaleo); + for (; x < window_end_x; ++x) + { + const float afs = dequantize_qasymm8_signed(*(input1_ptr + x), input1_qinfo); + const float bfs = dequantize_qasymm8_signed(*(input2_ptr + x), input2_qinfo); + *(output_ptr + x) = (*scalar_func)(afs, bfs, output_qinfo); + } + }, + input1, input2, output); } } |