aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/convolution
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution')
-rw-r--r--arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h172
-rw-r--r--arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h721
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/alloc.hpp31
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/arm.hpp39
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/convolution.hpp29
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/perf.h32
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/profiler.hpp326
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/shims.hpp747
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/tensor.hpp177
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/tensor_utils.hpp43
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/utils.hpp37
-rw-r--r--arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp209
-rw-r--r--arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp348
-rw-r--r--arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp263
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp69
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp127
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp355
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp1446
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp195
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp77
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp181
-rw-r--r--arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp447
22 files changed, 5178 insertions, 893 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h b/arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h
deleted file mode 100644
index 7f39e5ee8d..0000000000
--- a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolution3x3.h
+++ /dev/null
@@ -1,172 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef __ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H__
-#define __ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H__
-
-#include <arm_neon.h>
-
-namespace arm_compute
-{
-namespace detail
-{
-inline float32x4x3_t load_matrix_row(const float *ptr)
-{
- const float32x4x3_t r =
- {
- {
- vld1q_dup_f32(ptr),
- vld1q_dup_f32(1 + ptr),
- vld1q_dup_f32(2 + ptr)
- }
- };
- return r;
-}
-
-template <unsigned int stridex>
-float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
-
-template <>
-inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
-
- const float32x4x3_t vtop =
- {
- {
- vld1q_f32(in_top),
- vld1q_f32(in_top + 4),
- vld1q_f32(in_top + 8)
- }
- };
- const float32x4x3_t vmid =
- {
- {
- vld1q_f32(in_mid),
- vld1q_f32(in_mid + 4),
- vld1q_f32(in_mid + 8)
- }
- };
- const float32x4x3_t vlow =
- {
- {
- vld1q_f32(in_low),
- vld1q_f32(in_low + 4),
- vld1q_f32(in_low + 8)
- }
- };
- float32x4x2_t out =
- {
- {
- vmulq_f32(vtop.val[0], m0.val[0]),
- vmulq_f32(vtop.val[1], m0.val[0])
- }
- };
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
-
- out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
-
- out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
-
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
-
- out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
-
- out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
- return out;
-}
-
-template <>
-inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
-{
- float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
- return out;
-}
-
-template <>
-inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
-{
- float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
- return out;
-}
-
-template <unsigned int stridex>
-void store_results(float *buffer, const float32x4x2_t &values);
-
-template <>
-void store_results<1>(float *buffer, const float32x4x2_t &values)
-{
- vst1q_f32(buffer, values.val[0]);
- vst1q_f32(buffer + 4, values.val[1]);
-}
-
-template <>
-void store_results<2>(float *buffer, const float32x4x2_t &values)
-{
- vst1q_f32(buffer, values.val[0]);
-}
-
-template <>
-void store_results<3>(float *buffer, const float32x4x2_t &values)
-{
- vst1_f32(buffer, vget_low_f32(values.val[0]));
-}
-
-template <unsigned int stridex>
-int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
-
-template <>
-int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
-{
- return num_elems_written_per_iteration;
-}
-
-template <>
-int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
-{
- return num_elems_written_per_iteration << 1;
-}
-
-template <>
-int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
-{
- return num_elems_written_per_iteration * 3;
-}
-}
-} // namespace arm_compute
-#endif /* __ARM_COMPUTE_NECONVOLUTIONKERNEL3x3_H__ */ \ No newline at end of file
diff --git a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h b/arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h
deleted file mode 100644
index 908fa13876..0000000000
--- a/arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h
+++ /dev/null
@@ -1,721 +0,0 @@
-/*
- * Copyright (c) 2017-2018 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#ifndef __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
-#define __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__
-
-#include "arm_compute/core/AccessWindowStatic.h"
-#include "arm_compute/core/NEON/NEFixedPoint.h"
-
-#include <arm_neon.h>
-
-namespace arm_compute
-{
-namespace detail
-{
-/** Loads a 3x3 matrix as a row (float).
- *
- * @param[in] ptr Pointer to a float 3x3 matrix.
- * @param[in] weights_offset (Optional) Weights quantization offset.
- *
- * @return The loaded matrix.
- */
-inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
-{
- ARM_COMPUTE_UNUSED(weights_offset);
- const float32x4x3_t r =
- {
- {
- vld1q_dup_f32(ptr),
- vld1q_dup_f32(1 + ptr),
- vld1q_dup_f32(2 + ptr)
- }
- };
- return r;
-}
-
-/** Loads a 3x3 matrix as a row (qint8_t).
- *
- * @param[in] ptr Pointer to a qint8 3x3 matrix.
- * @param[in] weights_offset (Optional) Weights quantization offset.
- *
- * @return The loaded matrix.
- */
-inline qint8x8x3_t load_matrix_row(const qint8_t *ptr, int weights_offset = 0)
-{
- ARM_COMPUTE_UNUSED(weights_offset);
- /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
- r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
- const qint8x8x3_t r =
- {
- {
- vld1_dup_qs8(ptr),
- vld1_dup_qs8(1 + ptr),
- vld1_dup_qs8(2 + ptr)
- }
- };
- return r;
-}
-
-/** Loads a 3x3 matrix as a row (uint8_t).
- *
- * @param[in] ptr Pointer to a uint8_t 3x3 matrix.
- * @param[in] weights_offset (Optional) Weights quantization offset.
- *
- * @return The loaded matrix.
- */
-inline int32x4x3_t load_matrix_row(const uint8_t *ptr, int weights_offset = 0)
-{
- const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
-
- /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
- r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
- int32x4x3_t r =
- {
- {
- vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
- vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
- vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
- }
- };
- return r;
-}
-
-/** Perform a convolve3x3 on float32.
- *
- * @param[in] in_top Pointer to the first row of the input.
- * @param[in] in_mid Pointer to the second row of the input.
- * @param[in] in_low Pointer to the third row of the input.
- * @param[in] m0 First row of the filter.
- * @param[in] m1 Second row of the filter.
- * @param[in] m2 Third row of the filter.
- * @param[in] fixed_point_position (Optional) Fixed point position.
- * @param[in] input_offset (Optional) Input quantization offset.
- *
- */
-template <unsigned int stridex>
-float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low,
- const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
- int fixed_point_position, int input_offset = 0);
-
-template <>
-inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low,
- const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
- ARM_COMPUTE_UNUSED(input_offset);
-
- const float32x4x3_t vtop =
- {
- {
- vld1q_f32(in_top),
- vld1q_f32(in_top + 4),
- vld1q_f32(in_top + 8)
- }
- };
- const float32x4x3_t vmid =
- {
- {
- vld1q_f32(in_mid),
- vld1q_f32(in_mid + 4),
- vld1q_f32(in_mid + 8)
- }
- };
- const float32x4x3_t vlow =
- {
- {
- vld1q_f32(in_low),
- vld1q_f32(in_low + 4),
- vld1q_f32(in_low + 8)
- }
- };
- float32x4x2_t out =
- {
- {
- vmulq_f32(vtop.val[0], m0.val[0]),
- vmulq_f32(vtop.val[1], m0.val[0])
- }
- };
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
-
- out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
-
- out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
- out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
-
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
-
- out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
-
- out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
- out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
- return out;
-}
-
-template <>
-inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low,
- const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(input_offset);
-
- float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position, input_offset);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
- return out;
-}
-
-template <>
-inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low,
- const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(input_offset);
-
- float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position, input_offset);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
- return out;
-}
-
-/** Perform a convolve3x3 on qint16.
- *
- * @param[in] in_top Pointer to the first row of the input.
- * @param[in] in_mid Pointer to the second row of the input.
- * @param[in] in_low Pointer to the third row of the input.
- * @param[in] m0 First row of the filter.
- * @param[in] m1 Second row of the filter.
- * @param[in] m2 Third row of the filter.
- * @param[in] fixed_point_position (Optional) Fixed point position.
- * @param[in] input_offset (Optional) Input quantization offset.
- *
- */
-template <unsigned int stridex>
-qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low,
- const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2,
- int fixed_point_position, int input_offset = 0);
-
-template <>
-inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low,
- const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
- ARM_COMPUTE_UNUSED(input_offset);
-
- const qint8x8x3_t vtop =
- {
- {
- vld1_qs8(in_top),
- vld1_qs8(in_top + 8),
- vld1_qs8(in_top + 16)
- }
- };
- const qint8x8x3_t vmid =
- {
- {
- vld1_qs8(in_mid),
- vld1_qs8(in_mid + 8),
- vld1_qs8(in_mid + 16)
- }
- };
- const qint8x8x3_t vlow =
- {
- {
- vld1_qs8(in_low),
- vld1_qs8(in_low + 8),
- vld1_qs8(in_low + 16)
- }
- };
- qint16x8x2_t out =
- {
- {
- vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
- vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
- }
- };
- out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
- out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
- out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
- return out;
-}
-
-template <>
-inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low,
- const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(input_offset);
-
- qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position, input_offset);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
- return out;
-}
-
-template <>
-inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low,
- const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(input_offset);
-
- qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position, input_offset);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
- out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
- return out;
-}
-
-/** Perform a convolve3x3 on uint8_t
- *
- * @param[in] in_top Pointer to the first row of the input.
- * @param[in] in_mid Pointer to the second row of the input.
- * @param[in] in_low Pointer to the third row of the input.
- * @param[in] m0 First row of the filter.
- * @param[in] m1 Second row of the filter.
- * @param[in] m2 Third row of the filter.
- * @param[in] fixed_point_position (Optional) Fixed point position.
- * @param[in] input_offset (Optional) Input quantization offset.
- *
- */
-template <unsigned int stridex>
-int32x4x2_t convolve_3x3(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
- const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
- int fixed_point_position, int input_offset);
-
-template <>
-inline int32x4x2_t convolve_3x3<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
-
- const int32x4_t v_input_offset = vdupq_n_s32(input_offset);
-
- const uint8x8x2_t vtop =
- {
- {
- vld1_u8(in_top),
- vld1_u8(in_top + 8)
- }
- };
- const uint8x8x2_t vmid =
- {
- {
- vld1_u8(in_mid),
- vld1_u8(in_mid + 8)
- }
- };
- const uint8x8x2_t vlow =
- {
- {
- vld1_u8(in_low),
- vld1_u8(in_low + 8)
- }
- };
-
- const int32x4x3_t vtop_s32 =
- {
- {
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))),
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vtop.val[0])))),
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))),
- }
- };
- const int32x4x3_t vmid_s32 =
- {
- {
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))),
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vmid.val[0])))),
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))),
- }
- };
- const int32x4x3_t vlow_s32 =
- {
- {
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))),
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vlow.val[0])))),
- vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))),
- }
- };
-
- int32x4x2_t out
- {
- {
- vdupq_n_s32(0),
- vdupq_n_s32(0),
- }
- };
-
- // 0
- out.val[0] = vmlaq_s32(out.val[0], vtop_s32.val[0], m0.val[0]);
- out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 1), m0.val[1]);
- out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 2), m0.val[2]);
-
- out.val[0] = vmlaq_s32(out.val[0], vmid_s32.val[0], m1.val[0]);
- out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 1), m1.val[1]);
- out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 2), m1.val[2]);
-
- out.val[0] = vmlaq_s32(out.val[0], vlow_s32.val[0], m2.val[0]);
- out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 1), m2.val[1]);
- out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 2), m2.val[2]);
-
- // 1
- out.val[1] = vmlaq_s32(out.val[1], vtop_s32.val[1], m0.val[0]);
- out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 1), m0.val[1]);
- out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 2), m0.val[2]);
-
- out.val[1] = vmlaq_s32(out.val[1], vmid_s32.val[1], m1.val[0]);
- out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 1), m1.val[1]);
- out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 2), m1.val[2]);
-
- out.val[1] = vmlaq_s32(out.val[1], vlow_s32.val[1], m2.val[0]);
- out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 1), m2.val[1]);
- out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 2), m2.val[2]);
-
- return out;
-}
-
-template <>
-inline int32x4x2_t convolve_3x3<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
- const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
-
- int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position, input_offset);
- out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2);
- out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3);
- return out;
-}
-
-template <>
-inline int32x4x2_t convolve_3x3<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low,
- const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
- int fixed_point_position, int input_offset)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
- int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position, input_offset);
- out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1);
- return out;
-}
-
-/** Stores a float32x4x2_t array into a memory location.
- *
- * @param[in] buffer Pointer to the memory location where the values will be stored.
- * @param[in] values Values that will be stored.
- *
- */
-template <unsigned int stridex>
-void store_results(float *buffer, const float32x4x2_t &values);
-
-template <>
-inline void store_results<1>(float *buffer, const float32x4x2_t &values)
-{
- vst1q_f32(buffer, values.val[0]);
- vst1q_f32(buffer + 4, values.val[1]);
-}
-
-template <>
-inline void store_results<2>(float *buffer, const float32x4x2_t &values)
-{
- vst1q_f32(buffer, values.val[0]);
-}
-
-template <>
-inline void store_results<3>(float *buffer, const float32x4x2_t &values)
-{
- vst1_f32(buffer, vget_low_f32(values.val[0]));
-}
-
-/** Stores a qint16_t array into a memory location.
- *
- * @param[in] buffer Pointer to the memory location where the values will be stored.
- * @param[in] values Values that will be stored.
- *
- */
-template <unsigned int stridex>
-void store_results(qint16_t *buffer, const qint16x8x2_t &values);
-
-template <>
-inline void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
-{
- vst1q_qs16(buffer, values.val[0]);
- vst1q_qs16(buffer + 8, values.val[1]);
-}
-
-template <>
-inline void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
-{
- vst1q_qs16(buffer, values.val[0]);
-}
-
-template <>
-inline void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
-{
- vst1_qs16(buffer, vget_low_s16(values.val[0]));
-}
-
-/** Stores a uint32_t array into a memory location.
- *
- * @param[in] buffer Pointer to the memory location where the values will be stored.
- * @param[in] values Values that will be stored.
- *
- */
-template <unsigned int stridex>
-void store_results(int32_t *buffer, const int32x4x2_t &values);
-
-template <>
-inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
-{
- vst1q_s32(buffer, values.val[0]);
- vst1q_s32(buffer + 4, values.val[1]);
-}
-
-template <>
-inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
-{
- vst1q_s32(buffer, values.val[0]);
-}
-
-template <>
-inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
-{
- vst1_s32(buffer, vget_low_s32(values.val[0]));
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-/** Loads a 3x3 matrix as a row (float16_t).
- *
- * @param[in] ptr Pointer to a float 3x3 matrix.
- *
- * @return The loaded matrix.
- */
-inline float16x8x3_t load_matrix_row(const float16_t *ptr)
-{
- /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
- r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
- const float16x8x3_t r =
- {
- {
- vld1q_dup_f16(ptr),
- vld1q_dup_f16(1 + ptr),
- vld1q_dup_f16(2 + ptr)
- }
- };
- return r;
-}
-
-/** Perform a convolve3x3 on float16.
- *
- * @param[in] in_top Pointer to the first row of the input.
- * @param[in] in_mid Pointer to the second row of the input.
- * @param[in] in_low Pointer to the third row of the input.
- * @param[in] m0 First row of the filter.
- * @param[in] m1 Second row of the filter.
- * @param[in] m2 Third row of the filter.
- * @param[in] fixed_point_position (Optional) Fixed point position.
- *
- */
-template <unsigned int stridex>
-float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
- int fixed_point_position);
-
-template <>
-inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
- int fixed_point_position)
-{
- ARM_COMPUTE_UNUSED(fixed_point_position);
-
- const float16x8x3_t vtop =
- {
- {
- vld1q_f16(in_top),
- vld1q_f16(in_top + 8),
- vld1q_f16(in_top + 16)
- }
- };
- const float16x8x3_t vmid =
- {
- {
- vld1q_f16(in_mid),
- vld1q_f16(in_mid + 8),
- vld1q_f16(in_mid + 16)
- }
- };
- const float16x8x3_t vlow =
- {
- {
- vld1q_f16(in_low),
- vld1q_f16(in_low + 8),
- vld1q_f16(in_low + 16)
- }
- };
- float16x8x2_t out =
- {
- {
- vmulq_f16(vtop.val[0], m0.val[0]),
- vmulq_f16(vtop.val[1], m0.val[0])
- }
- };
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
- out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
- out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
- return out;
-}
-
-template <>
-inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
- int fixed_point_position)
-{
- float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
- return out;
-}
-
-template <>
-inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
- int fixed_point_position)
-{
- float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
- return out;
-}
-
-/** Stores a float16x8x2_t array into a memory location.
- *
- * @param[in] buffer Pointer to the memory location where the values will be stored.
- * @param[in] values Values that will be stored.
- *
- */
-template <unsigned int stridex>
-void store_results(float16_t *buffer, const float16x8x2_t &values);
-
-template <>
-inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
-{
- vst1q_f16(buffer, values.val[0]);
- vst1q_f16(buffer + 8, values.val[1]);
-}
-
-template <>
-inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
-{
- vst1q_f16(buffer, values.val[0]);
-}
-
-template <>
-inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
-{
- vst1_f16(buffer, vget_low_f16(values.val[0]));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-/** Get the number of elements processed on 3x3 convolution.
- *
- * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
- *
- * @return The number of elements processed.
- */
-template <unsigned int stridex>
-int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
-
-template <>
-inline int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
-{
- return num_elems_written_per_iteration;
-}
-
-template <>
-inline int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
-{
- return num_elems_written_per_iteration << 1;
-}
-
-template <>
-inline int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
-{
- return num_elems_written_per_iteration * 3;
-}
-inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
-{
- switch(stridex)
- {
- case 1:
- return get_input_num_elems_processed<1>(num_elems_written_per_iteration);
- case 2:
- return get_input_num_elems_processed<2>(num_elems_written_per_iteration);
- case 3:
- return get_input_num_elems_processed<3>(num_elems_written_per_iteration);
- default:
- ARM_COMPUTE_ERROR("stridex not supported");
- return 0;
- }
-}
-}
-} // namespace arm_compute
-#endif /* __ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H__ */
diff --git a/arm_compute/core/NEON/kernels/convolution/common/alloc.hpp b/arm_compute/core/NEON/kernels/convolution/common/alloc.hpp
new file mode 100644
index 0000000000..799e95d3e6
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/alloc.hpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#ifdef ALLOC_ALIGN
+#define ALLOCATE(x) aligned_alloc(ALLOC_ALIGN, x)
+#else
+#define ALLOCATE(x) malloc(x)
+#endif
diff --git a/arm_compute/core/NEON/kernels/convolution/common/arm.hpp b/arm_compute/core/NEON/kernels/convolution/common/arm.hpp
new file mode 100644
index 0000000000..90e7828553
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/arm.hpp
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/** Sets the macro __arm_any__ if compiling for Aarch32 or Aarch64.
+ * Includes `arm_neon.h` if compiling for either architecture.
+ */
+
+#ifdef __arm__
+#define __arm_any__
+#endif // __arm__
+
+#ifdef __aarch64__
+#define __arm_any__
+#endif // __aarch64__
+
+#ifdef __arm_any__
+#include <arm_neon.h>
+#endif // __arm_any__
diff --git a/arm_compute/core/NEON/kernels/convolution/common/convolution.hpp b/arm_compute/core/NEON/kernels/convolution/common/convolution.hpp
new file mode 100644
index 0000000000..2ab2597785
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/convolution.hpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+enum PaddingType {
+ PADDING_SAME, PADDING_VALID
+};
diff --git a/arm_compute/core/NEON/kernels/convolution/common/perf.h b/arm_compute/core/NEON/kernels/convolution/common/perf.h
new file mode 100644
index 0000000000..3c0d36646d
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/perf.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+/* Prototypes from perf.c */
+
+void start_counter(int fd);
+long long get_counter(int fd);
+long long stop_counter(int fd);
+int open_instruction_counter(void);
+int open_cycle_counter(void);
diff --git a/arm_compute/core/NEON/kernels/convolution/common/profiler.hpp b/arm_compute/core/NEON/kernels/convolution/common/profiler.hpp
new file mode 100644
index 0000000000..01fafa9604
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/profiler.hpp
@@ -0,0 +1,326 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <cstdio>
+#include <map>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+#include "perf.h"
+#include <unistd.h>
+
+#ifdef CYCLE_PROFILING
+class EventIDContainer
+{
+ public:
+ EventIDContainer() : container_lock(), event_ids()
+ {
+ }
+
+ int get_event_id(const char *id)
+ {
+ std::lock_guard<std::mutex> lock(container_lock);
+ if (!event_ids.count(id)) {
+ event_ids.emplace(id, event_ids.size());
+ }
+ return event_ids[id];
+ }
+
+ unsigned int size() const
+ {
+ return event_ids.size();
+ }
+
+ auto begin()
+ {
+ return event_ids.begin();
+ }
+
+ auto end()
+ {
+ return event_ids.end();
+ }
+
+ private:
+ std::mutex container_lock;
+ std::map<const char *, int> event_ids;
+};
+
+
+class ThreadEventCounterContainer
+{
+ public:
+ ThreadEventCounterContainer() : container_lock(), thread_counter_fds()
+ {
+ }
+
+ int get_counter_fd()
+ {
+ const auto id = std::this_thread::get_id();
+ std::lock_guard<std::mutex> lock(container_lock);
+ if (!thread_counter_fds.count(id))
+ {
+ thread_counter_fds.emplace(id, open_cycle_counter());
+ }
+ return thread_counter_fds[id];
+ }
+
+ ~ThreadEventCounterContainer()
+ {
+ // Close all counter file descriptors
+ for (auto& fd : thread_counter_fds)
+ {
+ close(fd.second);
+ }
+ }
+
+ private:
+ std::mutex container_lock;
+ std::map<std::thread::id, int> thread_counter_fds;
+};
+#endif // CYCLE_PROFILING
+
+
+class profiler {
+private:
+#ifdef CYCLE_PROFILING
+ struct ProfileEntry {
+ int event_id;
+ long int bytes_read, ops, bytes_written;
+ long int duration;
+ };
+
+ static const int maxevents = 10000;
+ ProfileEntry events[maxevents];
+ int currentevent;
+ std::mutex event_lock;
+
+ EventIDContainer event_ids;
+ ThreadEventCounterContainer thread_counter_fds;
+
+ int get_event_id(const char *id)
+ {
+ return event_ids.get_event_id(id);
+ }
+#endif // CYCLE_PROFILING
+
+public:
+#ifdef CYCLE_PROFILING
+ profiler() :
+ currentevent(0),
+ event_lock(),
+ event_ids(),
+ thread_counter_fds()
+ {
+ }
+
+ ~profiler() {
+ std::lock_guard<std::mutex> lock_events(event_lock);
+
+ // Compute performance from recorded events
+ struct ProfileResult {
+ ProfileResult() : total_calls(0),
+ total_duration(0),
+ total_bytes_read(0),
+ total_ops(0),
+ total_bytes_written(0) {
+ }
+
+ void operator+=(const ProfileEntry &rhs) {
+ total_calls++;
+ total_duration += rhs.duration;
+ total_bytes_read += rhs.bytes_read;
+ total_ops += rhs.ops;
+ total_bytes_written = rhs.bytes_written;
+ }
+
+ float avg_duration(void) const {
+ return static_cast<float>(total_duration) /
+ static_cast<float>(total_calls);
+ }
+
+ float bytes_read_per_cycle(void) const {
+ return static_cast<float>(total_bytes_read) /
+ static_cast<float>(total_duration);
+ }
+
+ float ops_per_cycle(void) const {
+ return static_cast<float>(total_ops) /
+ static_cast<float>(total_duration);
+ }
+
+ float bytes_written_per_cycle(void) const {
+ return static_cast<float>(total_bytes_written) /
+ static_cast<float>(total_duration);
+ }
+
+ long int total_calls,
+ total_duration,
+ total_bytes_read,
+ total_ops,
+ total_bytes_written;
+ };
+
+ std::vector<ProfileResult> totals;
+ totals.resize(event_ids.size());
+ for (int i = 0; i < currentevent; i++) {
+ const auto &event = events[i];
+ totals[event.event_id] += event;
+ }
+
+ // Get the longest label
+ int len_label = 0;
+ for (const auto &kv : event_ids) {
+ len_label = std::max(len_label, static_cast<int>(strlen(kv.first)));
+ }
+
+ // Get the longest values for every other field
+ const auto get_length_of_field =
+ [totals] (const char *title, auto f, auto len) -> size_t {
+ size_t l = strlen(title);
+ for (const auto &v : totals) {
+ l = std::max(l, len(f(v)));
+ }
+ return l;
+ };
+
+ // Get the strlen for an int
+ const auto intlen = [] (long int x) -> size_t {
+ size_t len = 0;
+ do {
+ x /= 10;
+ len++;
+ } while (x);
+ return len;
+ };
+
+ // Get the strlen for a float
+ const auto floatlen = [] (const int precision) {
+ return [precision] (float x) {
+ size_t len = 0;
+
+ if (!std::isfinite(x)) {
+ return static_cast<size_t>(3);
+ }
+
+ do {
+ x /= 10.0f;
+ len++;
+ } while (x > 1.0f);
+ return len + 1 + precision;
+ };
+ };
+
+ const int len_calls = get_length_of_field(
+ "Calls", [] (const auto &v) {return v.total_calls;},
+ intlen
+ );
+ const int len_duration = get_length_of_field(
+ "Duration", [] (const auto &v) {return v.total_duration;},
+ intlen
+ );
+ const int len_average_duration = get_length_of_field(
+ "Average", [] (const auto &v) {return v.avg_duration();},
+ floatlen(2)
+ );
+ const int len_reads_per_cycle = get_length_of_field(
+ "Reads / cycle",
+ [] (const auto &v) {return v.bytes_read_per_cycle();},
+ floatlen(6)
+ );
+ const int len_ops_per_cycle = get_length_of_field(
+ "Ops / cycle",
+ [] (const auto &v) {return v.ops_per_cycle();},
+ floatlen(6)
+ );
+ const int len_writes_per_cycle = get_length_of_field(
+ "Writes / cycle",
+ [] (const auto &v) {return v.bytes_written_per_cycle();},
+ floatlen(6)
+ );
+
+ // Print header
+ printf(
+ "%*s %*s %*s %*s %*s %*s %*s\n",
+ len_label, "",
+ len_calls, "Calls",
+ len_duration, "Duration",
+ len_average_duration, "Average",
+ len_reads_per_cycle, "Reads / cycle",
+ len_ops_per_cycle, "Ops / cycle",
+ len_writes_per_cycle, "Writes / cycle"
+ );
+ for (const auto &kv : event_ids) {
+ const auto id = kv.second;
+ printf(
+ "%*s %*ld %*ld %*.2f %*.6f %*.6f %*.6f\n",
+ len_label, kv.first,
+ len_calls, totals[id].total_calls,
+ len_duration, totals[id].total_duration,
+ len_average_duration, totals[id].avg_duration(),
+ len_reads_per_cycle, totals[id].bytes_read_per_cycle(),
+ len_ops_per_cycle, totals[id].ops_per_cycle(),
+ len_writes_per_cycle, totals[id].bytes_written_per_cycle()
+ );
+ }
+ printf("\n");
+ }
+#endif // CYCLE_PROFILING
+
+ template <typename T>
+ void operator() (const char * event,
+ T func,
+ long int bytes_read = 0,
+ long int ops = 0,
+ long int bytes_written = 0) {
+#ifdef CYCLE_PROFILING
+ if (currentevent==maxevents) {
+ func();
+ } else {
+ const auto countfd = thread_counter_fds.get_counter_fd();
+ start_counter(countfd);
+ func();
+ long long cycs = stop_counter(countfd);
+
+ // Store the profiling data
+ std::lock_guard<std::mutex> lock_events(event_lock);
+ events[currentevent++] = {
+ get_event_id(event), bytes_read, ops, bytes_written, cycs
+ };
+ }
+#else
+ (void) event;
+ (void) bytes_read;
+ (void) ops;
+ (void) bytes_written;
+ func();
+#endif // CYCLE_PROFILING
+ }
+};
diff --git a/arm_compute/core/NEON/kernels/convolution/common/shims.hpp b/arm_compute/core/NEON/kernels/convolution/common/shims.hpp
new file mode 100644
index 0000000000..09e14577ff
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/shims.hpp
@@ -0,0 +1,747 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cstdint>
+#include "arm.hpp"
+
+namespace reorder {
+/** Re-order a tensor from NCHW format to NHWC.
+ *
+ * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
+ *
+ * @param[in] in Input tensor in NCHW format.
+ * @param[out] out Output tensor, to be written in NHWC format.
+ * @param n_batches Number of batches in the tensors.
+ * @param n_channels Number of channels in the tensors
+ * @param n_rows Height of the tensor
+ * @param n_cols Width of the tensor
+ * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_channels * in_channel_stride`.
+ * @param in_channel_stride Stride over channels in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
+ * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols`.
+ * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
+ * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols * out_col_stride`.
+ * @param out_col_stride Stride over columns in the output tensor. If `0` defaults to `n_channels`.
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+ const T* const in,
+ T* const out,
+ const int n_batches,
+ const int n_channels,
+ const int n_rows,
+ const int n_cols,
+ int in_batch_stride=0,
+ int in_channel_stride=0,
+ int in_row_stride=0,
+ int out_batch_stride=0,
+ int out_row_stride=0,
+ int out_col_stride=0
+);
+
+/** Re-order a tensor from NHWC format to NCHW.
+ *
+ * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
+ *
+ * @param[in] in Input tensor in NHWC format.
+ * @param[out] out Output tensor, to be written in NCHW format.
+ * @param n_batches Number of batches in the tensors.
+ * @param n_rows Height of the tensor
+ * @param n_cols Width of the tensor
+ * @param n_channels Number of channels in the tensors
+ * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
+ * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols * in_col_stride`.
+ * @param in_col_stride Stride over columns in the input tensor. If `0` defaults to `n_channels`.
+ * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_channels * out_channel_stride`.
+ * @param out_channel_stride Stride over channels in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
+ * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols`.
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+ const T* const in, // Input data in NHWC form
+ T* const out, // Output data in NCHW form
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ int in_batch_stride=0,
+ int in_row_stride=0,
+ int in_col_stride=0,
+ int out_batch_stride=0,
+ int out_channel_stride=0,
+ int out_row_stride=0
+);
+
+/** Re-order a weight tensor from [Output feature map x Input feature map x
+ * Height x Width] format to [Height x Width x Input feature map x Output
+ * feature map] format.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+ const T* const in, // Input in [Output x Input x Height x Width] form
+ T* const out, // Output in [Height x Width x Input x Output] form
+ const int n_output_feature_maps,
+ const int n_input_feature_maps,
+ const int n_rows,
+ const int n_cols,
+ int in_output_feature_map_stride=0,
+ int in_input_feature_map_stride=0,
+ int in_row_stride=0,
+ int out_row_stride=0,
+ int out_col_stride=0,
+ int out_input_feature_map_stride=0
+);
+
+/** Re-order a weight tensor from [Height x Width x Input feature map x Output
+ * feature map] format to [Output feature map x Input feature map x Height x
+ * Width] format.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+ const T* const in, // Input in [Height x Width x Input x Output] form
+ T* const out, // Output in [Output x Input x Height x Width] form
+ const int n_rows,
+ const int n_cols,
+ const int n_input_feature_maps,
+ const int n_output_feature_maps,
+ int in_row_stride=0,
+ int in_col_stride=0,
+ int in_input_feature_map_stride=0,
+ int out_output_feature_map_stride=0,
+ int out_input_feature_map_stride=0,
+ int out_row_stride=0
+);
+
+/*****************************************************************************/
+/* 32-bit implementation : NCHW -> NHWC
+ */
+template <>
+inline void nchw_to_nhwc(
+ const int32_t* const in,
+ int32_t* const out,
+ const int n_batches,
+ const int n_channels,
+ const int n_rows,
+ const int n_cols,
+ int in_batch_stride,
+ int in_channel_stride,
+ int in_row_stride,
+ int out_batch_stride,
+ int out_row_stride,
+ int out_col_stride
+)
+{
+ typedef int32_t T;
+
+ // Fill in the stride values
+ in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+ in_channel_stride = (in_channel_stride) ? in_channel_stride
+ : n_rows * in_row_stride;
+ in_batch_stride = (in_batch_stride) ? in_batch_stride
+ : n_channels * in_channel_stride;
+
+ out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+ out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+ out_batch_stride = (out_batch_stride) ? out_batch_stride
+ : n_rows * out_row_stride;
+
+ // Perform the re-ordering
+ for (int n = 0; n < n_batches; n++)
+ {
+ const T* const in_batch = in + n*in_batch_stride;
+ T* const out_batch = out + n*out_batch_stride;
+
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in_batch + i*in_row_stride;
+ T* const out_row = out_batch + i*out_row_stride;
+
+ int j = 0, j_remaining = n_cols;
+#ifdef __arm_any__
+ for (; j_remaining >= 4; j += 4, j_remaining -= 4)
+ {
+ int c = 0, c_remaining = n_channels;
+ for (; c_remaining >= 4; c += 4, c_remaining -= 4)
+ {
+ // Read 4 channels worth of 4 columns, then zip to produce 4 columns
+ // worth of 4 channels.
+ int32x4_t channel_pixels[4];
+ channel_pixels[0] = vld1q_s32(in_row + (c + 0)*in_channel_stride + j);
+ channel_pixels[1] = vld1q_s32(in_row + (c + 1)*in_channel_stride + j);
+ channel_pixels[2] = vld1q_s32(in_row + (c + 2)*in_channel_stride + j);
+ channel_pixels[3] = vld1q_s32(in_row + (c + 3)*in_channel_stride + j);
+
+ const auto zip1 = vzipq_s32(channel_pixels[0], channel_pixels[2]);
+ const auto zip2 = vzipq_s32(channel_pixels[1], channel_pixels[3]);
+ const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
+ const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
+
+ vst1q_s32(out_row + (j + 0)*out_col_stride + c, out_0.val[0]);
+ vst1q_s32(out_row + (j + 1)*out_col_stride + c, out_0.val[1]);
+ vst1q_s32(out_row + (j + 2)*out_col_stride + c, out_1.val[0]);
+ vst1q_s32(out_row + (j + 3)*out_col_stride + c, out_1.val[1]);
+ }
+ for (; c_remaining; c++, c_remaining--)
+ {
+ for (int _j = 0; _j < 4; _j++)
+ {
+ const T* const in_col = in_row + j + _j;
+ T* const out_col = out_row + (j + _j)*out_col_stride;
+ const T* const in_channel = in_col + c*in_channel_stride;
+ out_col[c] = *(in_channel);
+ }
+ }
+ }
+ for (; j_remaining >= 2; j += 2, j_remaining -= 2)
+ {
+ int c = 0, c_remaining = n_channels;
+ for (; c_remaining >= 2; c += 2, c_remaining -= 2)
+ {
+ // Read 2 channels worth of 2 columns, then zip to produce 2 columns
+ // worth of 2 channels.
+ int32x2_t channel_pixels[2];
+ channel_pixels[0] = vld1_s32(in_row + (c + 0)*in_channel_stride + j);
+ channel_pixels[1] = vld1_s32(in_row + (c + 1)*in_channel_stride + j);
+
+ const auto output = vzip_s32(channel_pixels[0], channel_pixels[1]);
+
+ vst1_s32(out_row + (j + 0)*out_col_stride + c, output.val[0]);
+ vst1_s32(out_row + (j + 1)*out_col_stride + c, output.val[1]);
+ }
+ for (; c_remaining; c++, c_remaining--)
+ {
+ for (int _j = 0; _j < 2; _j++)
+ {
+ const T* const in_col = in_row + j + _j;
+ T* const out_col = out_row + (j + _j)*out_col_stride;
+ const T* const in_channel = in_col + c*in_channel_stride;
+ out_col[c] = *(in_channel);
+ }
+ }
+ }
+#endif // __arm_any__
+ for (; j_remaining; j++, j_remaining--)
+ {
+ const T* const in_col = in_row + j;
+ T* const out_col = out_row + j*out_col_stride;
+
+ for (int c = 0; c < n_channels; c++)
+ {
+ const T* const in_channel = in_col + c*in_channel_stride;
+ out_col[c] = *(in_channel);
+ }
+ }
+ }
+ }
+}
+
+template <>
+inline void nchw_to_nhwc(
+ const uint32_t* const in,
+ uint32_t* const out,
+ const int n_batches,
+ const int n_channels,
+ const int n_rows,
+ const int n_cols,
+ int in_batch_stride,
+ int in_channel_stride,
+ int in_row_stride,
+ int out_batch_stride,
+ int out_row_stride,
+ int out_col_stride
+)
+{
+ nchw_to_nhwc(
+ reinterpret_cast<const int32_t*>(in),
+ reinterpret_cast<int32_t*>(out),
+ n_batches, n_channels, n_rows, n_cols,
+ in_batch_stride, in_channel_stride, in_row_stride,
+ out_batch_stride, out_row_stride, out_col_stride
+ );
+}
+
+template <>
+inline void nchw_to_nhwc(
+ const float* const in,
+ float* const out,
+ const int n_batches,
+ const int n_channels,
+ const int n_rows,
+ const int n_cols,
+ int in_batch_stride,
+ int in_channel_stride,
+ int in_row_stride,
+ int out_batch_stride,
+ int out_row_stride,
+ int out_col_stride
+)
+{
+ nchw_to_nhwc(
+ reinterpret_cast<const int32_t*>(in),
+ reinterpret_cast<int32_t*>(out),
+ n_batches, n_channels, n_rows, n_cols,
+ in_batch_stride, in_channel_stride, in_row_stride,
+ out_batch_stride, out_row_stride, out_col_stride
+ );
+}
+
+/*****************************************************************************/
+/* Generic implementation : NCHW -> NHWC
+ */
+template <typename T>
+inline void nchw_to_nhwc(
+ const T* const in,
+ T* const out,
+ const int n_batches,
+ const int n_channels,
+ const int n_rows,
+ const int n_cols,
+ int in_batch_stride,
+ int in_channel_stride,
+ int in_row_stride,
+ int out_batch_stride,
+ int out_row_stride,
+ int out_col_stride
+)
+{
+ // Fill in the stride values
+ in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
+ in_channel_stride = (in_channel_stride) ? in_channel_stride
+ : n_rows * in_row_stride;
+ in_batch_stride = (in_batch_stride) ? in_batch_stride
+ : n_channels * in_channel_stride;
+
+ out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
+ out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
+ out_batch_stride = (out_batch_stride) ? out_batch_stride
+ : n_rows * out_row_stride;
+
+ // Perform the re-ordering
+ for (int n = 0; n < n_batches; n++)
+ {
+ const T* const in_batch = in + n*in_batch_stride;
+ T* const out_batch = out + n*out_batch_stride;
+
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in_batch + i*in_row_stride;
+ T* const out_row = out_batch + i*out_row_stride;
+
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_col = in_row + j;
+ T* const out_col = out_row + j*out_col_stride;
+
+ for (int c = 0; c < n_channels; c++)
+ {
+ const T* const in_channel = in_col + c*in_channel_stride;
+ out_col[c] = *(in_channel);
+ }
+ }
+ }
+ }
+}
+
+/*****************************************************************************/
+/* 32-bit implementation : NHWC -> NCHW
+ */
+template <>
+inline void nhwc_to_nchw(
+ const int32_t* const in, // Input data in NHWC form
+ int32_t* const out, // Output data in NCHW form
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ int in_batch_stride,
+ int in_row_stride,
+ int in_col_stride,
+ int out_batch_stride,
+ int out_channel_stride,
+ int out_row_stride
+)
+{
+ typedef int32_t T;
+
+ // Fill in stride values
+ in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+ in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+ in_batch_stride = (in_batch_stride) ? in_batch_stride
+ : n_rows * in_row_stride;
+
+ out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+ out_channel_stride = (out_channel_stride) ? out_channel_stride
+ : n_rows * out_row_stride;
+ out_batch_stride = (out_batch_stride) ? out_batch_stride
+ : n_channels * out_channel_stride;
+
+ // Perform the re-ordering
+ // For every batch
+ for (int n = 0; n < n_batches; n++)
+ {
+ const T* const in_batch = in + n*in_batch_stride;
+ T* const out_batch = out + n*out_batch_stride;
+
+ // For every row
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_i = in_batch + i*in_row_stride;
+ T* const out_i = out_batch + i*out_row_stride;
+
+ // For every column, beginning with chunks of 4
+ int j = 0, j_remaining = n_cols;
+#ifdef __arm_any__
+ for (; j_remaining >= 4; j += 4, j_remaining -=4)
+ {
+ // For every channel, beginning with chunks of 4
+ int c = 0, c_remaining = n_channels;
+ for (; c_remaining >= 4; c += 4, c_remaining -= 4)
+ {
+ // Read 4 columns worth of 4 channels then zip to produce 4 channels
+ // worth of 4 columns.
+ int32x4_t pixel_channels[4];
+ pixel_channels[0] = vld1q_s32(in_i + (j + 0)*in_col_stride + c);
+ pixel_channels[1] = vld1q_s32(in_i + (j + 1)*in_col_stride + c);
+ pixel_channels[2] = vld1q_s32(in_i + (j + 2)*in_col_stride + c);
+ pixel_channels[3] = vld1q_s32(in_i + (j + 3)*in_col_stride + c);
+
+ const auto zip1 = vzipq_s32(pixel_channels[0], pixel_channels[2]);
+ const auto zip2 = vzipq_s32(pixel_channels[1], pixel_channels[3]);
+ const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
+ const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
+
+ vst1q_s32(out_i + j + (c + 0)*out_channel_stride, out_0.val[0]);
+ vst1q_s32(out_i + j + (c + 1)*out_channel_stride, out_0.val[1]);
+ vst1q_s32(out_i + j + (c + 2)*out_channel_stride, out_1.val[0]);
+ vst1q_s32(out_i + j + (c + 3)*out_channel_stride, out_1.val[1]);
+ }
+ for (; c_remaining; c++, c_remaining--)
+ {
+ for (int _j = 0; _j < 4; _j++)
+ {
+ const T* const in_j = in_i + (j + _j)*in_col_stride;
+ T* const out_j = out_i + (j + _j);
+
+ const T* const in_channel = in_j + c;
+ T* const out_channel = out_j + c*out_channel_stride;
+ *(out_channel) = *(in_channel);
+ }
+ }
+ }
+ for (; j_remaining >= 2; j += 2, j_remaining -=2)
+ {
+ int c = 0, c_remaining = n_channels;
+ for (; c_remaining >= 2; c += 2, c_remaining -= 2)
+ {
+ // Read 2 columns worth of 2 channels then zip to produce 2 channels
+ // worth of 2 columns.
+ int32x2_t pixel_channels[2];
+ pixel_channels[0] = vld1_s32(in_i + (j + 0)*in_col_stride + c);
+ pixel_channels[1] = vld1_s32(in_i + (j + 1)*in_col_stride + c);
+
+ const auto output = vzip_s32(pixel_channels[0], pixel_channels[1]);
+
+ vst1_s32(out_i + j + (c + 0)*out_channel_stride, output.val[0]);
+ vst1_s32(out_i + j + (c + 1)*out_channel_stride, output.val[1]);
+ }
+ for (; c_remaining; c++, c_remaining--)
+ {
+ for (int _j = 0; _j < 2; _j++)
+ {
+ const T* const in_j = in_i + (j + _j)*in_col_stride;
+ T* const out_j = out_i + (j + _j);
+
+ const T* const in_channel = in_j + c;
+ T* const out_channel = out_j + c*out_channel_stride;
+ *(out_channel) = *(in_channel);
+ }
+ }
+ }
+#endif // __arm_any__
+ for (; j_remaining; j++, j_remaining--)
+ {
+ const T* const in_j = in_i + j*in_col_stride;
+ T* const out_j = out_i + j;
+
+ // For every channel
+ for (int c = 0; c < n_channels; c++)
+ {
+ const T* const in_channel = in_j + c;
+ T* const out_channel = out_j + c*out_channel_stride;
+ *(out_channel) = *(in_channel);
+ }
+ }
+ }
+ }
+}
+
+template <>
+inline void nhwc_to_nchw(
+ const uint32_t* const in, // Input data in NHWC form
+ uint32_t* const out, // Output data in NCHW form
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ int in_batch_stride,
+ int in_row_stride,
+ int in_col_stride,
+ int out_batch_stride,
+ int out_channel_stride,
+ int out_row_stride
+)
+{
+ // Redirect to generic 32-bit implementation
+ nhwc_to_nchw(
+ reinterpret_cast<const int32_t*>(in),
+ reinterpret_cast<int32_t*>(out),
+ n_batches, n_rows, n_cols, n_channels,
+ in_batch_stride, in_row_stride, in_col_stride,
+ out_batch_stride, out_channel_stride, out_row_stride
+ );
+}
+
+template <>
+inline void nhwc_to_nchw(
+ const float* const in, // Input data in NHWC form
+ float* const out, // Output data in NCHW form
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ int in_batch_stride,
+ int in_row_stride,
+ int in_col_stride,
+ int out_batch_stride,
+ int out_channel_stride,
+ int out_row_stride
+)
+{
+ // Redirect to generic 32-bit implementation
+ nhwc_to_nchw(
+ reinterpret_cast<const int32_t*>(in),
+ reinterpret_cast<int32_t*>(out),
+ n_batches, n_rows, n_cols, n_channels,
+ in_batch_stride, in_row_stride, in_col_stride,
+ out_batch_stride, out_channel_stride, out_row_stride
+ );
+}
+
+/*****************************************************************************/
+/* Generic implementation : NHWC -> NCHW
+ */
+template <typename T>
+inline void nhwc_to_nchw(
+ const T* const in, // Input data in NHWC form
+ T* const out, // Output data in NCHW form
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ int in_batch_stride,
+ int in_row_stride,
+ int in_col_stride,
+ int out_batch_stride,
+ int out_channel_stride,
+ int out_row_stride
+)
+{
+ // Fill in stride values
+ in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
+ in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
+ in_batch_stride = (in_batch_stride) ? in_batch_stride
+ : n_rows * in_row_stride;
+
+ out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
+ out_channel_stride = (out_channel_stride) ? out_channel_stride
+ : n_rows * out_row_stride;
+ out_batch_stride = (out_batch_stride) ? out_batch_stride
+ : n_channels * out_channel_stride;
+
+ // Perform the re-ordering
+ // For every batch
+ for (int n = 0; n < n_batches; n++)
+ {
+ const T* const in_batch = in + n*in_batch_stride;
+ T* const out_batch = out + n*out_batch_stride;
+
+ // For every row
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_i = in_batch + i*in_row_stride;
+ T* const out_i = out_batch + i*out_row_stride;
+
+ // For every column
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_j = in_i + j*in_col_stride;
+ T* const out_j = out_i + j;
+
+ // For every channel
+ for (int c = 0; c < n_channels; c++)
+ {
+ const T* const in_channel = in_j + c;
+ T* const out_channel = out_j + c*out_channel_stride;
+ *(out_channel) = *(in_channel);
+ }
+ }
+ }
+ }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
+ const T* const in, // Input in [Output x Input x Height x Width] form
+ T* const out, // Output in [Height x Width x Input x Output] form
+ const int n_output_feature_maps,
+ const int n_input_feature_maps,
+ const int n_rows,
+ const int n_cols,
+ int in_output_feature_map_stride,
+ int in_input_feature_map_stride,
+ int in_row_stride,
+ int out_row_stride,
+ int out_col_stride,
+ int out_input_feature_map_stride
+)
+{
+ // Fill in stride values
+ in_row_stride = (in_row_stride)
+ ? in_row_stride
+ : n_cols;
+ in_input_feature_map_stride = (in_input_feature_map_stride)
+ ? in_input_feature_map_stride
+ : n_rows * in_row_stride;
+ in_output_feature_map_stride = (in_output_feature_map_stride)
+ ? in_output_feature_map_stride
+ : n_input_feature_maps * in_input_feature_map_stride;
+
+ out_input_feature_map_stride = (out_input_feature_map_stride)
+ ? out_input_feature_map_stride
+ : n_output_feature_maps;
+ out_col_stride = (out_col_stride)
+ ? out_col_stride
+ : n_input_feature_maps * out_input_feature_map_stride;
+ out_row_stride = (out_row_stride)
+ ? out_row_stride
+ : n_cols * out_col_stride;
+
+ // Perform the re-ordering
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in + i * in_row_stride;
+ T* out_row = out + i * out_row_stride;
+
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_col = in_row + j;
+ T* const out_col = out_row + j * out_col_stride;
+
+ for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+ {
+ const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+ T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+ for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+ {
+ const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride;
+ T* const out_ofm = out_ifm + ofm;
+ *(out_ofm) = *(in_ofm);
+ }
+ }
+ }
+ }
+}
+
+/*****************************************************************************/
+/* Generic weight re-order implementation.
+ */
+template <typename T>
+inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
+ const T* const in, // Input in [Height x Width x Input x Output] form
+ T* const out, // Output in [Output x Input x Height x Width] form
+ const int n_rows,
+ const int n_cols,
+ const int n_input_feature_maps,
+ const int n_output_feature_maps,
+ int in_row_stride,
+ int in_col_stride,
+ int in_input_feature_map_stride,
+ int out_output_feature_map_stride,
+ int out_input_feature_map_stride,
+ int out_row_stride
+)
+{
+ // Fill in the stride values
+ in_input_feature_map_stride = (in_input_feature_map_stride)
+ ? in_input_feature_map_stride
+ : n_output_feature_maps;
+ in_col_stride = (in_col_stride)
+ ? in_col_stride
+ : n_input_feature_maps * in_input_feature_map_stride;
+ in_row_stride = (in_row_stride)
+ ? in_row_stride
+ : n_cols * in_col_stride;
+
+ out_row_stride = (out_row_stride)
+ ? out_row_stride
+ : n_cols;
+ out_input_feature_map_stride = (out_input_feature_map_stride)
+ ? out_input_feature_map_stride
+ : n_rows * out_row_stride;
+ out_output_feature_map_stride = (out_output_feature_map_stride)
+ ? out_output_feature_map_stride
+ : n_input_feature_maps * out_input_feature_map_stride;
+
+ // Perform the re-ordering
+ for (int i = 0; i < n_rows; i++)
+ {
+ const T* const in_row = in + i * in_row_stride;
+ T* const out_row = out + i * out_row_stride;
+
+ for (int j = 0; j < n_cols; j++)
+ {
+ const T* const in_col = in_row + j * in_col_stride;
+ T* const out_col = out_row + j;
+
+ for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
+ {
+ const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
+ T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
+
+ for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
+ {
+ const T* const in_ofm = in_ifm + ofm;
+ T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride;
+ *(out_ofm) = *(in_ofm);
+ }
+ }
+ }
+ }
+}
+
+} // namespace reorder
diff --git a/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp b/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
new file mode 100644
index 0000000000..6567eeb23d
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
@@ -0,0 +1,177 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cstdlib>
+#include <random>
+
+#include "alloc.hpp"
+
+enum TensorOrder
+{
+ NHWC, ///< [Batch x Height x Width x Channels]
+ NCHW, ///< [Batch x Channels x Height x Width]
+};
+
+struct Tensor4DShape
+{
+ int n_batches, n_rows, n_cols, n_channels;
+ TensorOrder ordering;
+
+ // Create a new tensor with the default (NHWC) ordering
+ inline Tensor4DShape(
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels,
+ const TensorOrder ordering=NHWC
+ ) : n_batches(n_batches),
+ n_rows(n_rows),
+ n_cols(n_cols),
+ n_channels(n_channels),
+ ordering(ordering)
+ {
+ }
+
+ inline int size() const
+ {
+ return n_batches * n_rows * n_cols * n_channels;
+ }
+
+ inline bool TestEq(const Tensor4DShape& other) const
+ {
+ return (n_batches == other.n_batches &&
+ n_rows == other.n_rows &&
+ n_cols == other.n_cols &&
+ n_channels == other.n_channels);
+ }
+};
+
+
+enum WeightOrder
+{
+ HWIO, ///< [Height x Width x Input channels x Output channels]
+ OIHW, ///< [Output channels x Input channels x Height x Width]
+};
+
+struct KernelShape
+{
+ int n_output_channels, n_rows, n_cols, n_input_channels;
+ WeightOrder ordering;
+
+ inline KernelShape(
+ const int n_output_channels,
+ const int n_rows,
+ const int n_cols,
+ const int n_input_channels,
+ const WeightOrder ordering=HWIO
+ ) : n_output_channels(n_output_channels),
+ n_rows(n_rows),
+ n_cols(n_cols),
+ n_input_channels(n_input_channels),
+ ordering(ordering)
+ {
+ }
+
+ inline int size(void) const
+ {
+ return n_output_channels * n_rows * n_cols * n_input_channels;
+ }
+};
+
+
+template <typename ShapeT, typename T>
+class Tensor4D final
+{
+ public:
+ Tensor4D(ShapeT shape) :
+ shape(shape),
+ _data(reinterpret_cast<T*>(ALLOCATE(size_bytes())))
+ {
+ Clear();
+ }
+
+ Tensor4D(const Tensor4D<ShapeT, T>&) = delete;
+ Tensor4D operator=(const Tensor4D<ShapeT, T>&) = delete;
+
+ ~Tensor4D() {
+ free(_data);
+ }
+
+ inline T* ptr() const {
+ return _data;
+ }
+
+ inline size_t size_bytes() const {
+ return shape.size() * sizeof(T);
+ }
+
+ inline T& element(int, int, int, int) const;
+
+ inline void Clear() {
+ Fill(static_cast<T>(0));
+ }
+
+ inline void Fill(T val) {
+ for (int i = 0; i < shape.size(); i++)
+ _data[i] = val;
+ }
+
+ const ShapeT shape;
+
+ private:
+ T* const _data;
+};
+
+
+template <>
+inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const
+{
+ int index;
+ if (shape.ordering == NHWC)
+ {
+ index = ((n*shape.n_rows + i)*shape.n_cols + j)*shape.n_channels + c;
+ }
+ else // NCHW
+ {
+ index = ((n*shape.n_channels + c)*shape.n_rows + i)*shape.n_cols + j;
+ }
+ return _data[index];
+}
+
+
+template <>
+inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const
+{
+ int index;
+ if (shape.ordering == HWIO)
+ {
+ index = ((i*shape.n_cols + j)*shape.n_input_channels + ic)*shape.n_output_channels + oc;
+ }
+ else // OIHW
+ {
+ index = ((oc*shape.n_input_channels + ic)*shape.n_rows + i)*shape.n_cols + j;
+ }
+ return _data[index];
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/common/tensor_utils.hpp b/arm_compute/core/NEON/kernels/convolution/common/tensor_utils.hpp
new file mode 100644
index 0000000000..68a5c6a178
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/tensor_utils.hpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "tensor.hpp"
+
+// Methods to print tensors and weights
+void PrintTensor(const Tensor4D<Tensor4DShape, float>& tensor);
+void PrintWeights(const Tensor4D<KernelShape, float>& weights);
+
+// Test the equivalence of two tensors
+bool CmpTensors(const Tensor4D<Tensor4DShape, float>& a,
+ const Tensor4D<Tensor4DShape, float>& b,
+ const float max_delta=0.0f);
+
+// Fill the tensor with a test pattern
+void TestPattern(Tensor4D<Tensor4DShape, float>& tensor);
+void TestPattern(Tensor4D<KernelShape, float>& weights);
+
+// Fill the tensor with random values
+void Randomise(Tensor4D<Tensor4DShape, float>& tensor, const int seed=0);
+void Randomise(Tensor4D<KernelShape, float>& weights, const int seed=0);
diff --git a/arm_compute/core/NEON/kernels/convolution/common/utils.hpp b/arm_compute/core/NEON/kernels/convolution/common/utils.hpp
new file mode 100644
index 0000000000..d8b9c3b7d3
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/common/utils.hpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+double TimeInUs(void);
+void PrintMatrix(const float* const m, const int M, const int N, const int row_stride);
+
+inline int iceildiv(const int a, const int b) {
+ return (a + b - 1) / b;
+}
+
+template <typename T>
+inline T roundup(const T a, const T b) {
+ return a + b - (a % b);
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp
new file mode 100644
index 0000000000..80b0614015
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp
@@ -0,0 +1,209 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+namespace depthwise
+{
+
+class IDepthwiseConvolution
+{
+public:
+ virtual ~IDepthwiseConvolution() = default;
+ virtual int output_size(const int dim_size, const bool padding_same) const = 0;
+ virtual unsigned int get_window(void) const = 0;
+ virtual void run(const unsigned int start, const unsigned int stop) = 0;
+};
+
+template <
+ int OutputTileRows,
+ int OutputTileCols,
+ int KernelRows,
+ int KernelCols,
+ int StrideRows,
+ int StrideCols,
+ typename TIn,
+ typename TOut
+>
+class DepthwiseConvolution : public IDepthwiseConvolution
+{
+ public:
+ typedef TIn InputType;
+ typedef TOut OutputType;
+
+ // Information about the specific convolution instance
+ static constexpr int output_tile_rows = OutputTileRows;
+ static constexpr int output_tile_cols = OutputTileCols;
+ static constexpr int kernel_rows = KernelRows;
+ static constexpr int kernel_cols = KernelCols;
+ static constexpr int stride_rows = StrideRows;
+ static constexpr int stride_cols = StrideCols;
+ static constexpr int inner_tile_rows = stride_rows * output_tile_rows + kernel_rows - 1;
+ static constexpr int inner_tile_cols = stride_cols * output_tile_cols + kernel_cols - 1;
+
+ /** Create a new depthwise convolution engine.
+ *
+ * @param[in] n_batches Number of batches tensors.
+ * @param[in] n_input_rows Number of rows in input tensor.
+ * @param[in] n_input_cols Number of columns in input tensor.
+ * @param[in] n_channels Number of channels in input and output tensors.
+ * @param[in] padding_same True if padding is SAME, else VALID.
+ * @param[in] weights Pointer to Height x Width x Channel ordered weights.
+ * @param[in] input Pointer to NHWC ordered input tensor.
+ * @param[output] output Pointer to NHWC ordered output tensor.
+ */
+ DepthwiseConvolution(
+ const int n_batches, const int n_input_rows, const int n_input_cols,
+ const int n_channels, const bool padding_same,
+ const TIn* const weights,
+ const TIn* const input,
+ TOut* const output
+ );
+
+ // Cannot copy or move a DepthwiseConvolution.
+ DepthwiseConvolution(DepthwiseConvolution&) = delete;
+ DepthwiseConvolution operator=(DepthwiseConvolution&) = delete;
+
+ /** Get the number of output rows/columns.
+ *
+ * @param[in] dim_size Number of elements in the dimension (rows/columns)
+ * @param[in] same_padding True if the padding is SAME, otherwise false.
+ */
+ static int get_output_size(const int dim_size, const bool padding_same);
+
+ /** Get the number of output rows/columns.
+ *
+ * @param[in] dim_size Number of elements in the dimension (rows/columns)
+ * @param[in] same_padding True if the padding is SAME, otherwise false.
+ */
+ int output_size(const int dim_size, const bool padding_same) const override
+ {
+ return DepthwiseConvolution<OutputTileRows,
+ OutputTileCols,
+ KernelRows,
+ KernelCols,
+ StrideRows,
+ StrideCols,
+ TIn,
+ TOut>::get_output_size(dim_size, padding_same);
+ }
+
+ /** Get the window of work to be performed by an instance of the operator.
+ */
+ unsigned int get_window(void) const override;
+
+ /** Perform a portion of the work associated with the operator.
+ *
+ * Will perform the window of work described by $[start, stop)$.
+ *
+ * @param[in] start Start of the window of work to perform.
+ * @param[in] stop End of the work to perform.
+ */
+ void run(const unsigned int start, const unsigned int stop) override;
+
+ protected:
+ /** Process a tile-row of the tensors.
+ */
+ static void process_tile_row(
+ const int n_channels,
+ const TIn* const weights,
+ const TIn* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ TOut* const outptr,
+ const int out_row_stride,
+ const int out_col_stride,
+ const int row_pad_in_top,
+ const int row_pad_in_left,
+ const int row_pad_in_bottom,
+ const int row_pad_out_bottom,
+ const int n_tiles,
+ const int n_input_cols,
+ const int n_output_cols
+ );
+
+ /** Process a single tile of the tensors.
+ *
+ * @param[in] n_channels Number of channels.
+ * @param[in] weights Pointer to Height x Width x Channels ordered weights.
+ * @param[in] inptr Pointer to the top-left unpadded value of the tile.
+ * @param[in] in_row_stride Stride between rows of the input tensor.
+ * @param[in] in_col_stride Stride between columns of the input tensor.
+ * @param[out] outptr Pointer to the top-left output value for the tile.
+ * @param[in] out_row_stride Stride between rows of the output tensor.
+ * @param[in] out_col_stride Stride between columns of the output tensor.
+ */
+ template <
+ int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+ int out_pad_bottom, int out_pad_right
+ >
+ static void process_tile(
+ const int n_channels,
+ const TIn* const weights,
+ const TIn* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ TOut* const outptr,
+ const int out_row_stride,
+ const int out_col_stride
+ );
+
+ // Type of a pointer to a `process_tile` instance
+ typedef void (*TileFn)(
+ const int,
+ const TIn* const,
+ const TIn* const, const int, const int,
+ TOut* const, const int, const int
+ );
+
+ // Determine the maximum padding values which can be applied to tiles of
+ // the tensors involved in this class of convolution.
+ static constexpr int max_in_pad_top = 2;
+ static constexpr int max_in_pad_left = 2;
+ static constexpr int max_in_pad_bottom = inner_tile_rows - 1;
+ static constexpr int max_in_pad_right = inner_tile_cols - 1;
+ static constexpr int max_out_pad_bottom = output_tile_rows;
+ static constexpr int max_out_pad_right = output_tile_cols;
+
+ /** Array of methods to process tensor tiles.
+ *
+ * Allows dynamic dispatch to specialized implementations based on
+ * different padding configurations.
+ */
+ static const TileFn tile_fns[
+ max_in_pad_top][max_in_pad_left][max_in_pad_bottom][max_in_pad_right][
+ max_out_pad_bottom][max_out_pad_right
+ ];
+
+ private:
+ // Member variables of instances of a convolution engine.
+ const TIn* const _weights;
+ const TIn* const _input;
+ TOut* const _output;
+ const int _n_batches, _n_input_rows, _n_input_cols, _n_channels,
+ _n_output_rows, _n_output_cols, _n_tile_rows, _n_tile_cols;
+ const bool _padding_same;
+};
+
+} // namespace depthwise
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
new file mode 100644
index 0000000000..f9671fc426
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp
@@ -0,0 +1,348 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/*
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ *
+ * NOTE: Header to be included by implementation files only.
+ *
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ */
+
+#include <algorithm>
+#include "arm_compute/core/NEON/kernels/convolution/depthwise/depthwise.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+#pragma once
+
+namespace depthwise
+{
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_output_size(
+ const int dim_size, const bool same_padding
+)
+{
+ return iceildiv(dim_size - (same_padding ? 0 : (KC - 1)), SR);
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::DepthwiseConvolution(
+ const int n_batches, const int n_input_rows, const int n_input_cols,
+ const int n_channels, const bool padding_same,
+ const TIn* const weights,
+ const TIn* const input,
+ TOut* const output
+) : _weights(weights), _input(input), _output(output),
+ _n_batches(n_batches),
+ _n_input_rows(n_input_rows),
+ _n_input_cols(n_input_cols),
+ _n_channels(n_channels),
+ _n_output_rows(get_output_size(n_input_rows, padding_same)),
+ _n_output_cols(get_output_size(n_input_cols, padding_same)),
+ _n_tile_rows(iceildiv(_n_output_rows, output_tile_rows)),
+ _n_tile_cols(iceildiv(_n_output_cols, output_tile_cols)),
+ _padding_same(padding_same)
+{
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+unsigned int DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::get_window() const
+{
+ // TODO Later support parallelisation over tile rows.
+ return 1; // _n_tile_rows;
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::run(
+ const unsigned int start,
+ const unsigned int stop
+)
+{
+ // TODO Later support parallelisation over tile rows.
+ (void) start;
+ (void) stop;
+
+ // Compute input striding
+ const int input_col_stride = _n_channels;
+ const int input_row_stride = _n_input_cols * input_col_stride;
+ const int input_batch_stride = _n_input_rows * input_row_stride;
+
+ // Compute output striding
+ const int output_col_stride = _n_channels;
+ const int output_row_stride = _n_output_cols * output_col_stride;
+ const int output_batch_stride = _n_output_rows * output_row_stride;
+
+ // Compute top and bottom padding for input and output
+ const int input_pad_top = _padding_same ?
+ ((_n_output_rows - 1)*stride_rows + kernel_rows - _n_input_rows) / 2 : 0;
+ const int input_pad_left = _padding_same ?
+ ((_n_output_cols - 1)*stride_cols + kernel_cols - _n_input_cols) / 2 : 0;
+ constexpr int tile_overlap = kernel_rows - 1;
+
+ // Perform the convolution by calling `process_tile_row` for each tile row in
+ // each batch.
+ for (int batch = 0; batch < _n_batches; batch++)
+ {
+ const TIn* const inptr_batch = _input + batch*input_batch_stride;
+ TOut* const outptr_batch = _output + batch*output_batch_stride;
+
+ // Loop over rows of tiles
+ for (int tile_i = 0; tile_i < _n_tile_rows; tile_i++)
+ {
+ // Pointer to the row
+ const int input_row_offset = (tile_i == 0) ? 0 : input_pad_top;
+ const TIn* const inptr_row = (inptr_batch + ((inner_tile_rows - tile_overlap)*tile_i - input_row_offset)*input_row_stride);
+ TOut* const outptr_row = outptr_batch + output_tile_rows * tile_i * output_row_stride;
+
+ // Input padding (top + bottom) for the row
+ const int input_row_top = tile_i*(inner_tile_rows - tile_overlap) - input_pad_top;
+ const int input_row_bottom = input_row_top + inner_tile_rows;
+ const int input_row_pad_top = (tile_i == 0) ? input_pad_top : 0;
+ const int input_row_pad_bottom = std::max(0, input_row_bottom - _n_input_rows);
+
+ // Output padding (bottom) for the row
+ const int output_row_bottom = (tile_i + 1)*output_tile_rows;
+ const int output_row_pad_bottom = std::max(0, output_row_bottom - _n_output_rows);
+
+ // Process the row
+ process_tile_row(
+ _n_channels, _weights,
+ inptr_row, input_row_stride, input_col_stride,
+ outptr_row, output_row_stride, output_col_stride,
+ input_row_pad_top, input_pad_left, input_row_pad_bottom,
+ output_row_pad_bottom,
+ _n_tile_cols, _n_input_cols, _n_output_cols
+ );
+ }
+ }
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::process_tile_row(
+ const int n_channels,
+ const TIn* const weights,
+ const TIn* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ TOut* const outptr,
+ const int out_row_stride,
+ const int out_col_stride,
+ const int row_pad_in_top,
+ const int row_pad_in_left,
+ const int row_pad_in_bottom,
+ const int row_pad_out_bottom,
+ const int n_tiles,
+ const int n_input_cols,
+ const int n_output_cols
+)
+{
+ constexpr int tile_overlap = kernel_cols - 1;
+
+ // Loop over columns of tiles
+ for (int tile_j = 0; tile_j < n_tiles; tile_j++)
+ {
+ // Input padding (left + right) for the tile
+ const int t_pad_in_left = (tile_j == 0) ? row_pad_in_left : 0;
+ const int t_in_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_in_left;
+ const int t_in_end = t_in_start + inner_tile_cols;
+ const int t_pad_in_right = std::max(0, t_in_end - n_input_cols);
+
+ // Output padding (right) for the tile
+ const int t_out_end = (tile_j + 1) * output_tile_cols;
+ const int t_pad_out_right = std::max(0, t_out_end - n_output_cols);
+
+ // Get pointers into the inputs and outputs
+ const int col_offset = (tile_j == 0) ? 0 : row_pad_in_left;
+ const TIn* const inptr_col = (inptr + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*in_col_stride);
+ TOut* const outptr_col = outptr + tile_j * output_tile_cols * out_col_stride;
+
+ // Apply the specific tile processing function
+ tile_fns[row_pad_in_top][t_pad_in_left][row_pad_in_bottom][t_pad_in_right][row_pad_out_bottom][t_pad_out_right](
+ n_channels, weights,
+ inptr_col, in_row_stride, in_col_stride,
+ outptr_col, out_row_stride, out_col_stride
+ );
+ }
+}
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC, typename TIn, typename TOut>
+template <
+ int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+ int out_pad_bottom, int out_pad_right
+>
+void DepthwiseConvolution<OTR, OTC, KR, KC, SR, SC, TIn, TOut>::process_tile(
+ const int n_channels,
+ const TIn* const weights,
+ const TIn* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ TOut* const outptr,
+ const int out_row_stride,
+ const int out_col_stride
+)
+{
+ // Compute valid ranges of the tile
+ constexpr int in_cells_i = inner_tile_rows - in_pad_bottom;
+ constexpr int in_cells_j = inner_tile_cols - in_pad_right;
+ constexpr int out_cells_i = output_tile_rows - out_pad_bottom;
+ constexpr int out_cells_j = output_tile_cols - out_pad_right;
+
+ // Instantiate pointers
+ const TIn* inptr_base = inptr;
+ const TIn* wptr_base = weights;
+ TOut* outptr_base = outptr;
+
+ const int weight_col_stride = n_channels;
+ const int weight_row_stride = kernel_cols * n_channels;
+
+ // Perform the depthwise convolution
+ int channels_remaining = n_channels;
+ for (; channels_remaining; channels_remaining--)
+ {
+ // Load input tile
+ TIn u[inner_tile_rows][inner_tile_cols];
+ for (int i = 0; i < inner_tile_rows; i++)
+ {
+ const TIn* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+ for (int j = 0; j < inner_tile_cols; j++)
+ {
+ if (i < in_pad_top || in_cells_i <= i ||
+ j < in_pad_left || in_cells_j <= j)
+ {
+ u[i][j] = static_cast<TIn>(0);
+ }
+ else
+ {
+ u[i][j] = *(inptr_row + (j - in_pad_left)*in_col_stride);
+ }
+ }
+ }
+ inptr_base++;
+
+ // Load weights tile
+ TIn w[kernel_rows][kernel_cols];
+ for (int i = 0; i < kernel_rows; i++)
+ {
+ const TIn* const wptr_row = wptr_base + i*weight_row_stride;
+ for (int j = 0; j < kernel_cols; j++)
+ {
+ w[i][j] = *(wptr_row + j*weight_col_stride);
+ }
+ }
+ wptr_base++;
+
+ // Perform the convolution
+ TOut v[out_cells_i][out_cells_j];
+ for (int out_i = 0; out_i < out_cells_i; out_i++)
+ {
+ for (int out_j = 0; out_j < out_cells_j; out_j++)
+ {
+ // Clear the accumulator
+ v[out_i][out_j] = static_cast<TOut>(0);
+
+ // Base co-ordinate
+ const int base_i = out_i * stride_rows;
+ const int base_j = out_j * stride_cols;
+
+ // Fill the accumulator
+ for (int in_i = 0; in_i < kernel_rows; in_i++)
+ {
+ const int i = base_i + in_i;
+ for (int in_j = 0; in_j < kernel_cols; in_j++)
+ {
+ const int j = base_j + in_j;
+ v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+ }
+ }
+ }
+ }
+
+ // Store the output tile
+ for (int i = 0; i < out_cells_i; i++)
+ {
+ TOut* const outptr_row = outptr_base + i*out_row_stride;
+ for (int j = 0; j < out_cells_j; j++)
+ {
+ *(outptr_row + j*out_col_stride) = v[i][j];
+ }
+ }
+ outptr_base++;
+ }
+}
+
+
+// New templated struct used solely as a way to provide tile processing
+// specialisations.
+template <int OutputTileRows, int OutputTileCols,
+ int KernelRows, int KernelCols,
+ int StrideRows, int StrideCols,
+ typename TIn, typename TOut>
+struct DepthwiseConvolutionImpl : public DepthwiseConvolution<
+ OutputTileRows, OutputTileCols,
+ KernelRows, KernelCols,
+ StrideRows, StrideCols, TIn, TOut
+>
+{
+ template <
+ int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+ int out_pad_bottom, int out_pad_right
+ >
+ static void process_tile(
+ const int n_channels,
+ const TIn* const weights,
+ const TIn* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ TOut* const outptr,
+ const int out_row_stride,
+ const int out_col_stride
+ )
+ {
+ // By default, redirect to parent. Specialised implementations can be added
+ // by overriding this method.
+ DepthwiseConvolution<OutputTileRows, OutputTileCols,
+ KernelRows, KernelCols,
+ StrideRows, StrideCols,
+ TIn, TOut>::
+ template process_tile<in_pad_top, in_pad_left, in_pad_bottom, in_pad_right,
+ out_pad_bottom, out_pad_right>(
+ n_channels,
+ weights,
+ inptr,
+ in_row_stride,
+ in_col_stride,
+ outptr,
+ out_row_stride,
+ out_col_stride
+ );
+ }
+};
+
+} // namespace depthwise
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
new file mode 100644
index 0000000000..e7f0609b0c
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/impl_fp32_fp32.hpp
@@ -0,0 +1,263 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/*
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ *
+ * NOTE: Header to be included by implementation files only.
+ *
+ * !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ */
+
+#include "arm_compute/core/NEON/kernels/convolution/common/arm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/depthwise/impl_base.hpp"
+
+#pragma once
+
+namespace depthwise
+{
+// Partial specialisation for FP32 to FP32
+template <int OutputTileRows, int OutputTileCols,
+ int KernelRows, int KernelCols,
+ int StrideRows, int StrideCols>
+struct DepthwiseConvolutionImpl<OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols, float, float>
+{
+ typedef DepthwiseConvolution<
+ OutputTileRows, OutputTileCols,
+ KernelRows, KernelCols,
+ StrideRows, StrideCols,
+ float, float
+ > DWC;
+
+ template <
+ int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+ int out_pad_bottom, int out_pad_right
+ >
+ static void process_tile(
+ const int n_channels,
+ const float* const weights,
+ const float* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ float* const outptr,
+ const int out_row_stride,
+ const int out_col_stride
+ );
+};
+
+
+template <int OTR, int OTC, int KR, int KC, int SR, int SC>
+template <
+ int in_pad_top, int in_pad_left, int in_pad_bottom, int in_pad_right,
+ int out_pad_bottom, int out_pad_right
+>
+void DepthwiseConvolutionImpl<OTR, OTC, KR, KC, SR, SC, float, float>::process_tile(
+ const int n_channels,
+ const float* const weights,
+ const float* const inptr,
+ const int in_row_stride,
+ const int in_col_stride,
+ float* const outptr,
+ const int out_row_stride,
+ const int out_col_stride
+)
+{
+ constexpr auto inner_tile_rows = DWC::inner_tile_rows;
+ constexpr auto inner_tile_cols = DWC::inner_tile_cols;
+ constexpr auto kernel_rows = DWC::kernel_rows;
+ constexpr auto kernel_cols = DWC::kernel_cols;
+ constexpr auto output_tile_rows = DWC::output_tile_rows;
+ constexpr auto output_tile_cols = DWC::output_tile_cols;
+ constexpr auto stride_rows = DWC::stride_rows;
+ constexpr auto stride_cols = DWC::stride_cols;
+
+ // Compute valid ranges of the tile
+ constexpr int in_cells_i = inner_tile_rows - in_pad_bottom;
+ constexpr int in_cells_j = inner_tile_cols - in_pad_right;
+ constexpr int out_cells_i = output_tile_rows - out_pad_bottom;
+ constexpr int out_cells_j = output_tile_cols - out_pad_right;
+
+ // Instantiate pointers
+ const float* inptr_base = inptr;
+ const float* wptr_base = weights;
+ float* outptr_base = outptr;
+
+ const int weight_col_stride = n_channels;
+ const int weight_row_stride = kernel_cols * n_channels;
+
+ // Perform the depthwise convolution
+ int channels_remaining = n_channels;
+#ifdef __aarch64__
+ for (; channels_remaining >= 4; channels_remaining -= 4)
+ {
+ // Load input tile
+ float32x4_t u[inner_tile_rows][inner_tile_cols];
+ for (int i = 0; i < inner_tile_rows; i++)
+ {
+ const float* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+ for (int j = 0; j < inner_tile_cols; j++)
+ {
+ if (i < in_pad_top || in_cells_i <= i ||
+ j < in_pad_left || in_cells_j <= j)
+ {
+ u[i][j] = vdupq_n_f32(0.0f);
+ }
+ else
+ {
+ u[i][j] = vld1q_f32(inptr_row + (j - in_pad_left)*in_col_stride);
+ }
+ }
+ }
+ inptr_base += 4;
+
+ // Load weights tile
+ float32x4_t w[kernel_rows][kernel_cols];
+ for (int i = 0; i < kernel_rows; i++)
+ {
+ const float* const wptr_row = wptr_base + i*weight_row_stride;
+ for (int j = 0; j < kernel_cols; j++)
+ {
+ w[i][j] = vld1q_f32(wptr_row + j*weight_col_stride);
+ }
+ }
+ wptr_base += 4;
+
+ // Perform the convolution
+ float32x4_t v[out_cells_i][out_cells_j];
+ for (int out_i = 0; out_i < out_cells_i; out_i++)
+ {
+ for (int out_j = 0; out_j < out_cells_j; out_j++)
+ {
+ // Base co-ordinate
+ const int base_i = out_i * stride_rows;
+ const int base_j = out_j * stride_cols;
+
+ // Fill the accumulator
+ for (int in_i = 0; in_i < kernel_rows; in_i++)
+ {
+ const int i = base_i + in_i;
+ for (int in_j = 0; in_j < kernel_cols; in_j++)
+ {
+ const int j = base_j + in_j;
+ if (in_i == 0 && in_j == 0)
+ {
+ // v[out_i][out_j] = w[in_i][in_j] * u[i][j];
+ v[out_i][out_j] = vmulq_f32(w[in_i][in_j], u[i][j]);
+ }
+ else
+ {
+ // v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+ v[out_i][out_j] = vmlaq_f32(v[out_i][out_j], w[in_i][in_j], u[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ // Store the output tile
+ for (int i = 0; i < out_cells_i; i++)
+ {
+ float* const outptr_row = outptr_base + i*out_row_stride;
+ for (int j = 0; j < out_cells_j; j++)
+ {
+ vst1q_f32(outptr_row + j*out_col_stride, v[i][j]);
+ }
+ }
+ outptr_base += 4;
+ }
+#endif // __aarch64__
+ for (; channels_remaining; channels_remaining--)
+ {
+ // Load input tile
+ float u[inner_tile_rows][inner_tile_cols];
+ for (int i = 0; i < inner_tile_rows; i++)
+ {
+ const float* const inptr_row = inptr_base + (i - in_pad_top)*in_row_stride;
+ for (int j = 0; j < inner_tile_cols; j++)
+ {
+ if (i < in_pad_top || in_cells_i <= i ||
+ j < in_pad_left || in_cells_j <= j)
+ {
+ u[i][j] = static_cast<float>(0);
+ }
+ else
+ {
+ u[i][j] = *(inptr_row + (j - in_pad_left)*in_col_stride);
+ }
+ }
+ }
+ inptr_base++;
+
+ // Load weights tile
+ float w[kernel_rows][kernel_cols];
+ for (int i = 0; i < kernel_rows; i++)
+ {
+ const float* const wptr_row = wptr_base + i*weight_row_stride;
+ for (int j = 0; j < kernel_cols; j++)
+ {
+ w[i][j] = *(wptr_row + j*weight_col_stride);
+ }
+ }
+ wptr_base++;
+
+ // Perform the convolution
+ float v[out_cells_i][out_cells_j];
+ for (int out_i = 0; out_i < out_cells_i; out_i++)
+ {
+ for (int out_j = 0; out_j < out_cells_j; out_j++)
+ {
+ // Clear the accumulator
+ v[out_i][out_j] = static_cast<float>(0);
+
+ // Base co-ordinate
+ const int base_i = out_i * stride_rows;
+ const int base_j = out_j * stride_cols;
+
+ // Fill the accumulator
+ for (int in_i = 0; in_i < kernel_rows; in_i++)
+ {
+ const int i = base_i + in_i;
+ for (int in_j = 0; in_j < kernel_cols; in_j++)
+ {
+ const int j = base_j + in_j;
+ v[out_i][out_j] += w[in_i][in_j] * u[i][j];
+ }
+ }
+ }
+ }
+
+ // Store the output tile
+ for (int i = 0; i < out_cells_i; i++)
+ {
+ float* const outptr_row = outptr_base + i*out_row_stride;
+ for (int j = 0; j < out_cells_j; j++)
+ {
+ *(outptr_row + j*out_col_stride) = v[i][j];
+ }
+ }
+ outptr_base++;
+ }
+}
+
+} // namespace depthwise
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
new file mode 100644
index 0000000000..663b3c414f
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/batched_blocked_gemm.hpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+namespace winograd
+{
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+class BatchedBlockedGemm
+{
+ public:
+ /** Create a new batched blocked GEMM operator. */
+ BatchedBlockedGemm(
+ const unsigned int n_gemms,
+ const int M, const int K, const int N,
+ const int a_matrix_stride,
+ const int a_row_stride,
+ const int b_matrix_stride,
+ const int b_row_stride,
+ const int c_matrix_stride,
+ const int c_row_stride,
+ const TIn* const a_ptr,
+ const TIn* const b_ptr,
+ TOut* const c_ptr
+ );
+
+ BatchedBlockedGemm(const BatchedBlockedGemm&) = delete;
+ BatchedBlockedGemm operator=(const BatchedBlockedGemm&) = delete;
+
+ /** Get a window of work performed by the operator. */
+ unsigned int get_window() const;
+
+ /** Perform a portion of the work of the operator. */
+ void run(const unsigned int start, const unsigned int stop);
+
+ private:
+ const unsigned int n_gemms;
+ const int M, N, K;
+ const int a_matrix_stride, a_row_stride;
+ const int b_matrix_stride, b_row_stride;
+ const int c_matrix_stride, c_row_stride;
+ const TIn* const a_ptr;
+ const TIn* const b_ptr;
+ TOut* const c_ptr;
+};
+
+} // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
new file mode 100644
index 0000000000..62a20c9eea
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm.hpp
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+template <typename TIn, typename TOut>
+inline void Gemm(const TIn* const a, const TIn* const b, TOut *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride,
+ const bool a_transposed=false,
+ const bool b_transposed=false) {
+ // Array access methods
+ const auto A = [a, a_transposed, M, K, a_row_stride] (const int i, const int j) -> TIn {
+ return a[(!a_transposed) ? i*a_row_stride + j : i + j*M];
+ };
+
+ const auto B = [b, b_transposed, K, N, b_row_stride] (const int i, const int j) -> TIn {
+ return b[(!b_transposed) ? i*b_row_stride + j : i + j*N];
+ };
+
+ const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+ return c[i*c_row_stride + j];
+ };
+
+ // Perform the matrix multiplication
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ for (int k = 0; k < K; k++) {
+ C(i, j) += A(i, k) * B(k, j);
+ }
+ }
+ }
+}
+
+template <const int M_BLOCK, const int N_BLOCK, typename TIn, typename TOut>
+inline void BlockedGemm(
+ const TIn* const a, const TIn* const b, TOut *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ // Array access methods
+ const auto A = [a, M, K, a_row_stride] (const int i, const int j) -> TIn {
+ return a[i*a_row_stride + j];
+ };
+
+ const auto B = [b, K, N, b_row_stride] (const int i, const int j) -> TIn {
+ return b[i*b_row_stride + j];
+ };
+
+ const auto C = [c, c_row_stride] (const int i, const int j) -> TOut& {
+ return c[i*c_row_stride + j];
+ };
+
+ const int M_BLOCKS = iceildiv(M, M_BLOCK);
+ const int N_BLOCKS = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < M_BLOCKS; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < N_BLOCKS; nblock++) {
+ // Create an appropriately sized block of accumulators
+ TOut accum[M_BLOCK][N_BLOCK];
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ accum[i][j] = static_cast<TOut>(0);
+ }
+ }
+
+ // Perform this portion of the matrix multiply
+ for (int k = 0; k < K; k++) {
+ // Load elements of A
+ TIn elems_a[M_BLOCK];
+ for (int i = 0; i < M_BLOCK; i++) {
+ elems_a[i] = A(mblock*M_BLOCK + i, k);
+ }
+
+ // Load elements of B
+ TIn elems_b[N_BLOCK];
+ for (int j = 0; j < N_BLOCK; j++) {
+ elems_b[j] = B(k, nblock*N_BLOCK + j);
+ }
+
+ // Perform the partial matrix multiply
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ accum[i][j] += elems_a[i] * elems_b[j];
+ }
+ }
+ }
+
+ // Store the partial product
+ for (int i = 0; i < M_BLOCK; i++) {
+ for (int j = 0; j < N_BLOCK; j++) {
+ C(mblock*M_BLOCK + i, nblock*N_BLOCK + j) = accum[i][j];
+ }
+ }
+ }
+ }
+}
+
+#include "gemm/a64_sgemm.hpp"
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
new file mode 100644
index 0000000000..8073cb1896
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm.hpp
@@ -0,0 +1,355 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include <cassert>
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+#ifdef __aarch64__
+
+template <>
+inline void BlockedGemm<8, 12, float, float>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int M_BLOCK = 8;
+ const int N_BLOCK = 12;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = K;
+
+ asm volatile (
+ // Create an 8x12 block of accumulators
+ " A_1 .req v27\n"
+ "sA_1 .req s27\n"
+ " A_2 .req v28\n"
+ "sA_2 .req s28\n"
+ " A_3 .req v29\n"
+ "sA_3 .req s29\n"
+ " A_4 .req v30\n"
+ "sA_4 .req s30\n"
+
+ " B_1 .req v24\n" " B_2 .req v25\n" " B_3 .req v26\n"
+ "qB_1 .req q24\n" "qB_2 .req q25\n" "qB_3 .req q26\n"
+
+ " C_11 .req v0\n" " C_12 .req v1\n" " C_13 .req v2\n"
+ " C_21 .req v3\n" " C_22 .req v4\n" " C_23 .req v5\n"
+ " C_31 .req v6\n" " C_32 .req v7\n" " C_33 .req v8\n"
+ " C_41 .req v9\n" " C_42 .req v10\n" " C_43 .req v11\n"
+ " C_51 .req v12\n" " C_52 .req v13\n" " C_53 .req v14\n"
+ " C_61 .req v15\n" " C_62 .req v16\n" " C_63 .req v17\n"
+ " C_71 .req v18\n" " C_72 .req v19\n" " C_73 .req v20\n"
+ " C_81 .req v21\n" " C_82 .req v22\n" " C_83 .req v23\n"
+
+ "qC_11 .req q0\n" "qC_12 .req q1\n" "qC_13 .req q2\n"
+ "qC_21 .req q3\n" "qC_22 .req q4\n" "qC_23 .req q5\n"
+ "qC_31 .req q6\n" "qC_32 .req q7\n" "qC_33 .req q8\n"
+ "qC_41 .req q9\n" "qC_42 .req q10\n" "qC_43 .req q11\n"
+ "qC_51 .req q12\n" "qC_52 .req q13\n" "qC_53 .req q14\n"
+ "qC_61 .req q15\n" "qC_62 .req q16\n" "qC_63 .req q17\n"
+ "qC_71 .req q18\n" "qC_72 .req q19\n" "qC_73 .req q20\n"
+ "qC_81 .req q21\n" "qC_82 .req q22\n" "qC_83 .req q23\n"
+
+ "aptr1 .req x17\n"
+ "aptr2 .req x18\n"
+ "aptr3 .req x19\n"
+ "aptr4 .req x20\n"
+ "aptr5 .req x21\n"
+ "aptr6 .req x22\n"
+ "aptr7 .req x23\n"
+
+ // Initialise accumulators with 0
+ // Initialise pointers
+ "movi C_11.4s, #0\n"
+ "add aptr1, %x[aptr], %x[a_row_stride]\n"
+ "movi C_12.4s, #0\n"
+ "add aptr2, aptr1, %x[a_row_stride]\n"
+ "movi C_13.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride]\n"
+ "movi C_21.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride]\n"
+ "movi C_22.4s, #0\n"
+ "add aptr5, aptr4, %x[a_row_stride]\n"
+ "movi C_23.4s, #0\n"
+ "add aptr6, aptr5, %x[a_row_stride]\n"
+ "movi C_31.4s, #0\n"
+ "add aptr7, aptr6, %x[a_row_stride]\n"
+ "movi C_32.4s, #0\n"
+ "ldr qB_1, [%x[bptr]]\n"
+ "movi C_33.4s, #0\n"
+ "ldr qB_2, [%x[bptr], #0x10]\n"
+ "movi C_41.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x00]\n"
+ "movi C_42.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x10]\n"
+ "movi C_43.4s, #0\n"
+ "prfm pldl1keep, [%x[bptr], #0x20]\n"
+ "movi C_51.4s, #0\n"
+ "prfm pldl1keep, [%x[aptr], #0x00]\n"
+ "movi C_52.4s, #0\n"
+ "prfm pldl1keep, [ aptr1, #0x00]\n"
+ "movi C_53.4s, #0\n"
+ "prfm pldl1keep, [ aptr2, #0x00]\n"
+ "movi C_61.4s, #0\n"
+ "prfm pldl1keep, [ aptr3, #0x00]\n"
+ "movi C_62.4s, #0\n"
+ "prfm pldl1keep, [ aptr4, #0x00]\n"
+ "movi C_63.4s, #0\n"
+ "prfm pldl1keep, [ aptr5, #0x00]\n"
+ "movi C_71.4s, #0\n"
+ "prfm pldl1keep, [ aptr6, #0x00]\n"
+ "movi C_72.4s, #0\n"
+ "prfm pldl1keep, [ aptr7, #0x00]\n"
+ "movi C_73.4s, #0\n"
+ "ldr sA_1, [%x[aptr]], #0x4\n"
+ "movi C_81.4s, #0\n"
+ "ldr sA_2, [ aptr1], #0x4\n"
+ "movi C_82.4s, #0\n"
+ "ldr sA_3, [ aptr2], #0x4\n"
+ "movi C_83.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 2f\n"
+
+ "1:"
+ "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+ "ldr qB_3, [%x[bptr], #0x20]\n"
+ "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+ "ldr sA_4, [ aptr3], #0x4\n"
+ "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+ "ldr sA_1, [ aptr4], #0x04\n"
+
+ "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride]\n"
+ "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+ "prfm pldl1keep, [ aptr3, #0x10]\n"
+ "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+ "ldr sA_2, [ aptr5], #0x04\n"
+
+ "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x00]\n"
+ "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x10]\n"
+ "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+ "ldr sA_3, [ aptr6], #0x04\n"
+
+ "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [%x[bptr], #0x20]\n"
+ "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [ aptr4, #0x10]\n"
+ "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+ "ldr sA_4, [ aptr7], #0x04\n"
+
+ "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+ "prfm pldl1keep, [ aptr5, #0x10]\n"
+ "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+ "prfm pldl1keep, [ aptr6, #0x10]\n"
+ "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+ "ldr sA_1, [%x[aptr]], #0x04\n"
+
+ "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+ "prfm pldl1keep, [ aptr7, #0x10]\n"
+ "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+ "ldr sA_2, [ aptr1], #0x04\n"
+
+ "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [%x[aptr], #0x10]\n"
+ "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+ "prfm pldl1keep, [ aptr1, #0x10]\n"
+ "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+ "ldr sA_3, [ aptr2], #0x04\n"
+
+ "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+ "prfm pldl1keep, [ aptr2, #0x10]\n"
+ "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+ "ldp qB_1, qB_2, [%x[bptr]]\n"
+ "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+ "bne 1b\n"
+
+ "2:"
+ "fmla C_11.4s, B_1.4s, A_1.s[0]\n"
+ "ldr qB_3, [%x[bptr], #0x20]\n"
+ "fmla C_12.4s, B_2.4s, A_1.s[0]\n"
+ "stp qC_11, qC_12, [%x[cptr]]\n"
+ "fmla C_13.4s, B_3.4s, A_1.s[0]\n"
+ "str qC_13, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_1, [ aptr4], #0x04\n"
+
+ "fmla C_21.4s, B_1.4s, A_2.s[0]\n"
+ "ldr sA_4, [ aptr3], #0x4\n"
+ "fmla C_22.4s, B_2.4s, A_2.s[0]\n"
+ "stp qC_21, qC_22, [%x[cptr]]\n"
+ "fmla C_23.4s, B_3.4s, A_2.s[0]\n"
+ "str qC_23, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_2, [ aptr5], #0x04\n"
+
+ "fmla C_31.4s, B_1.4s, A_3.s[0]\n"
+ "fmla C_32.4s, B_2.4s, A_3.s[0]\n"
+ "stp qC_31, qC_32, [%x[cptr]]\n"
+ "fmla C_33.4s, B_3.4s, A_3.s[0]\n"
+ "str qC_33, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_3, [ aptr6], #0x04\n"
+
+ "fmla C_41.4s, B_1.4s, A_4.s[0]\n"
+ "fmla C_42.4s, B_2.4s, A_4.s[0]\n"
+ "stp qC_41, qC_42, [%x[cptr]]\n"
+ "fmla C_43.4s, B_3.4s, A_4.s[0]\n"
+ "str qC_43, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+ "ldr sA_4, [ aptr7], #0x04\n"
+
+ "fmla C_51.4s, B_1.4s, A_1.s[0]\n"
+ "fmla C_52.4s, B_2.4s, A_1.s[0]\n"
+ "stp qC_51, qC_52, [%x[cptr]]\n"
+ "fmla C_53.4s, B_3.4s, A_1.s[0]\n"
+ "str qC_53, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_61.4s, B_1.4s, A_2.s[0]\n"
+ "fmla C_62.4s, B_2.4s, A_2.s[0]\n"
+ "stp qC_61, qC_62, [%x[cptr]]\n"
+ "fmla C_63.4s, B_3.4s, A_2.s[0]\n"
+ "str qC_63, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_71.4s, B_1.4s, A_3.s[0]\n"
+ "fmla C_72.4s, B_2.4s, A_3.s[0]\n"
+ "stp qC_71, qC_72, [%x[cptr]]\n"
+ "fmla C_73.4s, B_3.4s, A_3.s[0]\n"
+ "str qC_73, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ "fmla C_81.4s, B_1.4s, A_4.s[0]\n"
+ "fmla C_82.4s, B_2.4s, A_4.s[0]\n"
+ "stp qC_81, qC_82, [%x[cptr]]\n"
+ "fmla C_83.4s, B_3.4s, A_4.s[0]\n"
+ "str qC_83, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride]\n"
+
+ // Clear aliases
+ ".unreq aptr1\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+ ".unreq aptr5\n"
+ ".unreq aptr6\n"
+ ".unreq aptr7\n"
+
+ ".unreq A_1\n" ".unreq A_2\n" ".unreq A_3\n" ".unreq A_4\n"
+ ".unreq sA_1\n" ".unreq sA_2\n" ".unreq sA_3\n" ".unreq sA_4\n"
+
+ ".unreq B_1\n" ".unreq B_2\n" ".unreq B_3\n"
+ ".unreq qB_1\n" ".unreq qB_2\n" ".unreq qB_3\n"
+
+ ".unreq C_11\n" ".unreq C_12\n" ".unreq C_13\n"
+ ".unreq C_21\n" ".unreq C_22\n" ".unreq C_23\n"
+ ".unreq C_31\n" ".unreq C_32\n" ".unreq C_33\n"
+ ".unreq C_41\n" ".unreq C_42\n" ".unreq C_43\n"
+ ".unreq C_51\n" ".unreq C_52\n" ".unreq C_53\n"
+ ".unreq C_61\n" ".unreq C_62\n" ".unreq C_63\n"
+ ".unreq C_71\n" ".unreq C_72\n" ".unreq C_73\n"
+ ".unreq C_81\n" ".unreq C_82\n" ".unreq C_83\n"
+
+ ".unreq qC_11\n" ".unreq qC_12\n" ".unreq qC_13\n"
+ ".unreq qC_21\n" ".unreq qC_22\n" ".unreq qC_23\n"
+ ".unreq qC_31\n" ".unreq qC_32\n" ".unreq qC_33\n"
+ ".unreq qC_41\n" ".unreq qC_42\n" ".unreq qC_43\n"
+ ".unreq qC_51\n" ".unreq qC_52\n" ".unreq qC_53\n"
+ ".unreq qC_61\n" ".unreq qC_62\n" ".unreq qC_63\n"
+ ".unreq qC_71\n" ".unreq qC_72\n" ".unreq qC_73\n"
+ ".unreq qC_81\n" ".unreq qC_82\n" ".unreq qC_83\n"
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
+ "v29", "v30", "x17", "x18", "x19", "x20", "x21", "x22", "x23"
+ );
+ }
+ }
+}
+
+/*****************************************************************************/
+/* 4x16 blocked GEMM with specialised tails
+ */
+#include "a64_sgemm_4x16.hpp"
+
+template <>
+inline void BlockedGemm<4, 16, float, float>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ // Despatch based on tail of K
+ switch (K % 4) {
+ case 3:
+ sgemm_4x16_impl<3>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 2:
+ sgemm_4x16_impl<2>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 1:
+ sgemm_4x16_impl<1>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ case 0:
+ sgemm_4x16_impl<0>(
+ a, b, c, M, K, N, a_row_stride, b_row_stride, c_row_stride
+ );
+ break;
+ default:
+ assert(false);
+ }
+}
+
+#endif // __aarch64__
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
new file mode 100644
index 0000000000..5cd37de7a0
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/gemm/a64_sgemm_4x16.hpp
@@ -0,0 +1,1446 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+template <const unsigned int tail>
+inline void sgemm_4x16_impl(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+);
+
+template <>
+inline void sgemm_4x16_impl<0>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 0;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC12.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC13.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC14.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC21.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC22.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC23.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC24.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC31.4s, #0\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 2f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "2:" // Tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<1>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 1;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "movi vC31.4s, #0\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "ldr sA3, [ aptr3], #0x04\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "ldr sA4, [ aptr4], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<2>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 2;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "movi vC31.4s, #0\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr dA3, [ aptr3], #0x08\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr dA4, [ aptr4], #0x08\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
+
+template <>
+inline void sgemm_4x16_impl<3>(
+ const float* const a, const float* const b, float *c,
+ const int M, const int K, const int N,
+ const int a_row_stride,
+ const int b_row_stride,
+ const int c_row_stride
+) {
+ const int TAIL_SIZE = 3;
+ const int M_BLOCK = 4;
+ const int N_BLOCK = 16;
+
+ const int m_blocks = iceildiv(M, M_BLOCK);
+ const int n_blocks = iceildiv(N, N_BLOCK);
+
+ // For each block of output rows
+ for (int mblock = 0; mblock < m_blocks; mblock++) {
+ // For each block of output columns
+ for (int nblock = 0; nblock < n_blocks; nblock++) {
+ const float *aptr = a + mblock*M_BLOCK*a_row_stride;
+ const float *bptr = b + nblock*N_BLOCK;
+ float *cptr = c + mblock*M_BLOCK*c_row_stride + nblock*N_BLOCK;
+ int k = (K - TAIL_SIZE) / 4;
+
+ asm volatile(
+ "aptr2 .req X20\n"
+ "aptr3 .req X21\n"
+ "aptr4 .req X22\n"
+ "vC11 .req v0\n" "vC12 .req v1\n" "vC13 .req v2\n" "vC14 .req v3\n"
+ "qC11 .req q0\n" "qC12 .req q1\n" "qC13 .req q2\n" "qC14 .req q3\n"
+ "vC21 .req v4\n" "vC22 .req v5\n" "vC23 .req v6\n" "vC24 .req v7\n"
+ "qC21 .req q4\n" "qC22 .req q5\n" "qC23 .req q6\n" "qC24 .req q7\n"
+ "vC31 .req v8\n" "vC32 .req v9\n" "vC33 .req v10\n" "vC34 .req v11\n"
+ "qC31 .req q8\n" "qC32 .req q9\n" "qC33 .req q10\n" "qC34 .req q11\n"
+ "vC41 .req v12\n" "vC42 .req v13\n" "vC43 .req v14\n" "vC44 .req v15\n"
+ "qC41 .req q12\n" "qC42 .req q13\n" "qC43 .req q14\n" "qC44 .req q15\n"
+ "vA1 .req v16\n" "qA1 .req q16\n" "dA1 .req d16\n" "sA1 .req s16\n"
+ "vA2 .req v17\n" "qA2 .req q17\n" "dA2 .req d17\n" "sA2 .req s17\n"
+ "vA3 .req v18\n" "qA3 .req q18\n" "dA3 .req d18\n" "sA3 .req s18\n"
+ "vA4 .req v19\n" "qA4 .req q19\n" "dA4 .req d19\n" "sA4 .req s19\n"
+ "vB1 .req v20\n" "qB1 .req q20\n"
+ "vB2 .req v21\n" "qB2 .req q21\n"
+ "vB3 .req v22\n" "qB3 .req q22\n"
+ "vB4 .req v23\n" "qB4 .req q23\n"
+
+ // Clear accumulators, initialise pointers
+ "movi vC11.4s, #0\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "movi vC12.4s, #0\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "movi vC13.4s, #0\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "movi vC14.4s, #0\n"
+ "add aptr2, %x[aptr], %x[a_row_stride_bytes]\n"
+ "movi vC21.4s, #0\n"
+ "add aptr3, aptr2, %x[a_row_stride_bytes]\n"
+ "movi vC22.4s, #0\n"
+ "add aptr4, aptr3, %x[a_row_stride_bytes]\n"
+ "movi vC23.4s, #0\n"
+ "cbnz %x[k], 3f\n"
+
+ // Prepare for tail in K
+ "movi vC24.4s, #0\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "movi vC31.4s, #0\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "b 2f\n" // Jump to tail
+
+ "3:" // Prepare for loop over K
+ "movi vC24.4s, #0\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "movi vC31.4s, #0\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "movi vC32.4s, #0\n"
+ "movi vC33.4s, #0\n"
+ "movi vC34.4s, #0\n"
+ "movi vC41.4s, #0\n"
+ "movi vC42.4s, #0\n"
+ "movi vC43.4s, #0\n"
+ "movi vC44.4s, #0\n"
+ "subs %x[k], %x[k], #1\n"
+ "beq 4f\n"
+
+ "1:" // Loop proper
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "subs %x[k], %x[k], #1\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr qA1, [%x[aptr]], #0x10\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr qA2, [ aptr2], #0x10\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+ "bne 1b\n"
+
+ "4:" // Tail iteration
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qA3, [ aptr3], #0x10\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr qA4, [ aptr4], #0x10\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[2]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[2]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[2]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[2]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[2]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[2]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[2]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[2]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[2]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[2]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[2]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[2]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[2]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[2]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[2]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[2]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[3]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[3]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[3]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[3]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[3]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[3]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[3]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[3]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[3]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[3]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[3]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[3]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[3]\n"
+ "ldr dA1, [%x[aptr]], #0x08\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[3]\n"
+ "ldr dA2, [ aptr2], #0x08\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[3]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[3]\n"
+
+ "2:" // Common tail
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr dA3, [ aptr3], #0x08\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "ldr dA4, [ aptr4], #0x08\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[1]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[1]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[1]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[1]\n"
+ "add %x[bptr], %x[bptr], %x[b_row_stride_bytes]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[1]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[1]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[1]\n"
+ "ldr qB1, [%x[bptr], #0x00]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[1]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[1]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[1]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[1]\n"
+ "ldr qB2, [%x[bptr], #0x10]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[1]\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[1]\n"
+ "ldr sA1, [%x[aptr]], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[1]\n"
+ "ldr sA2, [ aptr2], #0x04\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[1]\n"
+ "ldr qB3, [%x[bptr], #0x20]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[1]\n"
+
+ "fmla vC11.4s, vB1.4s, vA1.s[0]\n"
+ "ldr qB4, [%x[bptr], #0x30]\n"
+ "fmla vC12.4s, vB2.4s, vA1.s[0]\n"
+ "stp qC11, qC12, [%x[cptr], #0x00]\n"
+ "fmla vC13.4s, vB3.4s, vA1.s[0]\n"
+ "ldr sA3, [ aptr3], #0x04\n"
+ "fmla vC14.4s, vB4.4s, vA1.s[0]\n"
+ "stp qC13, qC14, [%x[cptr], #0x20]\n"
+ "fmla vC21.4s, vB1.4s, vA2.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC22.4s, vB2.4s, vA2.s[0]\n"
+ "stp qC21, qC22, [%x[cptr], #0x00]\n"
+ "fmla vC23.4s, vB3.4s, vA2.s[0]\n"
+ "ldr sA4, [ aptr4], #0x04\n"
+ "fmla vC24.4s, vB4.4s, vA2.s[0]\n"
+ "stp qC23, qC24, [%x[cptr], #0x20]\n"
+ "fmla vC31.4s, vB1.4s, vA3.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC32.4s, vB2.4s, vA3.s[0]\n"
+ "stp qC31, qC32, [%x[cptr], #0x00]\n"
+ "fmla vC33.4s, vB3.4s, vA3.s[0]\n"
+ "fmla vC34.4s, vB4.4s, vA3.s[0]\n"
+ "stp qC33, qC34, [%x[cptr], #0x20]\n"
+ "fmla vC41.4s, vB1.4s, vA4.s[0]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+ "fmla vC42.4s, vB2.4s, vA4.s[0]\n"
+ "stp qC41, qC42, [%x[cptr], #0x00]\n"
+ "fmla vC43.4s, vB3.4s, vA4.s[0]\n"
+ "fmla vC44.4s, vB4.4s, vA4.s[0]\n"
+ "stp qC43, qC44, [%x[cptr], #0x20]\n"
+ "add %x[cptr], %x[cptr], %x[c_row_stride_bytes]\n"
+
+ ".unreq vB4\n" ".unreq qB4\n"
+ ".unreq vB3\n" ".unreq qB3\n"
+ ".unreq vB2\n" ".unreq qB2\n"
+ ".unreq vB1\n" ".unreq qB1\n"
+ ".unreq vA4\n" ".unreq qA4\n" ".unreq dA4\n" ".unreq sA4\n"
+ ".unreq vA3\n" ".unreq qA3\n" ".unreq dA3\n" ".unreq sA3\n"
+ ".unreq vA2\n" ".unreq qA2\n" ".unreq dA2\n" ".unreq sA2\n"
+ ".unreq vA1\n" ".unreq qA1\n" ".unreq dA1\n" ".unreq sA1\n"
+ ".unreq qC41\n" ".unreq qC42\n" ".unreq qC43\n" ".unreq qC44\n"
+ ".unreq vC41\n" ".unreq vC42\n" ".unreq vC43\n" ".unreq vC44\n"
+ ".unreq qC31\n" ".unreq qC32\n" ".unreq qC33\n" ".unreq qC34\n"
+ ".unreq vC31\n" ".unreq vC32\n" ".unreq vC33\n" ".unreq vC34\n"
+ ".unreq qC21\n" ".unreq qC22\n" ".unreq qC23\n" ".unreq qC24\n"
+ ".unreq vC21\n" ".unreq vC22\n" ".unreq vC23\n" ".unreq vC24\n"
+ ".unreq qC11\n" ".unreq qC12\n" ".unreq qC13\n" ".unreq qC14\n"
+ ".unreq vC11\n" ".unreq vC12\n" ".unreq vC13\n" ".unreq vC14\n"
+ ".unreq aptr2\n"
+ ".unreq aptr3\n"
+ ".unreq aptr4\n"
+
+ : [aptr] "+r" (aptr),
+ [bptr] "+r" (bptr),
+ [cptr] "+r" (cptr),
+ [k] "+r" (k)
+ : [a_row_stride_bytes] "r" (a_row_stride * sizeof(float)),
+ [b_row_stride_bytes] "r" (b_row_stride * sizeof(float)),
+ [c_row_stride_bytes] "r" (c_row_stride * sizeof(float))
+ : "cc", "memory", "x20", "x21", "x22",
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
+ "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20",
+ "v21", "v22", "v23"
+ );
+ }
+ }
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
new file mode 100644
index 0000000000..6dd8f5460a
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/input.hpp
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+
+namespace winograd
+{
+ /***************************************************************************/
+ /* Instance-less API */
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+ )
+ {
+ // Compute the padding required on each edge of the image
+ const bool base_padding = (padding_type == PADDING_SAME) ? 1 : 0;
+ const int pad_top = base_padding;
+ const int pad_left = base_padding;
+ const int tile_overlap = kernel_rows - 1;
+
+ // Compute striding values (assuming NHWC ordered data)
+ const int input_col_stride = input_shape.n_channels;
+ const int input_row_stride = input_shape.n_cols * input_col_stride;
+ const int input_batch_stride = input_shape.n_rows * input_row_stride;
+ const int output_col_stride = matrix_row_stride;
+ const int output_row_stride = tile_N * output_col_stride;
+
+ // Loop over batches
+ for (int batch = 0; batch < input_shape.n_batches; batch++)
+ {
+ // Pointer to the batch
+ const T* const input_base_batch = inptr + batch * input_batch_stride;
+ T* const outptr_base_batch = outptr_base + batch * matrix_batch_stride;
+
+ // Loop over rows of tiles
+ for (int tile_i = 0; tile_i < tile_M; tile_i++)
+ {
+ // Pointer to the row
+ const int row_offset = (tile_i == 0) ?
+ 0 : ((padding_type == PADDING_VALID) ? 0 : 1);
+ const T* const input_base_row = (
+ input_base_batch + ((inner_tile_rows - (kernel_rows - 1))*tile_i - row_offset)*input_row_stride
+ );
+ T* const outptr_base_row = outptr_base_batch + tile_i*output_row_stride;
+
+ // Padding (top + bottom) for the row
+ const int row_top = tile_i*(inner_tile_rows - tile_overlap) - pad_top;
+ const int row_bottom = row_top + inner_tile_rows;
+ const int row_pad_top = (tile_i == 0) ? pad_top : 0;
+ const int row_pad_bottom = (row_bottom <= input_shape.n_rows) ? 0 : row_bottom - input_shape.n_rows;
+
+ // Process the row
+ process_tile_row(
+ tile_N, input_shape.n_channels,
+ input_base_row, input_row_stride, input_col_stride,
+ outptr_base_row, matrix_stride, matrix_row_stride,
+ row_pad_top, pad_left, row_pad_bottom, input_shape.n_cols
+ );
+ }
+ }
+ }
+
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::InputTransform<T>::process_tile_row(
+ const int tile_N,
+ int n_channels,
+ const T* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const int pad_top,
+ const int row_pad_left,
+ const int pad_bottom,
+ const int n_cols
+ )
+ {
+ constexpr int tile_overlap = kernel_cols - 1;
+
+ // Loop over columns of tiles
+ for (int tile_j = 0; tile_j < tile_N; tile_j++)
+ {
+ // Padding (left + right) for the tile
+ const int t_pad_left = (tile_j == 0) ? row_pad_left : 0;
+ const int t_start = tile_j*(inner_tile_cols - tile_overlap) - row_pad_left;
+ const int t_end = t_start + inner_tile_cols;
+ const int t_pad_right = (t_end <= n_cols) ? 0 : t_end - n_cols;
+
+ // Get pointers into the inputs and outputs
+ const int col_offset = (tile_j == 0) ? 0 : row_pad_left;
+ const T* const input_base_col = (
+ input_base + ((inner_tile_cols - tile_overlap)*tile_j - col_offset)*input_col_stride
+ );
+ T* const outptr = matrix_base + tile_j*matrix_row_stride;
+
+ // Apply the specific tile processing function
+ tile_fns[pad_top][t_pad_left][pad_bottom][t_pad_right](
+ n_channels,
+ input_base_col,
+ input_row_stride,
+ input_col_stride,
+ outptr,
+ matrix_stride
+ );
+ }
+ }
+
+ /***************************************************************************/
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::InputTransform(
+ const T* const input, /** Input tensor data */
+ const int n_batches, /** Number of batches in input tensor. */
+ const int n_rows, /** Number of rows in input tensor. */
+ const int n_cols, /** Number of columns in input tensor. */
+ const int n_channels, /** Number of channels in input tensor. */
+ const PaddingType padding, /** Padding type. */
+ T* const output, /** Base of output matrices. */
+ const int matrix_stride, /** Stride between output matrices. */
+ const int matrix_row_stride /** Stride within matrices. */
+ ) : _inptr(input), _outptr(output),
+ _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels),
+ _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+ _tiles_M(iceildiv((padding == PADDING_SAME) ? n_rows : n_rows - 2, output_tile_rows)),
+ _tiles_N(iceildiv((padding == PADDING_SAME) ? n_cols : n_cols - 2, output_tile_cols)),
+ _padding_type(padding)
+ {
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ unsigned int WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::get_window() const
+ {
+ // TODO When the input transform supports multithreading, return the total
+ // number of tile rows (allowing for multiple batches). For now we return 1
+ // to indicate that the activations must be transformed as a single block.
+ return 1; // TODO _tiles_M * _n_batches;
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ void WinogradGEMM<otr, otc, kr, kc>::InputTransform<T>::run(
+ const unsigned int start, const unsigned int stop
+ )
+ {
+ // TODO When the input transform supports multithreading call execute for a
+ // portion of the tile rows.
+ (void) start;
+ (void) stop;
+
+ // For now, just do all of the work.
+ const Tensor4DShape input_shape = {
+ _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+ };
+ execute(
+ _inptr, input_shape, _padding_type, _tiles_M, _tiles_N, _outptr,
+ _matrix_stride, _matrix_row_stride * _tiles_M * _tiles_N, _matrix_row_stride
+ );
+ }
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
new file mode 100644
index 0000000000..bad3ef2249
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/kernel.hpp
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+using namespace winograd;
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::WeightsTransform(
+ const T* const input,
+ T* const output,
+ const int matrix_stride, /** Stride across matrices in the output. */
+ const int matrix_row_stride, /** Stride across rows of the matrix. */
+ const int n_output_channels,
+ const int n_input_channels
+) : inptr(input), outptr(output),
+ matrix_stride(matrix_stride), matrix_row_stride(matrix_row_stride),
+ n_output_channels(n_output_channels), n_input_channels(n_input_channels)
+{
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+unsigned int WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::get_window() const
+{
+ // TODO When the weights transform supports multithreading, return the number
+ // of output channels. For now we return 1 to indicate that the weights must
+ // be transformed as a single block.
+ // return n_output_channels;
+ return 1;
+}
+
+
+template <int otr, int otc, int kr, int kc>
+template <typename T>
+void WinogradGEMM<otr, otc, kr, kc>::WeightsTransform<T>::run(
+ const unsigned int start, const unsigned int stop
+)
+{
+ // TODO When the weights transform supports multithreading call execute for a
+ // portion of the output channels.
+ (void) start;
+ (void) stop;
+
+ // For now, just do all of the work.
+ execute(
+ n_output_channels,
+ n_input_channels,
+ inptr,
+ outptr,
+ matrix_stride,
+ matrix_row_stride
+ );
+}
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
new file mode 100644
index 0000000000..401b2816be
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp
@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp"
+
+namespace winograd
+{
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::execute(
+ const Tensor4DShape &output_shape,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output
+ )
+ {
+ // Compute the number of tiles and hence the padding required on the bottom
+ // and right of the image.
+ const int tile_M = iceildiv(output_shape.n_rows, output_tile_rows);
+ const int tile_N = iceildiv(output_shape.n_cols, output_tile_cols);
+ const int pad_bottom = output_tile_rows*tile_M - output_shape.n_rows;
+ const int pad_right = output_tile_cols*tile_N - output_shape.n_cols;
+
+ const int matrix_tile_row_stride = tile_N * matrix_row_stride;
+ const int matrix_batch_stride = tile_M * matrix_tile_row_stride;
+ const int output_col_stride = output_shape.n_channels;
+ const int output_row_stride = output_shape.n_cols * output_col_stride;
+ const int output_batch_stride = output_shape.n_rows * output_row_stride;
+
+ // Perform the output transformation for each batch
+ for (int batch = 0; batch < output_shape.n_batches; batch++)
+ {
+ // Get batch offset for input and outputs.
+ const T* const matrix_batch = matrix_base + batch*matrix_batch_stride;
+ T* const outptr_batch = output + batch*output_batch_stride;
+
+ // Perform the output transformation for each row of the output tensor.
+ for (int tile_i = 0; tile_i < tile_M; tile_i++)
+ {
+ // Compute properties of this row of output tiles
+ const int row_pad_bottom = (tile_i < tile_M - 1) ? 0: pad_bottom;
+ const T* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride;
+ T* const outptr_row = outptr_batch + output_tile_rows*tile_i*output_row_stride;
+
+ // Process the row
+ process_tile_row(
+ tile_N, output_shape.n_channels, matrix_tile_row, matrix_stride,
+ matrix_row_stride, biases,
+ outptr_row, output_row_stride, output_col_stride, row_pad_bottom,
+ pad_right
+ );
+ }
+ }
+ }
+
+ template <int output_tile_rows, int output_tile_cols,
+ int kernel_rows, int kernel_cols>
+ template <typename T>
+ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>::OutputTransform<T>::process_tile_row(
+ const int tile_N,
+ const int n_channels,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output,
+ const int output_row_stride,
+ const int output_col_stride,
+ const int row_pad_bottom,
+ const int row_pad_right
+ )
+ {
+ // Loop over columns of tiles
+ for (int tile_j = 0; tile_j < tile_N; tile_j++)
+ {
+ // Properties of this tile
+ const int tile_pad_right = (tile_j < tile_N - 1) ? 0 : row_pad_right;
+ const T* const matrix_row = matrix_base + tile_j * matrix_row_stride;
+ T* const outptr = output + output_tile_cols*tile_j*output_col_stride;
+
+ // Perform the output transformation
+ tile_fns[row_pad_bottom][tile_pad_right](
+ n_channels, matrix_row, matrix_stride, biases,
+ outptr, output_row_stride, output_col_stride
+ );
+ }
+ }
+
+ template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+ template <typename T>
+ size_t WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::bytes_read(const Tensor4DShape &shape)
+ {
+ const int M = iceildiv(shape.n_rows, output_tile_rows) *
+ iceildiv(shape.n_cols, output_tile_cols);
+ const int N = shape.n_channels;
+ return inner_tile_rows * inner_tile_cols * M * N * sizeof(T);
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ size_t WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::bytes_written(const Tensor4DShape &shape)
+ {
+ return shape.size() * sizeof(T);
+ }
+
+ template <int output_tile_rows, int output_tile_cols, int kr, int kc>
+ template <typename T>
+ WinogradGEMM<output_tile_rows, output_tile_cols, kr, kc>::OutputTransform<T>::OutputTransform(
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output,
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels
+ ) : _matrix_base(matrix_base), _biases(biases),
+ _matrix_stride(matrix_stride), _matrix_row_stride(matrix_row_stride),
+ _outptr(output), _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols),
+ _n_channels(n_channels), _tile_M(iceildiv(n_rows, output_tile_rows)),
+ _tile_N(iceildiv(n_cols, output_tile_cols))
+ {
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ unsigned int WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::get_window() const
+ {
+ // TODO When the output transform supports multithreading, return the total
+ // number of tile rows (allowing for multiple batches). For now we return 1
+ // to indicate that the activations must be transformed as a single block.
+ return 1; // TODO _tile_M * _n_batches;
+ }
+
+ template <int otr, int otc, int kr, int kc>
+ template <typename T>
+ void WinogradGEMM<otr, otc, kr, kc>::OutputTransform<T>::run(
+ const unsigned int start, const unsigned int stop
+ )
+ {
+ // TODO When the output transform supports multithreading call execute for a
+ // portion of the tile rows.
+ (void) start;
+ (void) stop;
+
+ // For now, just do all of the work.
+ const Tensor4DShape output_shape = {
+ _n_batches, _n_rows, _n_cols, _n_channels, NHWC
+ };
+ execute(
+ output_shape, _matrix_base, _matrix_stride, _matrix_row_stride, _biases,
+ _outptr
+ );
+ }
+} // namespace winograd
diff --git a/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
new file mode 100644
index 0000000000..f3b2bb10ed
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp
@@ -0,0 +1,447 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "arm_compute/core/NEON/kernels/convolution/common/alloc.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/convolution.hpp"
+#include "gemm.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/profiler.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/shims.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/tensor.hpp"
+#include "arm_compute/core/NEON/kernels/convolution/common/utils.hpp"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+// Generic Winograd implementation using GEMM
+namespace winograd
+{
+
+template <int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class WinogradGEMM
+{
+ public:
+ // Information about the specific Winograd instance
+ static constexpr int output_tile_rows = OutputTileRows;
+ static constexpr int output_tile_cols = OutputTileCols;
+ static constexpr int kernel_rows = KernelRows;
+ static constexpr int kernel_cols = KernelCols;
+ static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; // TODO Check
+ static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; // TODO Check
+ static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols;
+
+ /** Transform weights from the spatial to the Winograd domain. */
+ template <typename T>
+ struct WeightsTransform
+ {
+ /** Get the bytes read during the transform. */
+ static inline size_t bytes_read(const KernelShape &shape)
+ {
+ return shape.size() * sizeof(T);
+ }
+
+ /** Get the bytes written during the transform. */
+ static inline size_t bytes_written(const KernelShape &shape)
+ {
+ const int inner_tile_size = inner_tile_rows * inner_tile_cols;
+ return (inner_tile_size * shape.n_input_channels *
+ shape.n_output_channels * sizeof(T));
+ }
+
+ /** Get the count of operations performed by the transform. */
+ static int ops_performed(const KernelShape &shape);
+
+ /** Apply the transform to a tensor. */
+ static void execute(
+ const int n_output_channels,
+ const int n_input_channels,
+ const T* const input,
+ T* const output,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+
+ /** Create a WeightsTransform operator fixed on a given problem and set
+ * of pointers.
+ */
+ WeightsTransform(
+ const T* const input,
+ T* const output,
+ const int matrix_stride, /** Stride across matrices in the output. */
+ const int matrix_row_stride, /** Stride across rows of the matrix. */
+ const int n_output_channels, /** Number of filters. */
+ const int n_input_channels /** Number of channels in each filter. */
+ );
+
+ /** Get the window of work a given operator can perform. */
+ unsigned int get_window() const;
+
+ /** Perform work upon a window of the input. */
+ void run(const unsigned int start, const unsigned int stop);
+
+ private:
+ const T* const inptr; /** Fixed pointer to input data. */
+ T* const outptr; /** Fixed pointer to output memory. */
+ const int matrix_stride; /** Stride between output matrices. */
+ const int matrix_row_stride; /** Stride within output matrices. */
+ const int n_output_channels; /** Number of filters. */
+ const int n_input_channels; /** Number of channels in each filter. */
+ };
+
+ /** Transform input feature maps from the spatial to the Winograd domain.
+ */
+ template <typename T>
+ struct InputTransform
+ {
+ /** Get the bytes read during the transform. */
+ static size_t bytes_read(const Tensor4DShape &shape)
+ {
+ return shape.size() * sizeof(T);
+ }
+
+ /** Get the bytes written during the transform. */
+ static size_t bytes_written(const Tensor4DShape &shape)
+ {
+ const int M = iceildiv(shape.n_rows, inner_tile_rows) *
+ iceildiv(shape.n_cols, inner_tile_cols);
+ const int K = shape.n_channels;
+ return inner_tile_rows * inner_tile_cols * M * K * sizeof(T);
+ }
+
+ /** Get the count of operations performed by the transform. */
+ static int ops_performed(const Tensor4DShape &shape);
+
+ /** Apply the transform to a tensor. */
+ static void execute(
+ const T *inptr,
+ const Tensor4DShape& input_shape,
+ const PaddingType padding_type,
+ const int tile_M,
+ const int tile_N,
+ T *outptr_base,
+ const int matrix_stride,
+ const int matrix_batch_stride,
+ const int matrix_row_stride
+ );
+
+ /***********************************************************************/
+ /** Create an InputTransform operator fixed on a given problem and set of
+ * pointers.
+ */
+ InputTransform(
+ const T* const input, /** Input tensor data */
+ const int n_batches, /** Number of batches in input tensor. */
+ const int n_rows, /** Number of rows in input tensor. */
+ const int n_cols, /** Number of columns in input tensor. */
+ const int n_channels, /** Number of channels in input tensor. */
+ const PaddingType padding, /** Padding type. */
+ T* const output, /** Base of output matrices. */
+ const int matrix_stride, /** Stride between output matrices. */
+ const int matrix_row_stride /** Stride within matrices. */
+ );
+
+ /** Get the winodw of work a given operator can perform. */
+ unsigned int get_window() const;
+
+ /** Perform work upon a window of the input. */
+ void run(const unsigned int start, const unsigned int stop);
+ /***********************************************************************/
+
+ private:
+ static void process_tile_row(
+ const int tile_N,
+ int n_channels,
+ const T* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const int row_pad_top,
+ const int row_pad_left,
+ const int row_pad_bottom,
+ const int n_cols
+ );
+
+ static constexpr int max_pad_bottom = inner_tile_rows - 1;
+ static constexpr int max_pad_right = inner_tile_cols - 1;
+
+ /** Process a single tile of the input tensor. */
+ template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+ static void process_tile(int, const T*, int, int, T*, int);
+
+ // Array of methods to transform tiles of the input tensor.
+ typedef void (*TileFn)(int, const T*, int, int, T*, int);
+ static const TileFn tile_fns[2][2][max_pad_bottom][max_pad_right];
+
+ /* Member values for instance-based API. */
+ const T* const _inptr;
+ T* const _outptr;
+ const int _n_batches, _n_rows, _n_cols, _n_channels, _matrix_stride,
+ _matrix_row_stride, _tiles_M, _tiles_N;
+ const PaddingType _padding_type;
+ };
+
+ /** Transform output feature maps from the Winograd to the spatial domain.
+ */
+ template <typename T>
+ struct OutputTransform
+ {
+ /** Get the bytes read during the transform. */
+ static size_t bytes_read(const Tensor4DShape &shape);
+
+ /** Get the bytes written during the transform. */
+ static size_t bytes_written(const Tensor4DShape &shape);
+
+ /** Get the count of operations performed by the transform. */
+ static int ops_performed(const Tensor4DShape &shape);
+
+ /** Apply the transform to create a tensor. */
+ static void execute(
+ const Tensor4DShape &output_shape,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output
+ );
+
+ /***********************************************************************/
+ /** Create an OutputTransform operator fixed on a given problem and set
+ * of pointers.
+ */
+ OutputTransform(
+ const T* const matrix_base, /** Pointer to base of matrices. */
+ const int matrix_stride, /** Stride between matrices. */
+ const int matrix_row_stride, /** Stride within a matrix. */
+ const T* const biases, /** Pointer to biases vector. */
+ T* const output, /** Pointer to output tensor. */
+ const int n_batches, /** Number of batches in output tensor. */
+ const int n_rows, /** Number of rows in output tensor. */
+ const int n_cols, /** Number of columns in output tensor. */
+ const int n_channels /** Number of channels in output tensor. */
+ );
+
+ /** Get the window of work a given operator can perform. */
+ unsigned int get_window() const;
+
+ /** Perform work upon a window of the input. */
+ void run(const unsigned int start, const unsigned int stop);
+ /***********************************************************************/
+
+ private:
+ static void process_tile_row(
+ const int tile_N,
+ const int n_channels,
+ const T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride,
+ const T* const biases,
+ T* const output,
+ const int output_row_stride,
+ const int output_col_stride,
+ const int row_pad_bottom,
+ const int row_pad_right
+ );
+
+ // Limits on the amount of anti-padding to be applied
+ static constexpr int max_pad_bottom = output_tile_rows;
+ static constexpr int max_pad_right = output_tile_cols;
+
+ /** Prepare a single tile of the output tensor. */
+ template <int pad_bottom, int pad_right>
+ static void process_tile(int, const T*, int, const T*, T*, int, int);
+
+ // Array of methods to produce tiles of output tensor.
+ typedef void (*TileFn)(int, const T*, int, const T*, T*, int, int);
+ static const TileFn tile_fns[max_pad_bottom][max_pad_right];
+
+ /** Member constants for instances of the transform. */
+ const T* const _matrix_base;
+ const T* const _biases;
+ const int _matrix_stride, _matrix_row_stride;
+ T* const _outptr;
+ const int _n_batches, _n_rows, _n_cols, _n_channels, _tile_M, _tile_N;
+ };
+
+ /** Perform a convolution.
+ */
+ template <typename TOut, typename TIn>
+ class Convolution
+ {
+ public:
+ // Information about the typed Winograd instance
+ typedef TOut OutputType;
+ typedef TIn InputType;
+
+ /** Create a new Winograd operator. */
+ Convolution(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding,
+ void *kernel_storage=NULL
+ );
+
+ Convolution(const Convolution&) = delete;
+ Convolution operator=(const Convolution&) = delete;
+
+ /** Create a new Winograd operator and initialise the weights. */
+ Convolution(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding,
+ const TIn* const kernel,
+ void *kernel_storage=NULL,
+ void *transform_working_space=NULL
+ );
+
+ /** Clean up a convolution engine. */
+ ~Convolution();
+
+ /** Transform the weights into the Winograd domain. */
+ template <typename WeightsTransform=WeightsTransform<TIn>>
+ void transform_weights(
+ const TIn* const kernel,
+ void *transform_working_space=NULL
+ );
+
+ /* Apply the Winograd operator to some input. */
+ void execute(
+ TOut* const output,
+ const TIn* const input,
+ const TOut* const biases,
+ void* working_space=NULL,
+ const int n_threads=1
+ );
+
+ /* Apply the Winograd operator to some input. */
+ void execute(
+ TOut* const output,
+ const TIn* const input,
+ const TOut* const biases,
+ const int n_threads
+ );
+
+ /** Get the output shape of a convolution. */
+ static Tensor4DShape get_output_shape(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &in_shape,
+ const PaddingType padding
+ );
+
+ /* Get the memory required to transform the kernel.
+ */
+ static size_t get_kernel_transform_working_size(const KernelShape &shape);
+
+ /** Get the memory required to store the kernel transformed into the
+ * Winograd domain.
+ */
+ static size_t get_kernel_storage_size(const KernelShape &shape);
+
+ /** Get the memory required to store the input tensor transformed into
+ * the Winograd domain.
+ */
+ static size_t get_input_storage_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /** Get the memory required to store the output tensor in the Winograd
+ * domain.
+ */
+ static size_t get_output_storage_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /** Get the memory required to apply a Winograd operator to some input.
+ */
+ static size_t get_working_space_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /* Get the memory required by a single "input" matrix.
+ */
+ static size_t get_input_matrix_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ static int get_input_matrix_stride(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /* Get the memory required by a single "output" matrix.
+ */
+ static size_t get_output_matrix_size(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ static int get_output_matrix_stride(
+ const KernelShape &kernel_shape,
+ const Tensor4DShape &input_shape,
+ const PaddingType padding_type
+ );
+
+ /* Get the memory required by a single "kernel" matrix.
+ */
+ static size_t get_kernel_matrix_size(const KernelShape &shape);
+ static int get_kernel_matrix_stride(const KernelShape &shape);
+
+ static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */
+ static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */
+
+ private:
+ const KernelShape kernel_shape; /** Shape of the kernel to be applied. */
+ TIn *kernel_matrices[N_GEMMS]; /** Pointers into the kernel matrices. */
+ const int kernel_matrix_row_stride; /** Stride within the kernel matrices. */
+
+ const bool manage_kernel_storage; /** Kernel storage is managed by the instance. */
+ void* const _kernel_storage; /** Base pointer for kernel storage. */
+
+ const Tensor4DShape input_shape; /** Shape of the input tensor. */
+ const PaddingType padding; /** Padding applied by the operator. */
+
+ const Tensor4DShape output_shape; /** Output shape produced by the operator. */
+
+ const int tile_rows; /** Number of rows of tiles. */
+ const int tile_cols; /** Number of columns of tiles. */
+ const int M, K, N; /** Sizes of underlying fundamental matrix multiplications. */
+
+ profiler prof;
+ };
+};
+
+} // namespace winograd