aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp195
1 files changed, 195 insertions, 0 deletions
diff --git a/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
new file mode 100644
index 0000000000..033442aa14
--- /dev/null
+++ b/arm_compute/core/NEON/kernels/winograd/transforms/kernel_2x2_3x3.hpp
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+namespace winograd {
+ /* Transform a kernel into the Winograd domain.
+ *
+ * NOTE: It is assumed that the kernel is in the form [height x width x
+ * input_channels x output_channel].
+ */
+ template <typename T>
+ struct winograd2x2_3x3_gemm_kernel_transform_impl{
+ static void execute(
+ const KernelShape &shape,
+ const T* const kernel,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+
+ protected:
+ template <const int output_channel_tail>
+ static void transform_kernel(
+ const T* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ T* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride
+ );
+ };
+}
+
+/*****************************************************************************/
+/* Transform a fp32 kernel into the Winograd domain.
+ */
+#include "kernel_2x2_3x3/a64_float.hpp" // AArch64 specialisations
+
+namespace winograd
+{
+template <>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::execute(
+ const KernelShape &shape,
+ const float* const kernel,
+ float* const matrix_base,
+ const int matrix_stride,
+ const int matrix_row_stride
+) {
+ // Delegate based on tail size
+ const int n_input_channels = shape.n_input_channels;
+ const int n_output_channels = shape.n_output_channels;
+
+ switch (n_output_channels % 4) {
+ case 0:
+ transform_kernel<0>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ case 1:
+ transform_kernel<1>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ case 2:
+ transform_kernel<2>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ case 3:
+ transform_kernel<3>(
+ kernel, n_input_channels, n_output_channels,
+ matrix_base, matrix_stride, matrix_row_stride
+ );
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Cannot happen");
+ break;
+ }
+}
+
+template <>
+template<const int output_channel_tail>
+inline void winograd2x2_3x3_gemm_kernel_transform_impl<float>::transform_kernel(
+ const float* const kernel,
+ const int n_input_channels,
+ const int n_output_channels,
+ float* const matrix_base,
+ const int mstride,
+ const int matrix_row_stride
+) {
+ // Use one input pointer for each row of the kernel, use two additional
+ // offsets to extract columns.
+ const int kernel_col_stride = n_input_channels * n_output_channels;
+ const int kernel_row_stride = 3 * kernel_col_stride;
+ const float *inptr0 = kernel;
+ const float *inptr1 = kernel + kernel_row_stride;
+ const float *inptr2 = kernel + kernel_row_stride*2;
+
+ // Use four output pointers, for output matrices 0, 4, 8 and 12. Use three
+ // offsets to extract further matrices.
+ float *outptr0 = matrix_base;
+ float *outptr4 = matrix_base + mstride * 4;
+ float *outptr8 = matrix_base + mstride * 8;
+ float *outptr12 = matrix_base + mstride * 12;
+
+ // For every input channel
+ for (int in_c = 0; in_c < n_input_channels; in_c++) {
+ // For every output channel
+ for (int c = 0; c < n_output_channels; c++) {
+ // Read in the kernel
+ float w11 = inptr0[0], w12 = inptr0[kernel_col_stride], w13 = inptr0[kernel_col_stride*2];
+ float w21 = inptr1[0], w22 = inptr1[kernel_col_stride], w23 = inptr1[kernel_col_stride*2];
+ float w31 = inptr2[0], w32 = inptr2[kernel_col_stride], w33 = inptr2[kernel_col_stride*2];
+
+ // Progress input pointers
+ inptr0++;
+ inptr1++;
+ inptr2++;
+
+ // Compute the kernel W w, note we need only compute the middle two rows
+ // (2 and 3) because the first and last rows are merely copies of values
+ // from the matrix w.
+ float Ww11 = w11, Ww12 = w12, Ww13 = w13;
+ float Ww21 = 0.5*(w11 + w21 + w31), Ww22 = 0.5*(w12 + w22 + w32), Ww23 = 0.5*(w13 + w23 + w33);
+ float Ww31 = 0.5*(w11 - w21 + w31), Ww32 = 0.5*(w12 - w22 + w32), Ww33 = 0.5*(w13 - w23 + w33);
+ float Ww41 = w31, Ww42 = w32, Ww43 = w33;
+
+ // Hence compute W w W.T; again note we need compute only the middle two
+ // columns since the first and last columns are copies of the first and
+ // last columns of the previous matrix.
+ float WwWT11 = Ww11, WwWT12 = 0.5*(Ww11 + Ww12 + Ww13), WwWT13 = 0.5*(Ww11 - Ww12 + Ww13), WwWT14 = Ww13;
+ float WwWT21 = Ww21, WwWT22 = 0.5*(Ww21 + Ww22 + Ww23), WwWT23 = 0.5*(Ww21 - Ww22 + Ww23), WwWT24 = Ww23;
+ float WwWT31 = Ww31, WwWT32 = 0.5*(Ww31 + Ww32 + Ww33), WwWT33 = 0.5*(Ww31 - Ww32 + Ww33), WwWT34 = Ww33;
+ float WwWT41 = Ww41, WwWT42 = 0.5*(Ww41 + Ww42 + Ww43), WwWT43 = 0.5*(Ww41 - Ww42 + Ww43), WwWT44 = Ww43;
+
+ // Store the computed weights
+ outptr0[0 * mstride] = WwWT11;
+ outptr0[1 * mstride] = WwWT12;
+ outptr0[2 * mstride] = WwWT13;
+ outptr0[3 * mstride] = WwWT14;
+
+ outptr4[0 * mstride] = WwWT21;
+ outptr4[1 * mstride] = WwWT22;
+ outptr4[2 * mstride] = WwWT23;
+ outptr4[3 * mstride] = WwWT24;
+
+ outptr8[0 * mstride] = WwWT31;
+ outptr8[1 * mstride] = WwWT32;
+ outptr8[2 * mstride] = WwWT33;
+ outptr8[3 * mstride] = WwWT34;
+
+ outptr12[0 * mstride] = WwWT41;
+ outptr12[1 * mstride] = WwWT42;
+ outptr12[2 * mstride] = WwWT43;
+ outptr12[3 * mstride] = WwWT44;
+
+ // Progress output pointers
+ outptr0++;
+ outptr4++;
+ outptr8++;
+ outptr12++;
+ }
+
+ // Progression to complete stride
+ outptr0 += matrix_row_stride - n_output_channels;
+ outptr4 += matrix_row_stride - n_output_channels;
+ outptr8 += matrix_row_stride - n_output_channels;
+ outptr12 += matrix_row_stride - n_output_channels;
+ }
+}
+}