aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-01-23 09:36:04 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:00 +0000
commitd6ca478a7e410f8f529c2e505305b46d9fe21a9b (patch)
tree5c50c06e07f812890f127b1c4933996987f74f17 /src
parentd05dce46a14a7b67f322328ecd95bf96bdd30bae (diff)
downloadComputeLibrary-d6ca478a7e410f8f529c2e505305b46d9fe21a9b.tar.gz
COMPMID-784: Added support for biases in WinogradLayer.
1) Updated to the latest code from the RSH repo. 2) Moved winograd transforms into kernels. 3) Added support for biases Change-Id: I7f39f34a599b49d7d9b549cc10a4f4d4a8007ab8 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/117474 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/NEWinogradLayerKernel.cpp144
-rw-r--r--src/core/NEON/kernels/winograd/transforms/input_2x2_5x5_fp32.cpp458
-rw-r--r--src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp25
-rw-r--r--src/core/NEON/kernels/winograd/transforms/output_2x2_5x5_fp32.cpp242
-rw-r--r--src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp19
-rw-r--r--src/core/NEON/kernels/winograd/transforms/weights_2x2_5x5_fp32.cpp408
-rw-r--r--src/core/NEON/kernels/winograd/winograd_gemm.cpp13
-rw-r--r--src/core/NEON/kernels/winograd/winograd_layer.cpp4
-rw-r--r--src/runtime/NEON/functions/NEWinogradLayer.cpp41
9 files changed, 1309 insertions, 45 deletions
diff --git a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
index ea48e1f32b..e2e4e40fe4 100644
--- a/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWinogradLayerKernel.cpp
@@ -55,7 +55,7 @@ public:
float *const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */
float *const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
)
- : convolver(n_batches, n_input_channels, n_input_rows, n_input_cols, n_output_channels, same_padding, weights, weights_storage, input, winograd_input, output, winograd_output)
+ : convolver(n_batches, n_input_channels, n_input_rows, n_input_cols, n_output_channels, same_padding, weights, weights_storage, input, winograd_input, nullptr, output, winograd_output)
{
}
T convolver;
@@ -65,24 +65,6 @@ Winograd3x3F32::~Winograd3x3F32()
{
}
-void Winograd3x3F32::transform_output()
-{
- auto win = _pimpl->convolver.output_transform.get_window();
- _pimpl->convolver.output_transform.run(0, win);
-}
-
-void Winograd3x3F32::transform_input()
-{
- auto win = _pimpl->convolver.input_transform.get_window();
- _pimpl->convolver.input_transform.run(0, win);
-}
-
-void Winograd3x3F32::transform_weights()
-{
- auto win = _pimpl->convolver.weights_transform.get_window();
- _pimpl->convolver.weights_transform.run(0, win);
-}
-
Winograd3x3F32::Winograd3x3F32(
const int n_batches, /** Number of batches in the input and output tensors. */
const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */
@@ -146,4 +128,128 @@ void NEWinogradLayerKernel::run(const Window &window, const ThreadInfo &info)
const size_t last_gemm = window.x().end();
_convolver->_pimpl->convolver.gemms.run(first_gemm, last_gemm);
}
+
+INEWinogradLayerTransformKernel::INEWinogradLayerTransformKernel()
+ : _convolver(nullptr)
+{
+}
+
+void INEWinogradLayerTransformKernel::configure(Winograd3x3F32 *convolver)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(convolver);
+ _convolver = convolver;
+}
+
+// Weights transform
+
+void NEWinogradLayerTransformWeightsKernel::configure(Winograd3x3F32 *convolver)
+{
+ INEWinogradLayerTransformKernel::configure(convolver);
+ Window win;
+ auto win_last = _convolver->_pimpl->convolver.weights_transform.get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+ INEKernel::configure(win);
+}
+
+void NEWinogradLayerTransformWeightsKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+ _convolver->_pimpl->convolver.weights_transform.run(fst, lst);
+}
+
+bool NEWinogradLayerTransformWeightsKernel::is_parallelisable() const
+{
+ return false;
+}
+
+// Input transform
+
+void NEWinogradLayerTransformInputKernel::configure(Winograd3x3F32 *convolver)
+{
+ INEWinogradLayerTransformKernel::configure(convolver);
+ Window win;
+ auto win_last = _convolver->_pimpl->convolver.input_transform.get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+ INEKernel::configure(win);
+}
+
+void NEWinogradLayerTransformInputKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+ _convolver->_pimpl->convolver.input_transform.run(fst, lst);
+}
+bool NEWinogradLayerTransformInputKernel::is_parallelisable() const
+{
+ return false;
+}
+
+// Output transform
+NEWinogradLayerTransformOutputKernel::NEWinogradLayerTransformOutputKernel()
+ : _biases(nullptr), _output_workspace(nullptr), _matrix_stride(0), _matrix_row_stride(0), _output(nullptr), _n_batches(0), _n_rows(0), _n_cols(0), _n_channels(0)
+{
+}
+
+void NEWinogradLayerTransformOutputKernel::configure(
+ const ITensor *biases,
+ const float *const output_workingspace,
+ const int matrix_stride,
+ float *const output,
+ const int n_batches,
+ const int n_rows,
+ const int n_cols,
+ const int n_channels)
+{
+ using WinogradBase = winograd::WinogradGEMM<2, 2, 3, 3>;
+ using OutputTransform = typename WinogradBase::template OutputTransform<float>;
+
+ _biases = biases;
+ _output_workspace = output_workingspace;
+ _matrix_stride = matrix_stride;
+ _matrix_row_stride = roundup(n_channels, WinogradBase::Convolution<float, float>::N_BLOCK);
+ _output = output;
+ _n_batches = n_batches;
+ _n_rows = n_rows;
+ _n_cols = n_cols;
+ _n_channels = n_channels;
+
+ // We don't have the biases buffer at this stage as it hasn't been allocated, we pass in nullptr OutputTransform is only used here to compute the window
+ OutputTransform output_transform(_output_workspace, _matrix_stride, _matrix_row_stride, nullptr, _output, _n_batches, _n_rows, _n_cols, _n_channels);
+ Window win;
+ auto win_last = output_transform.get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+ INEKernel::configure(win);
+}
+
+void NEWinogradLayerTransformOutputKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_biases->buffer());
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_output_workspace);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_output);
+
+ using WinogradBase = winograd::WinogradGEMM<2, 2, 3, 3>;
+ using OutputTransform = typename WinogradBase::template OutputTransform<float>;
+
+ OutputTransform output_transform(_output_workspace, _matrix_stride, _matrix_row_stride,
+ reinterpret_cast<float *>(_biases->buffer()), _output,
+ _n_batches, _n_rows, _n_cols, _n_channels);
+
+ // The code below cannot be moved to configure because biases hasn't been allocated at that point
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+ output_transform.run(fst, lst);
+}
+
+bool NEWinogradLayerTransformOutputKernel::is_parallelisable() const
+{
+ return false;
+}
+
} // namespace arm_compute
diff --git a/src/core/NEON/kernels/winograd/transforms/input_2x2_5x5_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/input_2x2_5x5_fp32.cpp
new file mode 100644
index 0000000000..a6ebca1bce
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/input_2x2_5x5_fp32.cpp
@@ -0,0 +1,458 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "transforms/input.hpp"
+#include "winograd_gemm.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+using Transform = WinogradGEMM<2, 2, 5, 5>::InputTransform<float>;
+
+template <>
+template <>
+int Transform::ops_performed(const Tensor4DShape &input_shape)
+{
+ return 0; // TODO
+}
+
+/*****************************************************************************
+* F(2x2, 5x5) implies the use of a 6x6 input tile.
+*
+* Build an array of the specialised methods that deal with each of the
+* different padding combinations which may be required. These padding
+* constraints are the space:
+*
+* Padding top in {0, 1}
+* Padding left in {0, 1}
+* Padding bottom in {0, 1, 2, 3, 4}
+* Padding right in {0, 1, 2, 3, 4}
+*/
+template <>
+template <>
+template <int pad_top, int pad_left, int pad_bottom, int pad_right>
+void Transform::process_tile(
+ int n_channels,
+ const float* const input_base,
+ const int input_row_stride,
+ const int input_col_stride,
+ float* const matrix_base,
+ const int matrix_stride
+)
+{
+ constexpr int cells_i = 6 - pad_bottom;
+ constexpr int cells_j = 6 - pad_right;
+
+ float *outptr = matrix_base;
+
+ // Get pointers into the input tile
+ const float *x_ptrs[6][6];
+ for (int i = pad_top, xi = 0; i < cells_i; i++, xi++)
+ {
+ // Get a pointer into the row
+ const float* const row_ptr = input_base + xi*input_row_stride;
+
+ for (int j = pad_left, xj = 0; j < cells_j; j++, xj++)
+ {
+ x_ptrs[i][j] = row_ptr + xj*input_col_stride;
+ }
+ }
+
+ // Matrices used/computed in this kernel.
+ float x[6][6], XTx[6][6], U[6][6];
+ for (int i = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++)
+ {
+ x[i][j] = XTx[i][j] = 0.0f;
+ }
+ }
+
+ // Perform the Winograd input transformation for each channel in the input
+ // tensor.
+ int channels_remaining = n_channels;
+#ifdef __aarch64__
+ for (; channels_remaining >= 4; channels_remaining -= 4)
+ {
+ // Matrices used/computed in this kernel
+ float32x4_t x[6][6], XTx[6][6], U[6][6];
+ for (int i = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++)
+ {
+ x[i][j] = vdupq_n_f32(0.0f);
+ XTx[i][j] = vdupq_n_f32(0.0f);
+ }
+ }
+
+ // Read a 6x6 tile in the Winograd domain
+ for (int i = pad_top; i < cells_i; i++)
+ {
+ for (int j = pad_left; j < cells_j; j++)
+ {
+ x[i][j] = vld1q_f32(x_ptrs[i][j]);
+ x_ptrs[i][j] += 4;
+ }
+ }
+
+ // Compute XT . x
+ for (int j = pad_left; j < cells_j; j++)
+ {
+ // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j];
+ XTx[0][j] = vmlsq_n_f32(vmlaq_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f);
+
+ // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j];
+ XTx[1][j] = vmlsq_n_f32(vaddq_f32(x[3][j], x[4][j]), vaddq_f32(x[1][j], x[2][j]), 4.0f);
+
+ // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j];
+ XTx[2][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[3][j]), vsubq_f32(x[1][j], x[2][j]), 4.0f);
+
+ // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j];
+ XTx[3][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[3][j], x[1][j]), 2.0f);
+
+ // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j];
+ XTx[4][j] = vmlaq_n_f32(vsubq_f32(x[4][j], x[2][j]), vsubq_f32(x[1][j], x[3][j]), 2.0f);
+
+ // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j];
+ XTx[5][j] = vmlsq_n_f32(vmlaq_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f);
+ }
+
+ // Compute U = XT . x . X
+ for (int i = 0; i < 6; i++)
+ {
+ // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4];
+ U[i][0] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f);
+
+ // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4];
+ U[i][1] = vmlsq_n_f32(vaddq_f32(XTx[i][3], XTx[i][4]), vaddq_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+ // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4];
+ U[i][2] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][3]), vsubq_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+ // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4];
+ U[i][3] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][3], XTx[i][1]), 2.0f);
+
+ // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4];
+ U[i][4] = vmlaq_n_f32(vsubq_f32(XTx[i][4], XTx[i][2]), vsubq_f32(XTx[i][1], XTx[i][3]), 2.0f);
+
+ // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5];
+ U[i][5] = vmlsq_n_f32(vmlaq_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f);
+ }
+
+ // Store the transformed matrix
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ vst1q_f32(outptr + m*matrix_stride, U[i][j]);
+ }
+ }
+ outptr += 4;
+ }
+#endif // __aarch64__
+#ifdef __arm_any__
+ for (; channels_remaining >= 2; channels_remaining -= 2)
+ {
+ // Matrices used/computed in this kernel
+ float32x2_t x[6][6], XTx[6][6], U[6][6];
+ for (int i = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++)
+ {
+ x[i][j] = vdup_n_f32(0.0f);
+ XTx[i][j] = vdup_n_f32(0.0f);
+ }
+ }
+
+ // Read a 6x6 tile in the Winograd domain
+ for (int i = pad_top; i < cells_i; i++)
+ {
+ for (int j = pad_left; j < cells_j; j++)
+ {
+ x[i][j] = vld1_f32(x_ptrs[i][j]);
+ x_ptrs[i][j] += 2;
+ }
+ }
+
+ // Compute XT . x
+ for (int j = pad_left; j < cells_j; j++)
+ {
+ // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j];
+ XTx[0][j] = vmls_n_f32(vmla_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f);
+
+ // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j];
+ XTx[1][j] = vmls_n_f32(vadd_f32(x[3][j], x[4][j]), vadd_f32(x[1][j], x[2][j]), 4.0f);
+
+ // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j];
+ XTx[2][j] = vmla_n_f32(vsub_f32(x[4][j], x[3][j]), vsub_f32(x[1][j], x[2][j]), 4.0f);
+
+ // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j];
+ XTx[3][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[3][j], x[1][j]), 2.0f);
+
+ // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j];
+ XTx[4][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[1][j], x[3][j]), 2.0f);
+
+ // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j];
+ XTx[5][j] = vmls_n_f32(vmla_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f);
+ }
+
+ // Compute U = XT . x . X
+ for (int i = 0; i < 6; i++)
+ {
+ // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4];
+ U[i][0] = vmls_n_f32(vmla_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f);
+
+ // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4];
+ U[i][1] = vmls_n_f32(vadd_f32(XTx[i][3], XTx[i][4]), vadd_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+ // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4];
+ U[i][2] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][3]), vsub_f32(XTx[i][1], XTx[i][2]), 4.0f);
+
+ // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4];
+ U[i][3] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][3], XTx[i][1]), 2.0f);
+
+ // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4];
+ U[i][4] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][1], XTx[i][3]), 2.0f);
+
+ // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5];
+ U[i][5] = vmls_n_f32(vmla_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f);
+ }
+
+ // Store the transformed matrix
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ vst1_f32(outptr + m*matrix_stride, U[i][j]);
+ }
+ }
+ outptr += 2;
+ }
+#endif // __arm_any__
+ for (; channels_remaining; channels_remaining--)
+ {
+ // Load x
+ for (int i = pad_top; i < cells_i; i++)
+ {
+ for (int j = pad_left; j < cells_j; j++)
+ {
+ x[i][j] = *(x_ptrs[i][j]++);
+ }
+ }
+
+ // Compute XT . x
+ for (int j = pad_left; j < cells_j; j++)
+ {
+ XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j];
+ XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j];
+ XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j];
+ XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j];
+ XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j];
+ XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j];
+ }
+
+ // Compute U = XT . x . X
+ for (int i = 0; i < 6; i++)
+ {
+ U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4];
+ U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4];
+ U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4];
+ U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4];
+ U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4];
+ U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5];
+ }
+
+ // Store the transformed matrix
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ *(outptr + m*matrix_stride) = U[i][j];
+ }
+ }
+ outptr++;
+ }
+}
+
+template <>
+template <>
+const Transform::TileFn Transform::tile_fns[2][2][max_pad_bottom][max_pad_right] =
+{
+ {
+ {
+ {
+ Transform::template process_tile<0, 0, 0, 0>, // No padding
+ Transform::template process_tile<0, 0, 0, 1>, // Right
+ Transform::template process_tile<0, 0, 0, 2>, // " "
+ Transform::template process_tile<0, 0, 0, 3>, // " "
+ Transform::template process_tile<0, 0, 0, 4>, // " "
+ },
+ {
+ Transform::template process_tile<0, 0, 1, 0>, // Bottom
+ Transform::template process_tile<0, 0, 1, 1>, // Bottom right
+ Transform::template process_tile<0, 0, 1, 2>, // " "
+ Transform::template process_tile<0, 0, 1, 3>, // " "
+ Transform::template process_tile<0, 0, 1, 4>, // " "
+ },
+ {
+ Transform::template process_tile<0, 0, 2, 0>, // Bottom
+ Transform::template process_tile<0, 0, 2, 1>, // Bottom right
+ Transform::template process_tile<0, 0, 2, 2>, // " "
+ Transform::template process_tile<0, 0, 2, 3>, // " "
+ Transform::template process_tile<0, 0, 2, 4>, // " "
+ },
+ {
+ Transform::template process_tile<0, 0, 3, 0>, // Bottom
+ Transform::template process_tile<0, 0, 3, 1>, // Bottom right
+ Transform::template process_tile<0, 0, 3, 2>, // " "
+ Transform::template process_tile<0, 0, 3, 3>, // " "
+ Transform::template process_tile<0, 0, 3, 4>, // " "
+ },
+ {
+ Transform::template process_tile<0, 0, 4, 0>, // Bottom
+ Transform::template process_tile<0, 0, 4, 1>, // Bottom right
+ Transform::template process_tile<0, 0, 4, 2>, // " "
+ Transform::template process_tile<0, 0, 4, 3>, // " "
+ Transform::template process_tile<0, 0, 4, 4>, // " "
+ }
+ },
+ {
+ {
+ Transform::template process_tile<0, 1, 0, 0>, // Left
+ Transform::template process_tile<0, 1, 0, 1>,
+ Transform::template process_tile<0, 1, 0, 2>,
+ Transform::template process_tile<0, 1, 0, 3>,
+ Transform::template process_tile<0, 1, 0, 4>,
+ },
+ {
+ Transform::template process_tile<0, 1, 1, 0>, // Bottom left
+ Transform::template process_tile<0, 1, 1, 1>,
+ Transform::template process_tile<0, 1, 1, 2>,
+ Transform::template process_tile<0, 1, 1, 3>,
+ Transform::template process_tile<0, 1, 1, 4>,
+ },
+ {
+ Transform::template process_tile<0, 1, 2, 0>, // " "
+ Transform::template process_tile<0, 1, 2, 1>,
+ Transform::template process_tile<0, 1, 2, 2>,
+ Transform::template process_tile<0, 1, 2, 3>,
+ Transform::template process_tile<0, 1, 2, 4>,
+ },
+ {
+ Transform::template process_tile<0, 1, 3, 0>, // " "
+ Transform::template process_tile<0, 1, 3, 1>,
+ Transform::template process_tile<0, 1, 3, 2>,
+ Transform::template process_tile<0, 1, 3, 3>,
+ Transform::template process_tile<0, 1, 3, 4>,
+ },
+ {
+ Transform::template process_tile<0, 1, 4, 0>, // " "
+ Transform::template process_tile<0, 1, 4, 1>,
+ Transform::template process_tile<0, 1, 4, 2>,
+ Transform::template process_tile<0, 1, 4, 3>,
+ Transform::template process_tile<0, 1, 4, 4>,
+ }
+ }
+ },
+ {
+ {
+ {
+ Transform::template process_tile<1, 0, 0, 0>, // Top
+ Transform::template process_tile<1, 0, 0, 1>, // Top right
+ Transform::template process_tile<1, 0, 0, 2>, // " "
+ Transform::template process_tile<1, 0, 0, 3>, // " "
+ Transform::template process_tile<1, 0, 0, 4>, // " "
+ },
+ {
+ Transform::template process_tile<1, 0, 1, 0>,
+ Transform::template process_tile<1, 0, 1, 1>,
+ Transform::template process_tile<1, 0, 1, 2>,
+ Transform::template process_tile<1, 0, 1, 3>,
+ Transform::template process_tile<1, 0, 1, 4>,
+ },
+ {
+ Transform::template process_tile<1, 0, 2, 0>,
+ Transform::template process_tile<1, 0, 2, 1>,
+ Transform::template process_tile<1, 0, 2, 2>,
+ Transform::template process_tile<1, 0, 2, 3>,
+ Transform::template process_tile<1, 0, 2, 4>,
+ },
+ {
+ Transform::template process_tile<1, 0, 3, 0>,
+ Transform::template process_tile<1, 0, 3, 1>,
+ Transform::template process_tile<1, 0, 3, 2>,
+ Transform::template process_tile<1, 0, 3, 3>,
+ Transform::template process_tile<1, 0, 3, 4>,
+ },
+ {
+ Transform::template process_tile<1, 0, 4, 0>,
+ Transform::template process_tile<1, 0, 4, 1>,
+ Transform::template process_tile<1, 0, 4, 2>,
+ Transform::template process_tile<1, 0, 4, 3>,
+ Transform::template process_tile<1, 0, 4, 4>,
+ },
+ },
+ {
+ {
+ Transform::template process_tile<1, 1, 0, 0>, // Top left
+ Transform::template process_tile<1, 1, 0, 1>,
+ Transform::template process_tile<1, 1, 0, 2>,
+ Transform::template process_tile<1, 1, 0, 3>,
+ Transform::template process_tile<1, 1, 0, 4>,
+ },
+ {
+ Transform::template process_tile<1, 1, 1, 0>,
+ Transform::template process_tile<1, 1, 1, 1>,
+ Transform::template process_tile<1, 1, 1, 2>,
+ Transform::template process_tile<1, 1, 1, 3>,
+ Transform::template process_tile<1, 1, 1, 4>,
+ },
+ {
+ Transform::template process_tile<1, 1, 2, 0>,
+ Transform::template process_tile<1, 1, 2, 1>,
+ Transform::template process_tile<1, 1, 2, 2>,
+ Transform::template process_tile<1, 1, 2, 3>,
+ Transform::template process_tile<1, 1, 2, 4>,
+ },
+ {
+ Transform::template process_tile<1, 1, 3, 0>,
+ Transform::template process_tile<1, 1, 3, 1>,
+ Transform::template process_tile<1, 1, 3, 2>,
+ Transform::template process_tile<1, 1, 3, 3>,
+ Transform::template process_tile<1, 1, 3, 4>,
+ },
+ {
+ Transform::template process_tile<1, 1, 4, 0>,
+ Transform::template process_tile<1, 1, 4, 1>,
+ Transform::template process_tile<1, 1, 4, 2>,
+ Transform::template process_tile<1, 1, 4, 3>,
+ Transform::template process_tile<1, 1, 4, 4>,
+ }
+ }
+ }
+};
+
+template struct WinogradGEMM<2, 2, 5, 5>::InputTransform<float>;
+} // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp
index e7907d18c0..58db7d2ecd 100644
--- a/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp
+++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_3x3_fp32.cpp
@@ -65,6 +65,7 @@ void Transform::process_tile(
const int n_channels,
const float* const matrix_base,
const int matrix_stride,
+ const float* const biases,
float* const output,
const int output_row_stride,
const int output_col_stride
@@ -83,6 +84,7 @@ void Transform::process_tile(
}
}
const float *inptr = matrix_base;
+ const float *bptr = biases;
// For each channel of the output
int channels_remaining = n_channels;
@@ -90,7 +92,7 @@ void Transform::process_tile(
for (; channels_remaining >= 4; channels_remaining -= 4)
{
// Matrices used and computed during this transform
- float32x4_t F[4][4], FZ[4][2], f[2][2];
+ float32x4_t F[4][4], FZ[4][2], f[2][2], b;
// Read a 4x4 tile in the Winograd domain
for (int i = 0, m = 0; i < 4; i++)
@@ -122,12 +124,16 @@ void Transform::process_tile(
f[1][j] = vsubq_f32(vsubq_f32(FZ[1][j], FZ[2][j]), FZ[3][j]);
}
+ // Load the bias vector
+ b = vld1q_f32(bptr);
+ bptr += 4;
+
// Write out the output tile
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- vst1q_f32(outptrs[i][j], f[i][j]);
+ vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b));
outptrs[i][j] += 4;
}
}
@@ -137,7 +143,7 @@ void Transform::process_tile(
for (; channels_remaining >= 2; channels_remaining -= 2)
{
// Matrices used and computed during this transform
- float32x2_t F[4][4], FZ[4][2], f[2][2];
+ float32x2_t F[4][4], FZ[4][2], f[2][2], b;
// Read a 4x4 tile in the Winograd domain
for (int i = 0, m = 0; i < 4; i++)
@@ -169,12 +175,16 @@ void Transform::process_tile(
f[1][j] = vsub_f32(vsub_f32(FZ[1][j], FZ[2][j]), FZ[3][j]);
}
+ // Load the bias vector
+ b = vld1_f32(bptr);
+ bptr += 2;
+
// Write out the output tile
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- vst1_f32(outptrs[i][j], f[i][j]);
+ vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
outptrs[i][j] += 2;
}
}
@@ -183,7 +193,7 @@ void Transform::process_tile(
for (; channels_remaining; channels_remaining--)
{
// Matrices used and computed during this transform
- float F[4][4], FZ[4][2], f[2][2];
+ float F[4][4], FZ[4][2], f[2][2], b;
// Read a 4x4 tile in the Winograd domain
for (int i = 0, m = 0; i < 4; i++)
@@ -209,12 +219,15 @@ void Transform::process_tile(
f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j];
}
+ // Load the bias
+ b = *(bptr++);
+
// Write out the output tile
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- *(outptrs[i][j]++) = f[i][j];
+ *(outptrs[i][j]++) = f[i][j] + b;
}
}
}
diff --git a/src/core/NEON/kernels/winograd/transforms/output_2x2_5x5_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_2x2_5x5_fp32.cpp
new file mode 100644
index 0000000000..bfd670090a
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/output_2x2_5x5_fp32.cpp
@@ -0,0 +1,242 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "transforms/output.hpp"
+#include "winograd_gemm.hpp"
+#include "arm.hpp"
+
+namespace winograd
+{
+
+using Transform = WinogradGEMM<2, 2, 5, 5>::OutputTransform<float>;
+
+template <>
+template <>
+int Transform::ops_performed(const Tensor4DShape &shape)
+{
+ return 0; // TODO
+}
+
+/* F(2x2, 5x5) constructs 2x2 output tiles from a 5x5 convolution. Since we use
+ * enough tiles to cover the output space each output tile may contain 0 or 1
+ * padded values to the right and bottom columns or rows of the tile, e.g.:
+ *
+ * ___ ___
+ * | | | X|
+ * |___| |__X|
+ *
+ * ___ ___
+ * | | | X|
+ * |X_X| |X_X|
+ *
+ *
+ * We provide a specialised output transform for each of these instances.
+ * Consequently we below construct an array of the various padding options, the
+ * array contains pointers to the specific implementations.
+ */
+template <>
+template <>
+template <int pad_bottom, int pad_right>
+void Transform::process_tile(
+ const int n_channels,
+ const float* const matrix_base,
+ const int matrix_stride,
+ const float* const biases,
+ float* const output,
+ const int output_row_stride,
+ const int output_col_stride
+)
+{
+ constexpr int cells_i = 2 - pad_bottom;
+ constexpr int cells_j = 2 - pad_right;
+
+ // Construct a map to the output cells
+ float *outptrs[cells_i][cells_j];
+ for (int i = 0; i < cells_i; i++)
+ {
+ for (int j = 0; j < cells_j; j++)
+ {
+ outptrs[i][j] = output + i*output_row_stride + j*output_col_stride;
+ }
+ }
+ const float *inptr = matrix_base;
+ const float *bptr = biases;
+
+ // For each channel of the output
+ int channels_remaining = n_channels;
+#ifdef __aarch64__
+ for (; channels_remaining >= 4; channels_remaining -= 4)
+ {
+ // Matrices used and computed during this transform
+ float32x4_t F[6][6], FZ[6][2], f[2][2], b;
+
+ // Read a 6x6 tile in the Winograd domain
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ F[i][j] = vld1q_f32(inptr + m*matrix_stride);
+ }
+ }
+ inptr += 4;
+
+ // Compute the matrix F Z
+ for (int i = 0; i < 6; i++)
+ {
+ // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
+ FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]);
+
+ // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5];
+ FZ[i][1] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f), F[i][5]);
+ }
+
+ // Compute the output tile f = ZT F Z
+ for (int j = 0; j < 2; j++)
+ {
+ // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
+ f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+ // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j];
+ f[1][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f), FZ[5][j]);
+ }
+
+ // Write out the output tile
+ b = vld1q_f32(bptr);
+ bptr += 4;
+ for (int i = 0; i < cells_i; i++)
+ {
+ for (int j = 0; j < cells_j; j++)
+ {
+ vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b));
+ outptrs[i][j] += 4;
+ }
+ }
+ }
+#endif // __aarch64__
+#ifdef __arm_any__
+ for (; channels_remaining >= 2; channels_remaining -= 2)
+ {
+ // Matrices used and computed during this transform
+ float32x2_t F[6][6], FZ[6][2], f[2][2], b;
+
+ // Read a 6x6 tile in the Winograd domain
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ F[i][j] = vld1_f32(inptr + m*matrix_stride);
+ }
+ }
+ inptr += 2;
+
+ // Compute the matrix F Z
+ for (int i = 0; i < 6; i++)
+ {
+ // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
+ FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]);
+
+ // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5];
+ FZ[i][1] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f), F[i][5]);
+ }
+
+ // Compute the output tile f = ZT F Z
+ for (int j = 0; j < 2; j++)
+ {
+ // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
+ f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
+
+ // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j];
+ f[1][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f), FZ[5][j]);
+ }
+
+ // Write out the output tile
+ b = vld1_f32(bptr);
+ bptr += 2;
+ for (int i = 0; i < cells_i; i++)
+ {
+ for (int j = 0; j < cells_j; j++)
+ {
+ vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
+ outptrs[i][j] += 2;
+ }
+ }
+ }
+#endif // __arm_any__
+ for (; channels_remaining; channels_remaining--)
+ {
+ // Matrices used and computed during this transform
+ float F[6][6], FZ[6][2], f[2][2], b;
+
+ // Read a 6x6 tile in the Winograd domain
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ F[i][j] = *(inptr + m*matrix_stride);
+ }
+ }
+ inptr++;
+
+ // Compute the matrix F Z
+ for (int i = 0; i < 6; i++)
+ {
+ FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
+ FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5];
+ }
+
+ // Compute the output tile f = ZT F Z
+ for (int j = 0; j < 2; j++)
+ {
+ f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
+ f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j];
+ }
+
+ // Write out the output tile
+ b = *(bptr++);
+ for (int i = 0; i < cells_i; i++)
+ {
+ for (int j = 0; j < cells_j; j++)
+ {
+ *(outptrs[i][j]++) = f[i][j] + b;
+ }
+ }
+ }
+}
+
+template <>
+template <>
+const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] =
+{
+ {
+ Transform::template process_tile<0, 0>, // No padding
+ Transform::template process_tile<0, 1>, // Right padding
+ },
+ {
+ Transform::template process_tile<1, 0>, // Bottom padding
+ Transform::template process_tile<1, 1>, // Bottom and right padding
+ }
+};
+
+template struct WinogradGEMM<2, 2, 5, 5>::OutputTransform<float>;
+} // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
index 483e5c110b..45210d7976 100644
--- a/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
+++ b/src/core/NEON/kernels/winograd/transforms/output_4x4_3x3_fp32.cpp
@@ -82,6 +82,7 @@ void Transform::process_tile(
const int n_channels,
const float* const matrix_base,
const int matrix_stride,
+ const float* const biases,
float* const output,
const int output_row_stride,
const int output_col_stride
@@ -100,6 +101,7 @@ void Transform::process_tile(
}
}
const float *inptr = matrix_base;
+ const float *bptr = biases;
// For each channel of the output
int channels_remaining = n_channels;
@@ -107,7 +109,7 @@ void Transform::process_tile(
for (; channels_remaining >= 4; channels_remaining -= 4)
{
// Matrices used and computed during this transform
- float32x4_t F[6][6], FZ[6][4], f[4][4];
+ float32x4_t F[6][6], FZ[6][4], f[4][4], b;
// Read a 6x6 tile in the Winograd domain
for (int i = 0, m = 0; i < 6; i++)
@@ -152,11 +154,13 @@ void Transform::process_tile(
}
// Write out the output tile
+ b = vld1q_f32(bptr);
+ bptr += 4;
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- vst1q_f32(outptrs[i][j], f[i][j]);
+ vst1q_f32(outptrs[i][j], vaddq_f32(f[i][j], b));
outptrs[i][j] += 4;
}
}
@@ -166,7 +170,7 @@ void Transform::process_tile(
for (; channels_remaining >= 2; channels_remaining -= 2)
{
// Matrices used and computed during this transform
- float32x2_t F[6][6], FZ[6][4], f[4][4];
+ float32x2_t F[6][6], FZ[6][4], f[4][4], b;
// Read a 6x6 tile in the Winograd domain
for (int i = 0, m = 0; i < 6; i++)
@@ -211,11 +215,13 @@ void Transform::process_tile(
}
// Write out the output tile
+ b = vld1_f32(bptr);
+ bptr += 2;
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- vst1_f32(outptrs[i][j], f[i][j]);
+ vst1_f32(outptrs[i][j], vadd_f32(f[i][j], b));
outptrs[i][j] += 2;
}
}
@@ -224,7 +230,7 @@ void Transform::process_tile(
for (; channels_remaining; channels_remaining--)
{
// Matrices used and computed during this transform
- float F[6][6], FZ[6][4], f[4][4];
+ float F[6][6], FZ[6][4], f[4][4], b;
// Read a 6x6 tile in the Winograd domain
for (int i = 0, m = 0; i < 6; i++)
@@ -255,11 +261,12 @@ void Transform::process_tile(
}
// Write out the output tile
+ b = *(bptr++);
for (int i = 0; i < cells_i; i++)
{
for (int j = 0; j < cells_j; j++)
{
- *(outptrs[i][j]++) = f[i][j];
+ *(outptrs[i][j]++) = f[i][j] + b;
}
}
}
diff --git a/src/core/NEON/kernels/winograd/transforms/weights_2x2_5x5_fp32.cpp b/src/core/NEON/kernels/winograd/transforms/weights_2x2_5x5_fp32.cpp
new file mode 100644
index 0000000000..acf6b913f8
--- /dev/null
+++ b/src/core/NEON/kernels/winograd/transforms/weights_2x2_5x5_fp32.cpp
@@ -0,0 +1,408 @@
+/*
+ * Copyright (c) 2017 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm.hpp"
+#include "winograd_gemm.hpp"
+#include "transforms/kernel.hpp"
+
+namespace winograd
+{
+ template <>
+ template <>
+ void WinogradGEMM<2, 2, 5, 5>::WeightsTransform<float>::execute(
+ const int n_output_channels,
+ const int n_input_channels,
+ const float* const input,
+ float* const output,
+ const int matrix_stride,
+ const int matrix_row_stride
+ )
+ {
+ // Get pointers to each cell of the weight tensor
+ const auto weight_col_stride = n_input_channels * n_output_channels;
+ const auto weight_row_stride = 5 * weight_col_stride;
+ const float *inptrs[5][5];
+ for (int i = 0; i < 5; i++)
+ {
+ for (int j = 0; j < 5; j++)
+ {
+ inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride;
+ }
+ }
+
+ // For each input channel
+ for (int ic = 0; ic < n_input_channels; ic++)
+ {
+ float *outptr = output + ic * matrix_row_stride;
+
+ // For each output channel
+ int channels_remaining = n_output_channels;
+#ifdef __aarch64__
+ for (; channels_remaining >= 4; channels_remaining -= 4)
+ {
+ // Matrices used and computed in this kernel
+ float32x4_t w[5][5], Ww[6][5], V[6][6];
+
+ // Read weights
+ for (int i = 0; i < 5; i++)
+ {
+ for (int j = 0; j < 5; j++)
+ {
+ w[i][j] = vld1q_f32(inptrs[i][j]);
+ inptrs[i][j] += 4;
+ }
+ }
+
+ // Compute the matrix W w
+ for (int j = 0; j < 5; j++)
+ {
+ // Ww[0][j] = w[0][j]/4.0f;
+ Ww[0][j] = vmulq_n_f32(w[0][j], 1.0f/4.0f);
+
+ // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
+ Ww[1][j] = vmulq_n_f32(
+ vaddq_f32(
+ vaddq_f32(
+ vaddq_f32(w[1][j], w[0][j]),
+ vaddq_f32(w[3][j], w[2][j])
+ ),
+ w[4][j]
+ ),
+ -1.0f/6.0f
+ );
+
+ // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
+ // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f;
+ Ww[2][j] = vmulq_n_f32(
+ vsubq_f32(
+ vaddq_f32(
+ vsubq_f32(w[1][j], w[0][j]),
+ vsubq_f32(w[3][j], w[2][j])
+ ),
+ w[4][j]
+ ),
+ 1.0f/6.0f
+ );
+
+ // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
+ Ww[3][j] = vmulq_n_f32(
+ vmlaq_n_f32(
+ vaddq_f32(
+ vaddq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)),
+ vaddq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
+ ),
+ w[4][j], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
+ Ww[4][j] = vmulq_n_f32(
+ vmlaq_n_f32(
+ vaddq_f32(
+ vsubq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)),
+ vsubq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
+ ),
+ w[4][j], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // Ww[5][j] = w[4][j];
+ Ww[5][j] = w[4][j];
+ }
+
+ // Compute V = W w WT
+ for (int i = 0; i < 6; i++)
+ {
+ // V[i][0] = Ww[i][0]/4.0f;
+ V[i][0] = vmulq_n_f32(Ww[i][0], 1.0f/4.0f);
+
+ // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
+ V[i][1] = vmulq_n_f32(
+ vaddq_f32(
+ vaddq_f32(
+ vaddq_f32(Ww[i][1], Ww[i][0]),
+ vaddq_f32(Ww[i][3], Ww[i][2])
+ ),
+ Ww[i][4]
+ ),
+ -1.0f/6.0f
+ );
+
+ // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
+ // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f;
+ V[i][2] = vmulq_n_f32(
+ vsubq_f32(
+ vaddq_f32(
+ vsubq_f32(Ww[i][1], Ww[i][0]),
+ vsubq_f32(Ww[i][3], Ww[i][2])
+ ),
+ Ww[i][4]
+ ),
+ 1.0f/6.0f
+ );
+
+ // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
+ V[i][3] = vmulq_n_f32(
+ vmlaq_n_f32(
+ vaddq_f32(
+ vaddq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)),
+ vaddq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
+ ),
+ Ww[i][4], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
+ V[i][4] = vmulq_n_f32(
+ vmlaq_n_f32(
+ vaddq_f32(
+ vsubq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)),
+ vsubq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
+ ),
+ Ww[i][4], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // V[i][5] = Ww[i][4];
+ V[i][5] = Ww[i][4];
+ }
+
+ // Store the transformed weights
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ vst1q_f32(outptr + m*matrix_stride, V[i][j]);
+ }
+ }
+ outptr += 4;
+ }
+#endif // __aarch64__
+#ifdef __arm_any__
+ for (; channels_remaining >= 2; channels_remaining -= 2)
+ {
+ // Matrices used and computed in this kernel
+ float32x2_t w[5][5], Ww[6][5], V[6][6];
+
+ // Read weights
+ for (int i = 0; i < 5; i++)
+ {
+ for (int j = 0; j < 5; j++)
+ {
+ w[i][j] = vld1_f32(inptrs[i][j]);
+ inptrs[i][j] += 2;
+ }
+ }
+
+ // Compute the matrix W w
+ for (int j = 0; j < 5; j++)
+ {
+ // Ww[0][j] = w[0][j]/4.0f;
+ Ww[0][j] = vmul_n_f32(w[0][j], 1.0f/4.0f);
+
+ // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
+ Ww[1][j] = vmul_n_f32(
+ vadd_f32(
+ vadd_f32(
+ vadd_f32(w[1][j], w[0][j]),
+ vadd_f32(w[3][j], w[2][j])
+ ),
+ w[4][j]
+ ),
+ -1.0f/6.0f
+ );
+
+ // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
+ // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f;
+ Ww[2][j] = vmul_n_f32(
+ vsub_f32(
+ vadd_f32(
+ vsub_f32(w[1][j], w[0][j]),
+ vsub_f32(w[3][j], w[2][j])
+ ),
+ w[4][j]
+ ),
+ 1.0f/6.0f
+ );
+
+ // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
+ Ww[3][j] = vmul_n_f32(
+ vmla_n_f32(
+ vadd_f32(
+ vadd_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)),
+ vadd_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
+ ),
+ w[4][j], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
+ Ww[4][j] = vmul_n_f32(
+ vmla_n_f32(
+ vadd_f32(
+ vsub_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)),
+ vsub_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j])
+ ),
+ w[4][j], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // Ww[5][j] = w[4][j];
+ Ww[5][j] = w[4][j];
+ }
+
+ // Compute V = W w WT
+ for (int i = 0; i < 6; i++)
+ {
+ // V[i][0] = Ww[i][0]/4.0f;
+ V[i][0] = vmul_n_f32(Ww[i][0], 1.0f/4.0f);
+
+ // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
+ V[i][1] = vmul_n_f32(
+ vadd_f32(
+ vadd_f32(
+ vadd_f32(Ww[i][1], Ww[i][0]),
+ vadd_f32(Ww[i][3], Ww[i][2])
+ ),
+ Ww[i][4]
+ ),
+ -1.0f/6.0f
+ );
+
+ // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
+ // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f;
+ V[i][2] = vmul_n_f32(
+ vsub_f32(
+ vadd_f32(
+ vsub_f32(Ww[i][1], Ww[i][0]),
+ vsub_f32(Ww[i][3], Ww[i][2])
+ ),
+ Ww[i][4]
+ ),
+ 1.0f/6.0f
+ );
+
+ // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
+ V[i][3] = vmul_n_f32(
+ vmla_n_f32(
+ vadd_f32(
+ vadd_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)),
+ vadd_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
+ ),
+ Ww[i][4], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
+ V[i][4] = vmul_n_f32(
+ vmla_n_f32(
+ vadd_f32(
+ vsub_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)),
+ vsub_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3])
+ ),
+ Ww[i][4], 2.0f
+ ),
+ 1.0f/3.0f
+ );
+
+ // V[i][5] = Ww[i][4];
+ V[i][5] = Ww[i][4];
+ }
+
+ // Store the transformed weights
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ vst1_f32(outptr + m*matrix_stride, V[i][j]);
+ }
+ }
+ outptr += 2;
+ }
+#endif // __arm_any__
+ for (; channels_remaining; channels_remaining--)
+ {
+ // Matrices used and computed in this kernel
+ float w[5][5], Ww[6][5], V[6][6];
+
+ // Read weights
+ for (int i = 0; i < 5; i++)
+ {
+ for (int j = 0; j < 5; j++)
+ {
+ w[i][j] = *(inptrs[i][j]++);
+ }
+ }
+
+ // Compute the matrix W w
+ for (int j = 0; j < 5; j++)
+ {
+ Ww[0][j] = w[0][j]/4.0f;
+ Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f;
+ Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f;
+ Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f;
+ Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f;
+ Ww[5][j] = w[4][j];
+ }
+
+ // Compute V = W w WT
+ for (int i = 0; i < 6; i++)
+ {
+ V[i][0] = Ww[i][0]/4.0f;
+ V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f;
+ V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f;
+ V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f;
+ V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f;
+ V[i][5] = Ww[i][4];
+ }
+
+ // Store the transformed weights
+ for (int i = 0, m = 0; i < 6; i++)
+ {
+ for (int j = 0; j < 6; j++, m++)
+ {
+ *(outptr + m*matrix_stride) = V[i][j];
+ }
+ }
+ outptr++;
+ }
+ }
+ }
+
+ template <>
+ template <>
+ int WinogradGEMM<2, 2, 5, 5>::WeightsTransform<float>::ops_performed(const KernelShape &shape)
+ {
+ return 0; // TODO
+ }
+
+ template class WinogradGEMM<2, 2, 5, 5>::WeightsTransform<float>;
+} // namespace winograd
diff --git a/src/core/NEON/kernels/winograd/winograd_gemm.cpp b/src/core/NEON/kernels/winograd/winograd_gemm.cpp
index b44a45367f..fcfa635232 100644
--- a/src/core/NEON/kernels/winograd/winograd_gemm.cpp
+++ b/src/core/NEON/kernels/winograd/winograd_gemm.cpp
@@ -372,6 +372,7 @@ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>:
Convolution<TOut, TIn>::execute(
TOut* const output,
const TIn* const input,
+ const TOut* const biases,
void *working_space,
const int n_threads
)
@@ -479,7 +480,11 @@ Convolution<TOut, TIn>::execute(
kernel_matrices[0],
output_matrices[0]
);
- gemms.run(0, gemms.get_window());
+ for (unsigned int i = 0; i < gemms.get_window(); i++)
+ {
+ auto run_gemm = [&] () { gemms.run(i, i+1); };
+ prof("GEMM", run_gemm, 0, 0, 0);
+ }
// If the output tensor needs to be in NCHW form then store the NHWC output
// tensor in temporary storage and then reorder. If the output tensor needs
@@ -498,6 +503,7 @@ Convolution<TOut, TIn>::execute(
output_matrices[0],
out_matrix_stride_bytes / sizeof(TOut),
out_matrix_row_stride,
+ biases,
output_nhwc,
output_shape.n_batches,
output_shape.n_rows,
@@ -548,13 +554,16 @@ void WinogradGEMM<output_tile_rows, output_tile_cols, kernel_rows, kernel_cols>:
Convolution<TOut, TIn>::execute(
TOut* const output,
const TIn* const input,
+ const TOut* const biases,
const int n_threads
)
{
- execute(output, input, NULL, n_threads);
+ execute(output, input, biases, NULL, n_threads);
}
// Instantiate required implementations
template class WinogradGEMM<2, 2, 3, 3>::Convolution<float, float>;
template class WinogradGEMM<4, 4, 3, 3>::Convolution<float, float>;
+
+template class WinogradGEMM<2, 2, 5, 5>::Convolution<float, float>;
diff --git a/src/core/NEON/kernels/winograd/winograd_layer.cpp b/src/core/NEON/kernels/winograd/winograd_layer.cpp
index 689ecba5fb..f16d62c0ef 100644
--- a/src/core/NEON/kernels/winograd/winograd_layer.cpp
+++ b/src/core/NEON/kernels/winograd/winograd_layer.cpp
@@ -157,6 +157,7 @@ WinogradConvolutionLayer(
TIn* const winograd_weights, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */
const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */
TIn* const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */
+ const TOut* const biases, /** Pointer to biases vector. */
TOut* const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */
TOut* const winograd_output /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */
) : _kernel_shape(n_output_channels, KernelRows, KernelCols, n_input_channels),
@@ -193,7 +194,7 @@ WinogradConvolutionLayer(
winograd_input, winograd_weights, winograd_output
),
output_transform(
- winograd_output, _output_matrix_stride, _output_matrix_row_stride,
+ winograd_output, _output_matrix_stride, _output_matrix_row_stride, biases,
output, n_batches, _n_output_rows, _n_output_cols, n_output_channels
)
{
@@ -202,3 +203,4 @@ WinogradConvolutionLayer(
// Instantiate valid implementations.
template class WinogradConvolutionLayer<2, 2, 3, 3, float, float>;
template class WinogradConvolutionLayer<4, 4, 3, 3, float, float>;
+template class WinogradConvolutionLayer<2, 2, 5, 5, float, float>;
diff --git a/src/runtime/NEON/functions/NEWinogradLayer.cpp b/src/runtime/NEON/functions/NEWinogradLayer.cpp
index cb1d8b5a48..e8c77412a2 100644
--- a/src/runtime/NEON/functions/NEWinogradLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradLayer.cpp
@@ -28,6 +28,8 @@
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "support/ToolchainSupport.h"
+#include "arm_compute/core/NEON/kernels/winograd/winograd_gemm.hpp"
+
namespace
{
inline Tensor4DShape internal_get_input_shape(const arm_compute::ITensor *input)
@@ -43,15 +45,15 @@ inline Tensor4DShape internal_get_input_shape(const arm_compute::ITensor *input)
namespace arm_compute
{
NEWinogradLayer::NEWinogradLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _winograd_kernel(), _permute_input(), _permute_weights(), _permute_output(), _input_workspace(), _output_workspace(), _kernel_storage(), _input_nhwc(),
- _output_nhwc(), _weights_hwio(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv()
+ : _memory_group(std::move(memory_manager)), _winograd_kernel(), _transform_input_kernel(), _transform_output_kernel(), _transform_weights_kernel(), _permute_input(), _permute_weights(),
+ _permute_output(), _input_workspace(), _output_workspace(), _kernel_storage(), _input_nhwc(), _output_nhwc(), _weights_hwio(), _input(), _weights(), _output(), _reshaped_kernel(false), _conv()
{
} /* arm_compute */
void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, biases);
ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(1) != 3 || weights->info()->dimension(0) != 3, "Only 3x3 kernels are supported");
ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
@@ -76,22 +78,22 @@ void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, co
const int out_channels = output->info()->dimension(2);
const Tensor4DShape in_shape(internal_get_input_shape(input));
-
+ const size_t data_type_size = input->info()->element_size();
// Get the memory required to instantiate a new Winograd operator.
constexpr size_t storage_alignment = 64;
- const size_t kernel_storage_size = NEWinogradLayerKernel::get_weight_storage_size(out_channels, in_channels) * sizeof(float);
+ const size_t kernel_storage_size = NEWinogradLayerKernel::get_weight_storage_size(out_channels, in_channels) * data_type_size;
_kernel_storage.allocator()->init(TensorInfo(TensorShape{ (kernel_storage_size + storage_alignment - 1) }, 1, DataType::U8));
_memory_group.manage(&_kernel_storage);
_memory_group.manage(&_input_nhwc);
_kernel_storage.allocator()->allocate();
// Input storage
- const size_t input_storage_size = NEWinogradLayerKernel::get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, false) * sizeof(float);
+ const size_t input_storage_size = NEWinogradLayerKernel::get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, false) * data_type_size;
_input_workspace.allocator()->init(TensorInfo(TensorShape{ (input_storage_size + storage_alignment - 1) }, 1, DataType::U8));
_memory_group.manage(&_input_workspace);
_input_workspace.allocator()->allocate();
// Output storage
- const size_t output_storage_size = NEWinogradLayerKernel::get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels, false) * sizeof(float);
+ const size_t output_storage_size = NEWinogradLayerKernel::get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels, false) * data_type_size;
_output_workspace.allocator()->init(TensorInfo(TensorShape{ (output_storage_size + storage_alignment - 1) }, 1, DataType::U8));
_memory_group.manage(&_output_workspace);
_output_workspace.allocator()->allocate();
@@ -130,7 +132,6 @@ void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, co
_permute_input.configure(input, &_input_nhwc, PermutationVector(2U, 0U, 1U));
_input_nhwc.allocator()->allocate();
-
// Create Winograd operator object
_conv = support::cpp14::make_unique<Winograd3x3F32>(
in_shape.n_batches,
@@ -148,6 +149,20 @@ void NEWinogradLayer::configure(const ITensor *input, const ITensor *weights, co
// Configure the kernel, padding not needed so it's safe to call configure after allocare
_winograd_kernel.configure(_conv.get());
+ _transform_input_kernel.configure(_conv.get());
+ _transform_weights_kernel.configure(_conv.get());
+
+ //The biases tensor has not been allocated at this point in time, the output transform will add the biases to the final result in the run() method
+ using T = winograd::WinogradGEMM<2, 2, 3, 3>::Convolution<float, float>;
+ const int weights_width = weights->info()->dimension(0);
+ const int weights_height = weights->info()->dimension(1);
+ const KernelShape kernel_shape({ out_channels, weights_height, weights_width, in_channels });
+ const int output_matrix_stride = T::get_output_matrix_stride(kernel_shape, in_shape, PADDING_VALID);
+ const auto output_shape(T::get_output_shape(kernel_shape, in_shape, PADDING_VALID));
+
+ _transform_output_kernel.configure(biases, reinterpret_cast<float *>(_output_workspace.buffer()),
+ output_matrix_stride, reinterpret_cast<float *>(_output_nhwc.buffer()),
+ in_shape.n_batches, output_shape.n_rows, output_shape.n_cols, out_channels);
// Reorder the convoluted output to ACL's ordering NCHW
_permute_output.configure(&_output_nhwc, _output, PermutationVector(1U, 2U, 0U));
@@ -160,16 +175,20 @@ void NEWinogradLayer::run()
{
_reshaped_kernel = true;
_permute_weights.run();
- _conv->transform_weights();
+ NEScheduler::get().schedule(&_transform_weights_kernel, Window::DimX);
}
+
//Bring channels to the front as Winograd code expects the tensor to be in the format NHWC
_permute_input.run();
// Transform input tensor to the winograd domain
- _conv->transform_input();
+ NEScheduler::get().schedule(&_transform_input_kernel, Window::DimX);
+
//Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs
NEScheduler::get().schedule(&_winograd_kernel, Window::DimX);
+
// Transform output tensor to the spatial domain
- _conv->transform_output();
+ NEScheduler::get().schedule(&_transform_output_kernel, Window::DimX);
+
// Reorder the convoluted output to ACL's ordering NCHW
_permute_output.run();
_memory_group.release();