aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/detail
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-10-02 16:38:59 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-10-07 09:54:17 +0000
commitddb93bbf12fc9d685e7ddbef703a886d67cbda9b (patch)
tree6dc7bba4a3ffaa527f4972d85c951a012cce5231 /src/core/NEON/kernels/detail
parent4d91dc68adf8a4cc07285fe781469231230df3b9 (diff)
downloadComputeLibrary-ddb93bbf12fc9d685e7ddbef703a886d67cbda9b.tar.gz
COMPMID-3637: Move wrapper to src
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I524b0c4b49c7a7035b7d078b9585d77b0d438e10 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4083 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/detail')
-rw-r--r--src/core/NEON/kernels/detail/NEActivationFunctionDetail.h315
-rw-r--r--src/core/NEON/kernels/detail/NEColorConvertHelper.inl1045
-rw-r--r--src/core/NEON/kernels/detail/NEDirectConvolution3x3.h170
-rw-r--r--src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h965
4 files changed, 2495 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/detail/NEActivationFunctionDetail.h b/src/core/NEON/kernels/detail/NEActivationFunctionDetail.h
new file mode 100644
index 0000000000..eef1be06eb
--- /dev/null
+++ b/src/core/NEON/kernels/detail/NEActivationFunctionDetail.h
@@ -0,0 +1,315 @@
+/*
+ * Copyright (c) 2018-2020 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_DETAIL_NEACTIVATION_FUNCTION_DETAIL_H
+#define ARM_COMPUTE_DETAIL_NEACTIVATION_FUNCTION_DETAIL_H
+
+#include "src/core/NEON/wrapper/wrapper.h"
+
+namespace arm_compute
+{
+namespace detail
+{
+/** Dummy activation object */
+template <typename T, int S>
+struct dummy
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+
+ /** Construct a dummy activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit dummy(ActivationLayerInfo act_info)
+ {
+ ARM_COMPUTE_UNUSED(act_info);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ ARM_COMPUTE_UNUSED(vval);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ ARM_COMPUTE_UNUSED(val);
+ }
+};
+/** Linear activation object */
+template <typename T, int S>
+struct linear
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ /** Construct a Linear activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit linear(ActivationLayerInfo act_info)
+ : alpha(act_info.a()),
+ beta(act_info.b()),
+ valpha(wrapper::vdup_n(static_cast<T>(alpha), ExactTagType{})),
+ vbeta(wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}))
+ {
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ vval = wrapper::vmla(vbeta, vval, valpha);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ val = alpha * val + beta;
+ }
+
+ const T alpha; /**< Scalar alpha */
+ const T beta; /**< Scalar alpha */
+ const ExactType valpha; /**< Vector of alphas. */
+ const ExactType vbeta; /**< Vector of betas. */
+};
+/** Square activation object */
+template <typename T, int S>
+struct square
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ /** Construct a Square activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit square(ActivationLayerInfo act_info)
+ {
+ ARM_COMPUTE_UNUSED(act_info);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ vval = wrapper::vmul(vval, vval);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ val = val * val;
+ }
+};
+/** Logistic activation object */
+template <typename T, int S>
+struct logistic
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ /** Construct a Logistic activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit logistic(ActivationLayerInfo act_info)
+ : vone(wrapper::vdup_n(static_cast<T>(1), ExactTagType{}))
+ {
+ ARM_COMPUTE_UNUSED(act_info);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ vval = wrapper::vinv(wrapper::vadd(vone, wrapper::vexpq(wrapper::vneg(vval))));
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ val = 1 / (1 + std::exp(-val));
+ }
+
+ /** Vector of ones. */
+ const ExactType vone;
+};
+/** RELU activation object */
+template <typename T, int S>
+struct relu
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ /** Construct a RELU activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit relu(ActivationLayerInfo act_info)
+ : vzero(wrapper::vdup_n(static_cast<T>(0), ExactTagType{}))
+ {
+ ARM_COMPUTE_UNUSED(act_info);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ vval = wrapper::vmax(vzero, vval);
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ val = std::max(static_cast<T>(0), val);
+ }
+
+ /** Vector of zeroes. */
+ const ExactType vzero;
+};
+/** Bounded RELU activation object */
+template <typename T, int S>
+struct brelu
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ /** Construct a bounded RELU activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit brelu(ActivationLayerInfo act_info)
+ : alpha(act_info.a()),
+ vzero(wrapper::vdup_n(static_cast<T>(0), ExactTagType{})),
+ valpha(wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{}))
+ {
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ vval = wrapper::vmin(valpha, wrapper::vmax(vzero, vval));
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ val = std::min(alpha, std::max(static_cast<T>(0), val));
+ }
+
+ const T alpha; /** Scalar alpha */
+ const ExactType vzero; /** Vector of zeroes. */
+ const ExactType valpha; /** Vector of alphas. */
+};
+/** Lower-Upper Bounded RELU activation object */
+template <typename T, int S>
+struct lubrelu
+{
+ /** NEON vector type. */
+ using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+ /** Construct a lower-upper bounded RELU activation object.
+ *
+ * @param[in] act_info Activation layer information.
+ */
+ explicit lubrelu(ActivationLayerInfo act_info)
+ : alpha(act_info.a()),
+ beta(act_info.b()),
+ valpha(wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{})),
+ vbeta(wrapper::vdup_n(static_cast<T>(act_info.b()), ExactTagType{}))
+ {
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] vval Vector of values.
+ */
+ void operator()(ExactType &vval)
+ {
+ vval = wrapper::vmin(valpha, wrapper::vmax(vbeta, vval));
+ }
+
+ /** Run activation function.
+ *
+ * @param[in] val Scalar value.
+ */
+ void operator()(T &val)
+ {
+ val = std::min(alpha, std::max(beta, val));
+ }
+
+ const T alpha; /**< Scalar alpha */
+ const T beta; /**< Scalar alpha */
+ const ExactType valpha; /** Vector of alphas. */
+ const ExactType vbeta; /** Vector of betas. */
+};
+} // namespace detail
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_DETAIL_NEACTIVATION_FUNCTION_DETAIL_H */
diff --git a/src/core/NEON/kernels/detail/NEColorConvertHelper.inl b/src/core/NEON/kernels/detail/NEColorConvertHelper.inl
new file mode 100644
index 0000000000..ac196d9dbb
--- /dev/null
+++ b/src/core/NEON/kernels/detail/NEColorConvertHelper.inl
@@ -0,0 +1,1045 @@
+/*
+ * Copyright (c) 2016-2020 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/IMultiImage.h"
+#include "arm_compute/core/Utils.h"
+#include "src/core/NEON/NEMath.h"
+
+#include <arm_neon.h>
+
+namespace
+{
+#ifndef DOXYGEN_SKIP_THIS
+constexpr float red_coef_bt709 = 1.5748F;
+constexpr float green_coef_bt709 = -0.1873f;
+constexpr float green_coef2_bt709 = -0.4681f;
+constexpr float blue_coef_bt709 = 1.8556f;
+
+constexpr float rgb2yuv_bt709_kr = 0.2126f;
+constexpr float rgb2yuv_bt709_kb = 0.0722f;
+// K_g = 1 - K_r - K_b
+constexpr float rgb2yuv_bt709_kg = 0.7152f;
+// C_u = 1 / (2 * (1 - K_b))
+constexpr float rgb2yuv_bt709_cu = 0.5389f;
+// C_v = 1 / (2 * (1 - K_r))
+constexpr float rgb2yuv_bt709_cv = 0.6350f;
+
+constexpr float rgb2u8_red_coef = 0.2126f;
+constexpr float rgb2u8_green_coef = 0.7152f;
+constexpr float rgb2u8_blue_coef = 0.0722f;
+
+inline float32x4_t rgb_to_greyscale_calculation(const float32x4_t &rcolor, const float32x4_t &gcolor, const float32x4_t &bcolor,
+ const float rcoef, const float gcoef, const float bcoef)
+{
+ float32x4_t greyscale = vmulq_n_f32(rcolor, rcoef);
+ greyscale = vmlaq_n_f32(greyscale, gcolor, gcoef);
+ greyscale = vmlaq_n_f32(greyscale, bcolor, bcoef);
+ return greyscale;
+}
+
+inline void rgb_to_u8_conversion(const uint8x16x3_t &in, uint8x16_t &out)
+{
+ float32x4x4_t out_float32;
+
+ //Conversion from 3(RGB) 4 uint8s to 3(RGB) 4 floats
+ const float32x4x4_t r_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[0]);
+ const float32x4x4_t g_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[1]);
+ const float32x4x4_t b_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[2]);
+
+ //New grayscale image = ( (RED_COEFF * R) + (GREEN_COEFF * G) + (BLUE_COEFF * B) )
+ //Computation of 1(Greyscale) 4 uint8 using 3(RGB) 4 uint8s float
+ out_float32.val[0] = rgb_to_greyscale_calculation(r_float32.val[0], g_float32.val[0], b_float32.val[0],
+ rgb2u8_red_coef, rgb2u8_green_coef, rgb2u8_blue_coef);
+
+ out_float32.val[1] = rgb_to_greyscale_calculation(r_float32.val[1], g_float32.val[1], b_float32.val[1],
+ rgb2u8_red_coef, rgb2u8_green_coef, rgb2u8_blue_coef);
+
+ out_float32.val[2] = rgb_to_greyscale_calculation(r_float32.val[2], g_float32.val[2], b_float32.val[2],
+ rgb2u8_red_coef, rgb2u8_green_coef, rgb2u8_blue_coef);
+
+ out_float32.val[3] = rgb_to_greyscale_calculation(r_float32.val[3], g_float32.val[3], b_float32.val[3],
+ rgb2u8_red_coef, rgb2u8_green_coef, rgb2u8_blue_coef);
+
+ //Conversion from 1(Greyscale) 4 floats to 1(Greyscale) 4 uint8s
+ arm_compute::convert_float32x4x4_to_uint8x16(out_float32, out);
+}
+
+inline void rgb_to_yuv_calculation(const float32x4_t &rvec, const float32x4_t &gvec, const float32x4_t &bvec,
+ float32x4_t &yvec, float32x4_t &uvec, float32x4_t &vvec)
+{
+ /*
+ Y'= 0.2126*R' + 0.7152*G' + 0.0722*B'
+ U'=-0.1146*R' - 0.3854*G' + 0.5000*B'
+ V'= 0.5000*R' - 0.4542*G' - 0.0458*B'
+ */
+ const auto c128 = vdupq_n_f32(128.f);
+
+ // Y = R * K_r + G * (1 - K_r - K_b) * B * K_b
+ yvec = vmulq_n_f32(rvec, rgb2yuv_bt709_kr);
+ yvec = vmlaq_n_f32(yvec, gvec, rgb2yuv_bt709_kg);
+ yvec = vmlaq_n_f32(yvec, bvec, rgb2yuv_bt709_kb);
+
+ // U = (B - Y) / (2 * (1 - K_b))
+ uvec = vsubq_f32(bvec, yvec);
+ uvec = vmlaq_n_f32(c128, uvec, rgb2yuv_bt709_cu);
+
+ // V = (R - Y) / (2 * (1 - K_r))
+ vvec = vsubq_f32(rvec, yvec);
+ vvec = vmlaq_n_f32(c128, vvec, rgb2yuv_bt709_cv);
+}
+
+inline void yuyv_to_rgb_calculation(const float32x4_t &yvec_val, float32x4_t uvec_val, const float32x4_t &yyvec_val,
+ float32x4_t vvec_val, unsigned char *output_ptr, const bool alpha)
+{
+ float32x4x3_t rgb1, rgb2;
+
+ // Compute: cb - 128 and cr - 128;
+ const auto c128 = vdupq_n_f32(128.f);
+ uvec_val = vsubq_f32(uvec_val, c128);
+ vvec_val = vsubq_f32(vvec_val, c128);
+
+ // Compute:
+ // r = 0.0000f*f_u + 1.5748f*f_v;
+ // g = 0.1873f*f_u - 0.4681f*f_v;
+ // b = 1.8556f*f_u + 0.0000f*f_v;
+ const auto red = vmulq_n_f32(vvec_val, red_coef_bt709);
+ const auto blue = vmulq_n_f32(uvec_val, blue_coef_bt709);
+ const auto green = vaddq_f32(vmulq_n_f32(uvec_val, green_coef_bt709),
+ vmulq_n_f32(vvec_val, green_coef2_bt709));
+
+ // Compute the final r,g,b values using y1 for the first texel and y2 for the second one.
+ // the result is stored in two float32x4x3_t which then are converted to one uint8x8x3_t
+ // and written back to memory using vst3 instruction
+
+ rgb1.val[0] = vaddq_f32(yvec_val, red);
+ rgb1.val[1] = vaddq_f32(yvec_val, green);
+ rgb1.val[2] = vaddq_f32(yvec_val, blue);
+
+ rgb2.val[0] = vaddq_f32(yyvec_val, red);
+ rgb2.val[1] = vaddq_f32(yyvec_val, green);
+ rgb2.val[2] = vaddq_f32(yyvec_val, blue);
+
+ uint8x8x3_t u8_rgb;
+ arm_compute::convert_float32x4x3_to_uint8x8x3(rgb1, rgb2, u8_rgb);
+
+ if(!alpha)
+ {
+ vst3_lane_u8(&output_ptr[0], u8_rgb, 0);
+ vst3_lane_u8(&output_ptr[3], u8_rgb, 4);
+ vst3_lane_u8(&output_ptr[6], u8_rgb, 1);
+ vst3_lane_u8(&output_ptr[9], u8_rgb, 5);
+ vst3_lane_u8(&output_ptr[12], u8_rgb, 2);
+ vst3_lane_u8(&output_ptr[15], u8_rgb, 6);
+ vst3_lane_u8(&output_ptr[18], u8_rgb, 3);
+ vst3_lane_u8(&output_ptr[21], u8_rgb, 7);
+ }
+ else
+ {
+ uint8x8x4_t u8_rgba;
+ u8_rgba.val[0] = u8_rgb.val[0];
+ u8_rgba.val[1] = u8_rgb.val[1];
+ u8_rgba.val[2] = u8_rgb.val[2];
+ u8_rgba.val[3] = vdup_n_u8(255);
+ vst4_lane_u8(&output_ptr[0], u8_rgba, 0);
+ vst4_lane_u8(&output_ptr[4], u8_rgba, 4);
+ vst4_lane_u8(&output_ptr[8], u8_rgba, 1);
+ vst4_lane_u8(&output_ptr[12], u8_rgba, 5);
+ vst4_lane_u8(&output_ptr[16], u8_rgba, 2);
+ vst4_lane_u8(&output_ptr[20], u8_rgba, 6);
+ vst4_lane_u8(&output_ptr[24], u8_rgba, 3);
+ vst4_lane_u8(&output_ptr[28], u8_rgba, 7);
+ }
+}
+
+inline uint8x16x3_t load_rgb(const unsigned char *const ptr, const bool alpha)
+{
+ uint8x16x3_t rgb;
+
+ if(alpha)
+ {
+ const auto tmp = vld4q_u8(ptr);
+ rgb.val[0] = tmp.val[0];
+ rgb.val[1] = tmp.val[1];
+ rgb.val[2] = tmp.val[2];
+ }
+ else
+ {
+ rgb = vld3q_u8(ptr);
+ }
+
+ return rgb;
+}
+
+inline void rgb_to_yuv_conversion(uint8x16x3_t &vec_top, uint8x16x3_t &vec_bottom)
+{
+ // Convert the uint8x16_t to float32x4x4_t
+ const float32x4x4_t frvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[0]);
+ const float32x4x4_t fgvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[1]);
+ const float32x4x4_t fbvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[2]);
+
+ const float32x4x4_t frvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[0]);
+ const float32x4x4_t fgvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[1]);
+ const float32x4x4_t fbvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[2]);
+
+ float32x4x4_t fyvec_top, fuvec_top, fvvec_top;
+ float32x4x4_t fyvec_bottom, fuvec_bottom, fvvec_bottom;
+
+ for(auto i = 0; i < 4; ++i)
+ {
+ rgb_to_yuv_calculation(frvec_top.val[i], fgvec_top.val[i], fbvec_top.val[i],
+ fyvec_top.val[i], fuvec_top.val[i], fvvec_top.val[i]);
+ rgb_to_yuv_calculation(frvec_bottom.val[i], fgvec_bottom.val[i], fbvec_bottom.val[i],
+ fyvec_bottom.val[i], fuvec_bottom.val[i], fvvec_bottom.val[i]);
+ }
+
+ arm_compute::convert_float32x4x4_to_uint8x16(fyvec_top, vec_top.val[0]);
+ arm_compute::convert_float32x4x4_to_uint8x16(fuvec_top, vec_top.val[1]);
+ arm_compute::convert_float32x4x4_to_uint8x16(fvvec_top, vec_top.val[2]);
+ arm_compute::convert_float32x4x4_to_uint8x16(fyvec_bottom, vec_bottom.val[0]);
+ arm_compute::convert_float32x4x4_to_uint8x16(fuvec_bottom, vec_bottom.val[1]);
+ arm_compute::convert_float32x4x4_to_uint8x16(fvvec_bottom, vec_bottom.val[2]);
+}
+
+inline void store_rgb_to_nv12(const uint8x16_t &rvec_top, const uint8x16_t &gvec_top, const uint8x16_t &bvec_top,
+ const uint8x16_t &rvec_bottom, const uint8x16_t &gvec_bottom, const uint8x16_t &bvec_bottom,
+ unsigned char *const __restrict out_y_top, unsigned char *const __restrict out_y_bottom,
+ unsigned char *const __restrict out_uv)
+{
+ uint8x16x3_t vec_top, vec_bottom;
+ vec_top.val[0] = rvec_top;
+ vec_top.val[1] = gvec_top;
+ vec_top.val[2] = bvec_top;
+ vec_bottom.val[0] = rvec_bottom;
+ vec_bottom.val[1] = gvec_bottom;
+ vec_bottom.val[2] = bvec_bottom;
+
+ rgb_to_yuv_conversion(vec_top, vec_bottom);
+
+ vst1q_u8(out_y_top, vec_top.val[0]);
+ vst1q_u8(out_y_bottom, vec_bottom.val[0]);
+
+ const auto uvec = vuzpq_u8(vec_top.val[1], vec_bottom.val[1]);
+ const auto vvec = vuzpq_u8(vec_top.val[2], vec_bottom.val[2]);
+ const auto utmp = vrhaddq_u8(uvec.val[0], uvec.val[1]);
+ const auto vtmp = vrhaddq_u8(vvec.val[0], vvec.val[1]);
+
+ uint8x8x2_t uvvec;
+ uvvec.val[0] = vhadd_u8(vget_low_u8(utmp), vget_high_u8(utmp));
+ uvvec.val[1] = vhadd_u8(vget_low_u8(vtmp), vget_high_u8(vtmp));
+
+ vst2_u8(out_uv, uvvec);
+}
+
+inline void store_rgb_to_iyuv(const uint8x16_t &rvec_top, const uint8x16_t &gvec_top, const uint8x16_t &bvec_top,
+ const uint8x16_t &rvec_bottom, const uint8x16_t &gvec_bottom, const uint8x16_t &bvec_bottom,
+ unsigned char *const __restrict out_y_top, unsigned char *const __restrict out_y_bottom,
+ unsigned char *const __restrict out_u,
+ unsigned char *const __restrict out_v)
+{
+ uint8x16x3_t vec_top, vec_bottom;
+ vec_top.val[0] = rvec_top;
+ vec_top.val[1] = gvec_top;
+ vec_top.val[2] = bvec_top;
+ vec_bottom.val[0] = rvec_bottom;
+ vec_bottom.val[1] = gvec_bottom;
+ vec_bottom.val[2] = bvec_bottom;
+
+ rgb_to_yuv_conversion(vec_top, vec_bottom);
+
+ vst1q_u8(out_y_top, vec_top.val[0]);
+ vst1q_u8(out_y_bottom, vec_bottom.val[0]);
+
+ const auto uvvec_top = vuzpq_u8(vec_top.val[1], vec_top.val[2]);
+ const auto uvvec_bottom = vuzpq_u8(vec_bottom.val[1], vec_bottom.val[2]);
+ const auto uvvec = vhaddq_u8(vrhaddq_u8(uvvec_top.val[0], uvvec_top.val[1]),
+ vrhaddq_u8(uvvec_bottom.val[0], uvvec_bottom.val[1]));
+
+ vst1_u8(out_u, vget_low_u8(uvvec));
+ vst1_u8(out_v, vget_high_u8(uvvec));
+}
+
+inline void store_rgb_to_yuv4(const uint8x16_t &rvec, const uint8x16_t &gvec, const uint8x16_t &bvec,
+ unsigned char *const __restrict out_y,
+ unsigned char *const __restrict out_u,
+ unsigned char *const __restrict out_v)
+{
+ // Convert the uint8x16_t to float32x4x4_t
+ const float32x4x4_t frvec = arm_compute::convert_uint8x16_to_float32x4x4(rvec);
+ const float32x4x4_t fgvec = arm_compute::convert_uint8x16_to_float32x4x4(gvec);
+ const float32x4x4_t fbvec = arm_compute::convert_uint8x16_to_float32x4x4(bvec);
+
+ float32x4x4_t fyvec, fuvec, fvvec;
+ for(auto i = 0; i < 4; ++i)
+ {
+ rgb_to_yuv_calculation(frvec.val[i], fgvec.val[i], fbvec.val[i],
+ fyvec.val[i], fuvec.val[i], fvvec.val[i]);
+ }
+
+ uint8x16_t yvec, uvec, vvec;
+ arm_compute::convert_float32x4x4_to_uint8x16(fyvec, yvec);
+ arm_compute::convert_float32x4x4_to_uint8x16(fuvec, uvec);
+ arm_compute::convert_float32x4x4_to_uint8x16(fvvec, vvec);
+
+ vst1q_u8(out_y, yvec);
+ vst1q_u8(out_u, uvec);
+ vst1q_u8(out_v, vvec);
+}
+#endif /* DOXYGEN_SKIP_THIS */
+}
+
+namespace arm_compute
+{
+/** Convert RGB to RGBX.
+ *
+ * @param[in] input Input RGB data buffer.
+ * @param[out] output Output RGBX buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+void colorconvert_rgb_to_rgbx(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IImage *__restrict>(output);
+
+ Iterator in(input_ptr, win);
+ Iterator out(output_ptr, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta1 = vld3q_u8(in.ptr());
+ uint8x16x4_t ta2;
+ ta2.val[0] = ta1.val[0];
+ ta2.val[1] = ta1.val[1];
+ ta2.val[2] = ta1.val[2];
+ ta2.val[3] = vdupq_n_u8(255);
+ vst4q_u8(out.ptr(), ta2);
+ },
+ in, out);
+}
+
+/** Convert RGB to U8.
+ *
+ * @param[in] input Input RGB data buffer.
+ * @param[out] output Output U8 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+void colorconvert_rgb_to_u8(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IImage *__restrict>(output);
+
+ Iterator in(input_ptr, win);
+ Iterator out(output_ptr, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta1 = vld3q_u8(in.ptr());
+ uint8x16_t ta2;
+ rgb_to_u8_conversion(ta1, ta2);
+ vst1q_u8(out.ptr(), ta2);
+ },
+ in, out);
+}
+
+/** Convert RGBX to RGB.
+ *
+ * @param[in] input Input RGBX data buffer.
+ * @param[out] output Output RGB buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+void colorconvert_rgbx_to_rgb(const void *input, void *output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IImage *__restrict>(output);
+
+ Iterator in(input_ptr, win);
+ Iterator out(output_ptr, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta1 = vld4q_u8(in.ptr());
+ uint8x16x3_t ta2;
+ ta2.val[0] = ta1.val[0];
+ ta2.val[1] = ta1.val[1];
+ ta2.val[2] = ta1.val[2];
+ vst3q_u8(out.ptr(), ta2);
+ },
+ in, out);
+}
+
+/** Convert YUYV to RGB.
+ *
+ * @param[in] input Input YUYV data buffer.
+ * @param[out] output Output RGB buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool yuyv, bool alpha>
+void colorconvert_yuyv_to_rgb(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IImage *__restrict>(output);
+
+ constexpr auto element_size = alpha ? 32 : 24;
+ constexpr auto shift = yuyv ? 0 : 1;
+
+ Iterator in(input_ptr, win);
+ Iterator out(output_ptr, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta = vld4q_u8(in.ptr());
+ //ta.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta.val[1] = U0 U2 U4 U6 ...
+ //ta.val[2] = Y1 Y3 Y5 Y7 ...
+ //ta.val[3] = V0 V2 V4 V7 ...
+
+ // Convert the uint8x16x4_t to float32x4x4_t
+ const float32x4x4_t yvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[0 + shift]);
+ const float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[1 - shift]);
+ const float32x4x4_t yyvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[2 + shift]);
+ const float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[3 - shift]);
+
+ yuyv_to_rgb_calculation(yvec.val[0], uvec.val[0], yyvec.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec.val[1], uvec.val[1], yyvec.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec.val[2], uvec.val[2], yyvec.val[2], vvec.val[2], out.ptr() + 2 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec.val[3], uvec.val[3], yyvec.val[3], vvec.val[3], out.ptr() + 3 * element_size, alpha);
+ },
+ in, out);
+}
+
+/** Convert NV12 to RGB.
+ *
+ * @param[in] input Input NV12 data buffer.
+ * @param[out] output Output RGB buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool uv, bool alpha>
+void colorconvert_nv12_to_rgb(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IMultiImage *__restrict>(input);
+ const auto output_ptr = static_cast<IImage *__restrict>(output);
+
+ constexpr auto element_size = alpha ? 32 : 24;
+ const auto out_stride = output_ptr->info()->strides_in_bytes().y();
+ constexpr auto shift = uv ? 0 : 1;
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in_y(input_ptr->plane(0), win);
+ Iterator in_uv(input_ptr->plane(1), win_uv);
+ Iterator out(output_ptr, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_y_top = vld2q_u8(in_y.ptr());
+ const auto ta_y_bottom = vld2q_u8(in_y.ptr() + input_ptr->plane(0)->info()->strides_in_bytes().y());
+ const auto ta_uv = vld2q_u8(in_uv.ptr());
+ //ta_y.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta_y.val[1] = Y1 Y3 Y5 Y7 ...
+ //ta_uv.val[0] = U0 U2 U4 U6 ...
+ //ta_uv.val[1] = V0 V2 V4 V6 ...
+
+ // Convert the uint8x16x4_t to float32x4x4_t
+ float32x4x4_t yvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[0]);
+ float32x4x4_t yyvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[1]);
+ float32x4x4_t yvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]);
+ float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]);
+ float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_uv.val[0 + shift]);
+ float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_uv.val[1 - shift]);
+
+ yuyv_to_rgb_calculation(yvec_top.val[0], uvec.val[0], yyvec_top.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_top.val[1], uvec.val[1], yyvec_top.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_top.val[2], uvec.val[2], yyvec_top.val[2], vvec.val[2], out.ptr() + 2 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_top.val[3], uvec.val[3], yyvec_top.val[3], vvec.val[3], out.ptr() + 3 * element_size, alpha);
+
+ yuyv_to_rgb_calculation(yvec_bottom.val[0], uvec.val[0], yyvec_bottom.val[0], vvec.val[0], out.ptr() + out_stride + 0 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_bottom.val[1], uvec.val[1], yyvec_bottom.val[1], vvec.val[1], out.ptr() + out_stride + 1 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_bottom.val[2], uvec.val[2], yyvec_bottom.val[2], vvec.val[2], out.ptr() + out_stride + 2 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_bottom.val[3], uvec.val[3], yyvec_bottom.val[3], vvec.val[3], out.ptr() + out_stride + 3 * element_size, alpha);
+ },
+ in_y, in_uv, out);
+}
+
+/** Convert IYUV to RGB.
+ *
+ * @param[in] input Input IYUV data buffer.
+ * @param[out] output Output RGB buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool alpha>
+void colorconvert_iyuv_to_rgb(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IMultiImage *__restrict>(input);
+ const auto output_ptr = static_cast<IImage *__restrict>(output);
+
+ constexpr auto element_size = alpha ? 32 : 24;
+ const auto out_stride = output_ptr->info()->strides_in_bytes().y();
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in_y(input_ptr->plane(0), win);
+ Iterator in_u(input_ptr->plane(1), win_uv);
+ Iterator in_v(input_ptr->plane(2), win_uv);
+ Iterator out(output_ptr, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto *y_top_ptr = in_y.ptr();
+ const auto *y_bottom_ptr = in_y.ptr() + input_ptr->plane(0)->info()->strides_in_bytes().y();
+ const auto *u_ptr = in_u.ptr();
+ const auto *v_ptr = in_v.ptr();
+
+ // Work-around issue in gcc 9(>=) where vld2q might cause issues with register allocation
+#if defined(__arch64__)
+ const auto ta0_y_top = vld1q_u8(y_top_ptr);
+ const auto ta1_y_top = vld1q_u8(y_top_ptr + 16);
+ const auto ta0_y_bottom = vld1q_u8(y_bottom_ptr);
+ const auto ta1_y_bottom = vld1q_u8(y_bottom_ptr + 16);
+ const auto ta_u = vld1q_u8(u_ptr);
+ const auto ta_v = vld1q_u8(v_ptr);
+
+ // Convert the uint8x16x4_t to float32x4x4_t
+ float32x4x4_t yvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vuzp1q_u8(ta0_y_top, ta1_y_top));
+ float32x4x4_t yyvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vuzp2q_u8(ta0_y_top, ta1_y_top));
+ float32x4x4_t yvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vuzp1q_u8(ta0_y_bottom, ta1_y_bottom));
+ float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vuzp2q_u8(ta0_y_bottom, ta1_y_bottom));
+ float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_u);
+ float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_v);
+#else /* defined(__arch64__) */
+ const auto ta_y_top = vld2q_u8(y_top_ptr);
+ const auto ta_y_bottom = vld2q_u8(y_bottom_ptr);
+ const auto ta_u = vld1q_u8(u_ptr);
+ const auto ta_v = vld1q_u8(v_ptr);
+ //ta_y.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta_y.val[1] = Y1 Y3 Y5 Y7 ...
+ //ta_u.val[0] = U0 U2 U4 U6 ...
+ //ta_v.val[0] = V0 V2 V4 V6 ...
+
+ // Convert the uint8x16x4_t to float32x4x4_t
+ float32x4x4_t yvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[0]);
+ float32x4x4_t yyvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[1]);
+ float32x4x4_t yvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]);
+ float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]);
+ float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_u);
+ float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_v);
+#endif /* defined(__arch64__) */
+
+ yuyv_to_rgb_calculation(yvec_top.val[0], uvec.val[0], yyvec_top.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_top.val[1], uvec.val[1], yyvec_top.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_top.val[2], uvec.val[2], yyvec_top.val[2], vvec.val[2], out.ptr() + 2 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_top.val[3], uvec.val[3], yyvec_top.val[3], vvec.val[3], out.ptr() + 3 * element_size, alpha);
+
+ yuyv_to_rgb_calculation(yvec_bottom.val[0], uvec.val[0], yyvec_bottom.val[0], vvec.val[0], out.ptr() + out_stride + 0 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_bottom.val[1], uvec.val[1], yyvec_bottom.val[1], vvec.val[1], out.ptr() + out_stride + 1 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_bottom.val[2], uvec.val[2], yyvec_bottom.val[2], vvec.val[2], out.ptr() + out_stride + 2 * element_size, alpha);
+ yuyv_to_rgb_calculation(yvec_bottom.val[3], uvec.val[3], yyvec_bottom.val[3], vvec.val[3], out.ptr() + out_stride + 3 * element_size, alpha);
+ },
+ in_y, in_u, in_v, out);
+}
+
+/** Convert YUYV to NV12.
+ *
+ * @param[in] input Input YUYV data buffer.
+ * @param[out] output Output NV12 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool yuyv>
+void colorconvert_yuyv_to_nv12(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ constexpr auto shift = yuyv ? 0 : 1;
+
+ // NV12's UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in(input_ptr, win);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_uv(output_ptr->plane(1), win_uv);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_top = vld4q_u8(in.ptr());
+ const auto ta_bottom = vld4q_u8(in.ptr() + input_ptr->info()->strides_in_bytes().y());
+ //ta.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta.val[1] = U0 U2 U4 U6 ...
+ //ta.val[2] = Y1 Y3 Y5 Y7 ...
+ //ta.val[3] = V0 V2 V4 V7 ...
+
+ uint8x16x2_t yvec;
+ yvec.val[0] = ta_top.val[0 + shift];
+ yvec.val[1] = ta_top.val[2 + shift];
+ vst2q_u8(out_y.ptr(), yvec);
+
+ uint8x16x2_t yyvec;
+ yyvec.val[0] = ta_bottom.val[0 + shift];
+ yyvec.val[1] = ta_bottom.val[2 + shift];
+ vst2q_u8(out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(), yyvec);
+
+ uint8x16x2_t uvvec;
+ uvvec.val[0] = vhaddq_u8(ta_top.val[1 - shift], ta_bottom.val[1 - shift]);
+ uvvec.val[1] = vhaddq_u8(ta_top.val[3 - shift], ta_bottom.val[3 - shift]);
+ vst2q_u8(out_uv.ptr(), uvvec);
+ },
+ in, out_y, out_uv);
+}
+
+/** Convert IYUV to NV12.
+ *
+ * @param[in] input Input IYUV data buffer.
+ * @param[out] output Output NV12 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+void colorconvert_iyuv_to_nv12(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IMultiImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in_y(input_ptr->plane(0), win);
+ Iterator in_u(input_ptr->plane(1), win_uv);
+ Iterator in_v(input_ptr->plane(2), win_uv);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_uv(output_ptr->plane(1), win_uv);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_y_top = vld2q_u8(in_y.ptr());
+ const auto ta_y_bottom = vld2q_u8(in_y.ptr() + input_ptr->plane(0)->info()->strides_in_bytes().y());
+ uint8x16x2_t ta_uv;
+ ta_uv.val[0] = vld1q_u8(in_u.ptr());
+ ta_uv.val[1] = vld1q_u8(in_v.ptr());
+ //ta_y.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta_y.val[1] = Y1 Y3 Y5 Y7 ...
+ //ta_uv.val[0] = U0 U2 U4 U6 ...
+ //ta_uv.val[1] = V0 V2 V4 V6 ...
+
+ vst2q_u8(out_y.ptr(), ta_y_top);
+ vst2q_u8(out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(), ta_y_bottom);
+ vst2q_u8(out_uv.ptr(), ta_uv);
+ },
+ in_y, in_u, in_v, out_y, out_uv);
+}
+
+/** Convert NV12 to IYUV.
+ *
+ * @param[in] input Input NV12 data buffer.
+ * @param[out] output Output IYUV buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool uv>
+void colorconvert_nv12_to_iyuv(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IMultiImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ constexpr auto shift = uv ? 0 : 1;
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in_y(input_ptr->plane(0), win);
+ Iterator in_uv(input_ptr->plane(1), win_uv);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_u(output_ptr->plane(1), win_uv);
+ Iterator out_v(output_ptr->plane(2), win_uv);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_y_top = vld2q_u8(in_y.ptr());
+ const auto ta_y_bottom = vld2q_u8(in_y.ptr() + input_ptr->plane(0)->info()->strides_in_bytes().y());
+ const auto ta_uv = vld2q_u8(in_uv.ptr());
+ //ta_y.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta_y.val[1] = Y1 Y3 Y5 Y7 ...
+ //ta_uv.val[0] = U0 U2 U4 U6 ...
+ //ta_uv.val[1] = V0 V2 V4 V6 ...
+
+ vst2q_u8(out_y.ptr(), ta_y_top);
+ vst2q_u8(out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(), ta_y_bottom);
+ vst1q_u8(out_u.ptr(), ta_uv.val[0 + shift]);
+ vst1q_u8(out_v.ptr(), ta_uv.val[1 - shift]);
+ },
+ in_y, in_uv, out_y, out_u, out_v);
+}
+
+/** Convert YUYV to IYUV.
+ *
+ * @param[in] input Input YUYV data buffer.
+ * @param[out] output Output IYUV buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool yuyv>
+void colorconvert_yuyv_to_iyuv(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ constexpr auto shift = yuyv ? 0 : 1;
+
+ // Destination's UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in(input_ptr, win);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_u(output_ptr->plane(1), win_uv);
+ Iterator out_v(output_ptr->plane(2), win_uv);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_top = vld4q_u8(in.ptr());
+ const auto ta_bottom = vld4q_u8(in.ptr() + input_ptr->info()->strides_in_bytes().y());
+ //ta.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta.val[1] = U0 U2 U4 U6 ...
+ //ta.val[2] = Y1 Y3 Y5 Y7 ...
+ //ta.val[3] = V0 V2 V4 V7 ...
+
+ uint8x16x2_t yvec;
+ yvec.val[0] = ta_top.val[0 + shift];
+ yvec.val[1] = ta_top.val[2 + shift];
+ vst2q_u8(out_y.ptr(), yvec);
+
+ uint8x16x2_t yyvec;
+ yyvec.val[0] = ta_bottom.val[0 + shift];
+ yyvec.val[1] = ta_bottom.val[2 + shift];
+ vst2q_u8(out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(), yyvec);
+
+ uint8x16_t uvec;
+ uvec = vhaddq_u8(ta_top.val[1 - shift], ta_bottom.val[1 - shift]);
+ vst1q_u8(out_u.ptr(), uvec);
+
+ uint8x16_t vvec;
+ vvec = vhaddq_u8(ta_top.val[3 - shift], ta_bottom.val[3 - shift]);
+ vst1q_u8(out_v.ptr(), vvec);
+ },
+ in, out_y, out_u, out_v);
+}
+
+/** Convert NV12 to YUV4.
+ *
+ * @param[in] input Input NV12 data buffer.
+ * @param[out] output Output YUV4 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool uv>
+void colorconvert_nv12_to_yuv4(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IMultiImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ constexpr auto shift = uv ? 0 : 1;
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in_y(input_ptr->plane(0), win);
+ Iterator in_uv(input_ptr->plane(1), win_uv);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_u(output_ptr->plane(1), win);
+ Iterator out_v(output_ptr->plane(2), win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_y_top = vld2q_u8(in_y.ptr());
+ const auto ta_y_bottom = vld2q_u8(in_y.ptr() + input_ptr->plane(0)->info()->strides_in_bytes().y());
+ const auto ta_uv = vld2q_u8(in_uv.ptr());
+ //ta_y.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta_y.val[1] = Y1 Y3 Y5 Y7 ...
+ //ta_uv.val[0] = U0 U2 U4 U6 ...
+ //ta_uv.val[1] = V0 V2 V4 V6 ...
+
+ vst2q_u8(out_y.ptr(), ta_y_top);
+ vst2q_u8(out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(), ta_y_bottom);
+
+ uint8x16x2_t uvec;
+ uvec.val[0] = ta_uv.val[0 + shift];
+ uvec.val[1] = ta_uv.val[0 + shift];
+ vst2q_u8(out_u.ptr(), uvec);
+ vst2q_u8(out_u.ptr() + output_ptr->plane(1)->info()->strides_in_bytes().y(), uvec);
+
+ uint8x16x2_t vvec;
+ vvec.val[0] = ta_uv.val[1 - shift];
+ vvec.val[1] = ta_uv.val[1 - shift];
+ vst2q_u8(out_v.ptr(), vvec);
+ vst2q_u8(out_v.ptr() + output_ptr->plane(2)->info()->strides_in_bytes().y(), vvec);
+ },
+ in_y, in_uv, out_y, out_u, out_v);
+}
+
+/** Convert IYUV to YUV4.
+ *
+ * @param[in] input Input IYUV data buffer.
+ * @param[out] output Output YUV4 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+void colorconvert_iyuv_to_yuv4(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IMultiImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in_y(input_ptr->plane(0), win);
+ Iterator in_u(input_ptr->plane(1), win_uv);
+ Iterator in_v(input_ptr->plane(2), win_uv);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_u(output_ptr->plane(1), win);
+ Iterator out_v(output_ptr->plane(2), win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_y_top = vld2q_u8(in_y.ptr());
+ const auto ta_y_bottom = vld2q_u8(in_y.ptr() + input_ptr->plane(0)->info()->strides_in_bytes().y());
+ const auto ta_u = vld1q_u8(in_u.ptr());
+ const auto ta_v = vld1q_u8(in_v.ptr());
+ //ta_y.val[0] = Y0 Y2 Y4 Y6 ...
+ //ta_y.val[1] = Y1 Y3 Y5 Y7 ...
+ //ta_u = U0 U2 U4 U6 ...
+ //ta_v = V0 V2 V4 V6 ...
+
+ vst2q_u8(out_y.ptr(), ta_y_top);
+ vst2q_u8(out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(), ta_y_bottom);
+
+ uint8x16x2_t uvec;
+ uvec.val[0] = ta_u;
+ uvec.val[1] = ta_u;
+ vst2q_u8(out_u.ptr(), uvec);
+ vst2q_u8(out_u.ptr() + output_ptr->plane(1)->info()->strides_in_bytes().y(), uvec);
+
+ uint8x16x2_t vvec;
+ vvec.val[0] = ta_v;
+ vvec.val[1] = ta_v;
+ vst2q_u8(out_v.ptr(), vvec);
+ vst2q_u8(out_v.ptr() + output_ptr->plane(2)->info()->strides_in_bytes().y(), vvec);
+ },
+ in_y, in_u, in_v, out_y, out_u, out_v);
+}
+
+/** Convert RGB to NV12.
+ *
+ * @param[in] input Input RGB data buffer.
+ * @param[out] output Output NV12 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool alpha>
+void colorconvert_rgb_to_nv12(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in(input_ptr, win);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_uv(output_ptr->plane(1), win_uv);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_rgb_top = load_rgb(in.ptr(), alpha);
+ const auto ta_rgb_bottom = load_rgb(in.ptr() + input_ptr->info()->strides_in_bytes().y(), alpha);
+ //ta_rgb.val[0] = R0 R1 R2 R3 ...
+ //ta_rgb.val[1] = G0 G1 G2 G3 ...
+ //ta_rgb.val[2] = B0 B1 B2 B3 ...
+
+ store_rgb_to_nv12(ta_rgb_top.val[0], ta_rgb_top.val[1], ta_rgb_top.val[2],
+ ta_rgb_bottom.val[0], ta_rgb_bottom.val[1], ta_rgb_bottom.val[2],
+ out_y.ptr(), out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(),
+ out_uv.ptr());
+ },
+ in, out_y, out_uv);
+}
+
+/** Convert RGB to IYUV.
+ *
+ * @param[in] input Input RGB data buffer.
+ * @param[out] output Output IYUV buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool alpha>
+void colorconvert_rgb_to_iyuv(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ // UV's width and height are subsampled
+ Window win_uv(win);
+ win_uv.set(Window::DimX, Window::Dimension(win_uv.x().start() / 2, win_uv.x().end() / 2, win_uv.x().step() / 2));
+ win_uv.set(Window::DimY, Window::Dimension(win_uv.y().start() / 2, win_uv.y().end() / 2, 1));
+ win_uv.validate();
+
+ Iterator in(input_ptr, win);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_u(output_ptr->plane(1), win_uv);
+ Iterator out_v(output_ptr->plane(2), win_uv);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_rgb_top = load_rgb(in.ptr(), alpha);
+ const auto ta_rgb_bottom = load_rgb(in.ptr() + input_ptr->info()->strides_in_bytes().y(), alpha);
+ //ta_rgb.val[0] = R0 R1 R2 R3 ...
+ //ta_rgb.val[1] = G0 G1 G2 G3 ...
+ //ta_rgb.val[2] = B0 B1 B2 B3 ...
+
+ store_rgb_to_iyuv(ta_rgb_top.val[0], ta_rgb_top.val[1], ta_rgb_top.val[2],
+ ta_rgb_bottom.val[0], ta_rgb_bottom.val[1], ta_rgb_bottom.val[2],
+ out_y.ptr(), out_y.ptr() + output_ptr->plane(0)->info()->strides_in_bytes().y(),
+ out_u.ptr(), out_v.ptr());
+ },
+ in, out_y, out_u, out_v);
+}
+
+/** Convert RGB to YUV4.
+ *
+ * @param[in] input Input RGB data buffer.
+ * @param[out] output Output YUV4 buffer.
+ * @param[in] win Window for iterating the buffers.
+ *
+ */
+template <bool alpha>
+void colorconvert_rgb_to_yuv4(const void *__restrict input, void *__restrict output, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(nullptr == input);
+ ARM_COMPUTE_ERROR_ON(nullptr == output);
+ win.validate();
+
+ const auto input_ptr = static_cast<const IImage *__restrict>(input);
+ const auto output_ptr = static_cast<IMultiImage *__restrict>(output);
+
+ Iterator in(input_ptr, win);
+ Iterator out_y(output_ptr->plane(0), win);
+ Iterator out_u(output_ptr->plane(1), win);
+ Iterator out_v(output_ptr->plane(2), win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto ta_rgb = load_rgb(in.ptr(), alpha);
+ //ta_rgb.val[0] = R0 R1 R2 R3 ...
+ //ta_rgb.val[1] = G0 G1 G2 G3 ...
+ //ta_rgb.val[2] = B0 B1 B2 B3 ...
+
+ store_rgb_to_yuv4(ta_rgb.val[0], ta_rgb.val[1], ta_rgb.val[2],
+ out_y.ptr(), out_u.ptr(), out_v.ptr());
+ },
+ in, out_y, out_u, out_v);
+}
+} // namespace arm_compute
diff --git a/src/core/NEON/kernels/detail/NEDirectConvolution3x3.h b/src/core/NEON/kernels/detail/NEDirectConvolution3x3.h
new file mode 100644
index 0000000000..96defbc9c9
--- /dev/null
+++ b/src/core/NEON/kernels/detail/NEDirectConvolution3x3.h
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) 2017-2020 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H
+#define ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace detail
+{
+inline float32x4x3_t load_matrix_row(const float *ptr)
+{
+ const float32x4x3_t r =
+ {
+ {
+ vld1q_dup_f32(ptr),
+ vld1q_dup_f32(1 + ptr),
+ vld1q_dup_f32(2 + ptr)
+ }
+ };
+ return r;
+}
+
+template <unsigned int stridex>
+float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2);
+
+template <>
+inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
+{
+ const float32x4x3_t vtop =
+ {
+ {
+ vld1q_f32(in_top),
+ vld1q_f32(in_top + 4),
+ vld1q_f32(in_top + 8)
+ }
+ };
+ const float32x4x3_t vmid =
+ {
+ {
+ vld1q_f32(in_mid),
+ vld1q_f32(in_mid + 4),
+ vld1q_f32(in_mid + 8)
+ }
+ };
+ const float32x4x3_t vlow =
+ {
+ {
+ vld1q_f32(in_low),
+ vld1q_f32(in_low + 4),
+ vld1q_f32(in_low + 8)
+ }
+ };
+ float32x4x2_t out =
+ {
+ {
+ vmulq_f32(vtop.val[0], m0.val[0]),
+ vmulq_f32(vtop.val[1], m0.val[0])
+ }
+ };
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
+ return out;
+}
+
+template <>
+inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
+{
+ float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
+ return out;
+}
+
+template <>
+inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2)
+{
+ float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
+ return out;
+}
+
+template <unsigned int stridex>
+void store_results(float *buffer, const float32x4x2_t &values);
+
+template <>
+void store_results<1>(float *buffer, const float32x4x2_t &values)
+{
+ vst1q_f32(buffer, values.val[0]);
+ vst1q_f32(buffer + 4, values.val[1]);
+}
+
+template <>
+void store_results<2>(float *buffer, const float32x4x2_t &values)
+{
+ vst1q_f32(buffer, values.val[0]);
+}
+
+template <>
+void store_results<3>(float *buffer, const float32x4x2_t &values)
+{
+ vst1_f32(buffer, vget_low_f32(values.val[0]));
+}
+
+template <unsigned int stridex>
+int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
+
+template <>
+int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
+{
+ return num_elems_written_per_iteration;
+}
+
+template <>
+int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
+{
+ return num_elems_written_per_iteration << 1;
+}
+
+template <>
+int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
+{
+ return num_elems_written_per_iteration * 3;
+}
+}
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H */ \ No newline at end of file
diff --git a/src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h b/src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
new file mode 100644
index 0000000000..d7ee70a1cd
--- /dev/null
+++ b/src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
@@ -0,0 +1,965 @@
+/*
+ * Copyright (c) 2017-2020 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
+#define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
+
+#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/utils/misc/Requires.h"
+#include "src/core/NEON/NEFixedPoint.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace detail
+{
+/** Loads a 3x3 matrix as a row (float).
+ *
+ * @param[in] ptr Pointer to a float 3x3 matrix.
+ * @param[in] weights_offset (Optional) Weights quantization offset.
+ *
+ * @return The loaded matrix.
+ */
+inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
+{
+ ARM_COMPUTE_UNUSED(weights_offset);
+ const float32x4x3_t r =
+ {
+ {
+ vld1q_dup_f32(ptr),
+ vld1q_dup_f32(1 + ptr),
+ vld1q_dup_f32(2 + ptr)
+ }
+ };
+ return r;
+}
+
+/** Loads a 3x3 matrix as a row (uint8_t/int8_t).
+ *
+ * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix.
+ * @param[in] weights_offset (Optional) Weights quantization offset.
+ *
+ * @return The loaded matrix.
+ */
+template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
+inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
+{
+ const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
+
+ /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
+ r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
+ int32x4x3_t r =
+ {
+ {
+ vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
+ vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
+ vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
+ }
+ };
+ return r;
+}
+
+/** Stores a float32x4x2_t array into a memory location.
+ *
+ * @param[in] buffer Pointer to the memory location where the values will be stored.
+ * @param[in] values Values that will be stored.
+ *
+ */
+template <unsigned int stridex>
+void store_results(float *buffer, const float32x4x2_t &values);
+
+template <>
+inline void store_results<1>(float *buffer, const float32x4x2_t &values)
+{
+ vst1q_f32(buffer, values.val[0]);
+ vst1q_f32(buffer + 4, values.val[1]);
+}
+
+template <>
+inline void store_results<2>(float *buffer, const float32x4x2_t &values)
+{
+ vst1q_f32(buffer, values.val[0]);
+}
+
+template <>
+inline void store_results<3>(float *buffer, const float32x4x2_t &values)
+{
+ vst1_f32(buffer, vget_low_f32(values.val[0]));
+}
+
+/** Stores a uint32_t array into a memory location.
+ *
+ * @param[in] buffer Pointer to the memory location where the values will be stored.
+ * @param[in] values Values that will be stored.
+ *
+ */
+template <unsigned int stridex>
+void store_results(int32_t *buffer, const int32x4x2_t &values);
+
+template <>
+inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
+{
+ vst1q_s32(buffer, values.val[0]);
+ vst1q_s32(buffer + 4, values.val[1]);
+}
+
+template <>
+inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
+{
+ vst1q_s32(buffer, values.val[0]);
+}
+
+template <>
+inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
+{
+ vst1_s32(buffer, vget_low_s32(values.val[0]));
+}
+
+template <unsigned int stridex>
+inline void accumulate_results(float *buffer, const float32x4x2_t &values);
+
+template <>
+inline void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
+{
+ vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
+ vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
+}
+
+template <>
+inline void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
+{
+ vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
+}
+
+template <>
+inline void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
+{
+ vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
+}
+
+template <unsigned int stridex>
+void accumulate_results(int32_t *buffer, const int32x4x2_t &values);
+
+template <>
+inline void accumulate_results<1>(int32_t *buffer, const int32x4x2_t &values)
+{
+ vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
+ vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
+}
+
+template <>
+inline void accumulate_results<2>(int32_t *buffer, const int32x4x2_t &values)
+{
+ vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
+}
+
+template <>
+inline void accumulate_results<3>(int32_t *buffer, const int32x4x2_t &values)
+{
+ vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+/** Stores a float16x8x2_t array into a memory location.
+ *
+ * @param[in] buffer Pointer to the memory location where the values will be stored.
+ * @param[in] values Values that will be stored.
+ *
+ */
+template <unsigned int stridex>
+void store_results(float16_t *buffer, const float16x8x2_t &values);
+
+template <>
+inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
+{
+ vst1q_f16(buffer, values.val[0]);
+ vst1q_f16(buffer + 8, values.val[1]);
+}
+
+template <>
+inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
+{
+ vst1q_f16(buffer, values.val[0]);
+}
+
+template <>
+inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
+{
+ vst1_f16(buffer, vget_low_f16(values.val[0]));
+}
+
+template <unsigned int stridex>
+inline void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
+
+template <>
+inline void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
+{
+ vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
+ vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
+}
+
+template <>
+inline void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
+{
+ vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
+}
+
+template <>
+inline void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
+{
+ vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
+}
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+/** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] dilation_x Dilation, in elements across x.
+ * @param[in] input_offset (Optional) Input quantization offset.
+ *
+ */
+inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
+ const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
+ const size_t dilation_x, int input_offset)
+{
+ ARM_COMPUTE_UNUSED(input_offset);
+
+ const float32x4x3_t vtop =
+ {
+ {
+ vld1q_f32(in_top),
+ vld1q_f32(in_top + dilation_x),
+ vld1q_f32(in_top + 2 * dilation_x)
+ }
+ };
+ const float32x4x3_t vmid =
+ {
+ {
+ vld1q_f32(in_mid),
+ vld1q_f32(in_mid + dilation_x),
+ vld1q_f32(in_mid + 2 * dilation_x)
+ }
+ };
+ const float32x4x3_t vlow =
+ {
+ {
+ vld1q_f32(in_low),
+ vld1q_f32(in_low + dilation_x),
+ vld1q_f32(in_low + 2 * dilation_x)
+ }
+ };
+ float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
+ out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
+ out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
+
+ out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
+ out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
+ out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
+
+ out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
+ out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
+ out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
+
+ return out;
+}
+
+/** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] dilation_x Dilation, in elements across x.
+ * @param[in] stridex Stride value in elements across x.
+ * @param[in] input_offset (Optional) Input quantization offset.
+ *
+ */
+inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
+ const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
+ const size_t dilation_x, unsigned int stridex, int input_offset = 0)
+{
+ ARM_COMPUTE_ERROR_ON(stridex > 3);
+ float32x4x2_t out =
+ {
+ {
+ single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
+ single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
+ }
+ };
+
+ if(stridex == 2)
+ {
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
+ }
+ else if(stridex == 3)
+ {
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
+ }
+
+ return out;
+}
+
+/** Perform a convolve3x3 on float32.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[out] out_ptr Pointer to the output.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] stridex Stride value in elements across x.
+ * @param[in] input_offset (Optional) Input quantization offset.
+ *
+ */
+template <bool accumulate>
+void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
+ const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
+ unsigned int stridex, int input_offset = 0);
+
+template <bool accumulate>
+inline void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
+ const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
+ unsigned int stridex, int input_offset)
+{
+ ARM_COMPUTE_UNUSED(input_offset);
+ ARM_COMPUTE_ERROR_ON(stridex > 3);
+
+ float32x4x2_t out =
+ {
+ {
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f)
+ }
+ };
+ if(stridex == 2)
+ {
+ const float32x4x2_t vtop = vld2q_f32(in_top);
+ const float32x4x2_t vmid = vld2q_f32(in_mid);
+ const float32x4x2_t vlow = vld2q_f32(in_low);
+ const float32x4_t vtop_end = vld1q_f32(in_top + 8);
+ const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
+ const float32x4_t vlow_end = vld1q_f32(in_low + 8);
+
+ out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
+
+ accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
+ }
+ else
+ {
+ const float32x4x3_t vtop =
+ {
+ {
+ vld1q_f32(in_top),
+ vld1q_f32(in_top + 4),
+ vld1q_f32(in_top + 8)
+ }
+ };
+ const float32x4x3_t vmid =
+ {
+ {
+ vld1q_f32(in_mid),
+ vld1q_f32(in_mid + 4),
+ vld1q_f32(in_mid + 8)
+ }
+ };
+ const float32x4x3_t vlow =
+ {
+ {
+ vld1q_f32(in_low),
+ vld1q_f32(in_low + 4),
+ vld1q_f32(in_low + 8)
+ }
+ };
+ out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
+ out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
+
+ out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
+ out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
+
+ if(stridex == 3)
+ {
+ out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
+ accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
+ }
+ else
+ {
+ accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
+ }
+ }
+}
+
+/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] dilation_x Dilation, in elements across x.
+ * @param[in] input_offset Input quantization offset.
+ *
+ */
+template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
+inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
+ const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
+ size_t dilation_x, int32_t input_offset)
+{
+ using VectorType = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
+ using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
+
+ const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
+
+ const VectorType vtop =
+ {
+ {
+ wrapper::vload(in_top),
+ wrapper::vload(in_top + dilation_x),
+ wrapper::vload(in_top + 2 * dilation_x)
+ }
+ };
+ const VectorType vmid =
+ {
+ {
+ wrapper::vload(in_mid),
+ wrapper::vload(in_mid + dilation_x),
+ wrapper::vload(in_mid + 2 * dilation_x)
+ }
+ };
+ const VectorType vlow =
+ {
+ {
+ wrapper::vload(in_low),
+ wrapper::vload(in_low + dilation_x),
+ wrapper::vload(in_low + 2 * dilation_x)
+ }
+ };
+
+ const int32x4x3_t vtop_s32 =
+ {
+ {
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))),
+ }
+ };
+ const int32x4x3_t vmid_s32 =
+ {
+ {
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))),
+ }
+ };
+ const int32x4x3_t vlow_s32 =
+ {
+ {
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))),
+ }
+ };
+
+ int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
+ out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
+ out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
+
+ out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
+ out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
+ out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
+
+ out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
+ out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
+ out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
+
+ return out;
+}
+
+/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] dilation_x Dilation, in elements across x.
+ * @param[in] stridex Stride value in elements across x.
+ * @param[in] input_offset Input quantization offset.
+ *
+ */
+template < typename T, REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
+inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
+ const size_t dilation_x, unsigned int stridex, int input_offset)
+{
+ ARM_COMPUTE_ERROR_ON(stridex > 3);
+ int32x4x2_t out =
+ {
+ {
+ single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
+ single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
+ }
+ };
+
+ if(stridex == 2)
+ {
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
+ }
+ else if(stridex == 3)
+ {
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
+ }
+ return out;
+}
+
+/** Perform a convolve3x3 on 8-bit elements
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[out] out_ptr Pointer to the output.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] stridex Stride value in elements across x.
+ * @param[in] input_offset Input quantization offset.
+ *
+ */
+template < bool accumulate, typename T1, typename T2, REQUIRES_TA(std::is_same<T1, uint8_t>::value || std::is_same<T1, int8_t>::value) >
+void convolve_3x3(const T1 *in_top, const T1 *in_mid, const T1 *in_low, T2 *out_ptr,
+ const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
+ unsigned int stridex, int32_t input_offset)
+{
+ ARM_COMPUTE_ERROR_ON(stridex > 3);
+ using VectorType = typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
+ using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
+
+ const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
+
+ const VectorType vtop =
+ {
+ {
+ wrapper::vload(in_top),
+ wrapper::vload(in_top + 8)
+ }
+ };
+ const VectorType vmid =
+ {
+ {
+ wrapper::vload(in_mid),
+ wrapper::vload(in_mid + 8)
+ }
+ };
+ const VectorType vlow =
+ {
+ {
+ wrapper::vload(in_low),
+ wrapper::vload(in_low + 8)
+ }
+ };
+
+ const int32x4x3_t vtop_s32 =
+ {
+ {
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
+ }
+ };
+ const int32x4x3_t vmid_s32 =
+ {
+ {
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
+ }
+ };
+ const int32x4x3_t vlow_s32 =
+ {
+ {
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))),
+ wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
+ }
+ };
+
+ int32x4x2_t out
+ {
+ {
+ wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
+ wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
+ }
+ };
+
+ // 0
+ out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
+ out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
+ out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
+
+ out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
+ out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
+ out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
+
+ out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
+ out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
+ out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
+
+ // 1
+ out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
+ out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
+ out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
+
+ out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
+ out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
+ out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
+
+ out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
+ out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
+ out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
+
+ if(stridex == 1)
+ {
+ accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
+ }
+ else if(stridex == 2)
+ {
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
+
+ accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
+ }
+ else if(stridex == 3)
+ {
+ out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
+ accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
+ }
+}
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+/** Loads a 3x3 matrix as a row (float16_t).
+ *
+ * @param[in] ptr Pointer to a float 3x3 matrix.
+ *
+ * @return The loaded matrix.
+ */
+inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
+{
+ ARM_COMPUTE_UNUSED(weights_offset);
+ /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
+ r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
+ const float16x8x3_t r =
+ {
+ {
+ vld1q_dup_f16(ptr),
+ vld1q_dup_f16(1 + ptr),
+ vld1q_dup_f16(2 + ptr)
+ }
+ };
+ return r;
+}
+
+/** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] dilation_x Dilation, in elements across x.
+ * @param[in] input_offset (Optional)Input quantization offset.
+ *
+ */
+inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
+ const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+ const size_t dilation_x, int input_offset = 0)
+{
+ ARM_COMPUTE_UNUSED(input_offset);
+ const float16x8x3_t vtop =
+ {
+ {
+ vld1q_f16(in_top),
+ vld1q_f16(in_top + dilation_x),
+ vld1q_f16(in_top + 2 * dilation_x)
+ }
+ };
+ const float16x8x3_t vmid =
+ {
+ {
+ vld1q_f16(in_mid),
+ vld1q_f16(in_mid + dilation_x),
+ vld1q_f16(in_mid + 2 * dilation_x)
+ }
+ };
+ const float16x8x3_t vlow =
+ {
+ {
+ vld1q_f16(in_low),
+ vld1q_f16(in_low + dilation_x),
+ vld1q_f16(in_low + 2 * dilation_x)
+ }
+ };
+ float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
+ out = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
+ out = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
+
+ out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
+ out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
+ out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
+
+ out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
+ out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
+ out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
+
+ return out;
+}
+
+/** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] dilation_x Dilation, in elements across x.
+ * @param[in] stridex Stride value in elements across x.
+ * @param[in] input_offset (Optional) Input quantization offset.
+ *
+ */
+inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
+ const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+ const size_t dilation_x, unsigned int stridex, int input_offset = 0)
+{
+ float16x8x2_t out =
+ {
+ {
+ single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
+ single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
+ }
+ };
+
+ if(stridex == 2)
+ {
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
+ }
+ else if(stridex == 3)
+ {
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
+ }
+
+ return out;
+}
+
+/** Perform a convolve3x3 on float16.
+ *
+ * @param[in] in_top Pointer to the first row of the input.
+ * @param[in] in_mid Pointer to the second row of the input.
+ * @param[in] in_low Pointer to the third row of the input.
+ * @param[out] out_ptr Pointer to the output.
+ * @param[in] m0 First row of the filter.
+ * @param[in] m1 Second row of the filter.
+ * @param[in] m2 Third row of the filter.
+ * @param[in] stridex Stride value in elements across x.
+ * @param[in] input_offset (Optional) Input quantization offset.
+ *
+ */
+template <bool accumulate>
+inline void convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, float16_t *out_ptr,
+ const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
+ unsigned int stridex, int input_offset = 0)
+{
+ ARM_COMPUTE_UNUSED(input_offset);
+
+ float16x8x2_t out =
+ {
+ {
+ vdupq_n_f16(0),
+ vdupq_n_f16(0)
+ }
+ };
+ if(stridex == 2)
+ {
+ const float16x8x2_t vtop = vld2q_f16(in_top);
+ const float16x8x2_t vmid = vld2q_f16(in_mid);
+ const float16x8x2_t vlow = vld2q_f16(in_low);
+ const float16x8_t vtop_end = vld1q_f16(in_top + 16);
+ const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
+ const float16x8_t vlow_end = vld1q_f16(in_low + 16);
+
+ out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
+
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
+
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
+
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
+
+ accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
+ }
+ else
+ {
+ const float16x8x3_t vtop =
+ {
+ {
+ vld1q_f16(in_top),
+ vld1q_f16(in_top + 8),
+ vld1q_f16(in_top + 16)
+ }
+ };
+ const float16x8x3_t vmid =
+ {
+ {
+ vld1q_f16(in_mid),
+ vld1q_f16(in_mid + 8),
+ vld1q_f16(in_mid + 16)
+ }
+ };
+ const float16x8x3_t vlow =
+ {
+ {
+ vld1q_f16(in_low),
+ vld1q_f16(in_low + 8),
+ vld1q_f16(in_low + 16)
+ }
+ };
+ out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
+ out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
+
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
+ out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
+
+ if(stridex == 3)
+ {
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
+ out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
+
+ accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
+ }
+ else
+ {
+ accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
+ }
+ }
+}
+#endif /** __ARM_FEATURE_FP16_VECTOR_ARITHMETIC **/
+
+/** Get the number of elements processed on 3x3 convolution.
+ *
+ * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
+ * @param[in] stridex Stride value in elements across x.
+ *
+ * @return The number of elements processed.
+ */
+inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
+{
+ switch(stridex)
+ {
+ case 1:
+ return num_elems_written_per_iteration;
+ case 2:
+ return num_elems_written_per_iteration << 1;
+ case 3:
+ return num_elems_written_per_iteration * 3;
+ default:
+ ARM_COMPUTE_ERROR("stridex not supported");
+ return 0;
+ }
+}
+}
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */