aboutsummaryrefslogtreecommitdiff
path: root/src/core/cpu
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2021-07-01 12:20:56 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2021-07-13 13:42:25 +0000
commit96f977e43f452a75f2658b820791cb3d3da9c0a3 (patch)
treefe279f0573d871c051bb49acf4b83f50b29a1647 /src/core/cpu
parent04b39e8e56112dabf6f5746117354680a9985841 (diff)
downloadComputeLibrary-96f977e43f452a75f2658b820791cb3d3da9c0a3.tar.gz
Port NEWinogradConvolutionLayer
Rename to CpuWinogradConv2d Allow memory to be injected externally Change-Id: I1f0a26ea533e326a7c63df86e708895c31752a39 Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5926 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'src/core/cpu')
-rw-r--r--src/core/cpu/kernels/CpuWinogradConv2dKernel.cpp552
-rw-r--r--src/core/cpu/kernels/CpuWinogradConv2dKernel.h575
2 files changed, 1127 insertions, 0 deletions
diff --git a/src/core/cpu/kernels/CpuWinogradConv2dKernel.cpp b/src/core/cpu/kernels/CpuWinogradConv2dKernel.cpp
new file mode 100644
index 0000000000..74b031b226
--- /dev/null
+++ b/src/core/cpu/kernels/CpuWinogradConv2dKernel.cpp
@@ -0,0 +1,552 @@
+/*
+ * Copyright (c) 2017-2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/core/cpu/kernels/CpuWinogradConv2dKernel.h"
+
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "src/core/NEON/kernels/convolution/common/utils.hpp"
+#include "src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp"
+#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/helpers/WindowHelpers.h"
+
+#include <memory>
+
+namespace arm_compute
+{
+namespace cpu
+{
+//Batched Gemms
+
+namespace
+{
+inline bool is_kernel_size_supported(DataType data_type, Size2D size)
+{
+ const std::array<Size2D, 8> f32_support = { { Size2D(1, 3), Size2D(3, 1), Size2D(5, 5), Size2D(3, 3), Size2D(1, 5), Size2D(5, 1), Size2D(7, 1), Size2D(1, 7) } };
+ const std::array<Size2D, 8> f16_support = { { Size2D(3, 3) } };
+
+ switch(data_type)
+ {
+ case DataType::F16:
+ return std::end(f16_support) != std::find(std::begin(f16_support), std::end(f16_support), size);
+ case DataType::F32:
+ return std::end(f32_support) != std::find(std::begin(f32_support), std::end(f32_support), size);
+ default:
+ return false;
+ }
+}
+
+Status validate_arguments_winograd_weight_trans(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+
+ const size_t idx_width = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+ const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+ const auto input_width = input->dimension(idx_width);
+ const auto input_height = input->dimension(idx_height);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(input_width, input_height)),
+ "Only 1x3, 3x1, 1x5, 5x1, 7x1, 1x7, 3x3 and 5x5 kernels are supported");
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
+ const Size2D &output_tile = winograd_info.output_tile_size;
+ const std::array<Size2D, 8> supported_tile_sizes = { { Size2D(2U, 2U), Size2D(4U, 4U), Size2D(1U, 6U), Size2D(6U, 1U), Size2D(4, 1), Size2D(1, 4), Size2D(2, 1), Size2D(1, 2) } };
+ ARM_COMPUTE_RETURN_ERROR_ON(std::end(supported_tile_sizes) == std::find(std::begin(supported_tile_sizes), std::end(supported_tile_sizes), output_tile));
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ const TensorInfo tensor_info_output = input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_filter_transform_shape(*input, winograd_info));
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window_winograd_weight_trans(ITensorInfo *input, ITensorInfo *output, const WinogradInfo &winograd_info)
+{
+ // Output tensor auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_filter_transform_shape(*input, winograd_info)));
+ const Window win = calculate_max_window(*input, Steps(), true /* skip border*/);
+ return std::make_pair(Status{}, win);
+}
+
+Status validate_arguments_winograd_input_trans(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info)
+{
+ const Size2D &kernel_dims = winograd_info.kernel_size;
+ const PadStrideInfo &conv_info = winograd_info.convolution_info;
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(kernel_dims.width, kernel_dims.height)),
+ "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported");
+
+ // Validate configured output
+ if(output->total_size() != 0)
+ {
+ const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, winograd_info);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window_winograd_input_trans(ITensorInfo *input, ITensorInfo *output, const WinogradInfo &winograd_info)
+{
+ const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, winograd_info);
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape));
+ return std::make_pair(Status{}, calculate_max_window(*input, Steps(), true));
+}
+
+Status validate_arguments_winograd_output_trans(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info)
+{
+ const PadStrideInfo &conv_info = winograd_info.convolution_info;
+ const Size2D kernel_dims = winograd_info.kernel_size;
+
+ // Number of tiles along the X and Y direction
+ const unsigned int num_tiles_x = std::ceil((winograd_info.input_dimensions.x() - (kernel_dims.width - 1) + conv_info.pad_left() + conv_info.pad_right()) / static_cast<float>
+ (winograd_info.output_tile_size.width));
+ const unsigned int num_tiles_y = std::ceil((winograd_info.input_dimensions.y() - (kernel_dims.height - 1) + conv_info.pad_top() + conv_info.pad_bottom()) / static_cast<float>
+ (winograd_info.output_tile_size.height));
+ const Size2D num_tiles = Size2D(num_tiles_x, num_tiles_y);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != num_tiles.area());
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(kernel_dims.width, kernel_dims.height)),
+ "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported");
+
+ const std::array<unsigned int, 3> supported_gemm_sizes = { { 8U, 16U, 36U } };
+ ARM_COMPUTE_RETURN_ERROR_ON(std::end(supported_gemm_sizes) == std::find(std::begin(supported_gemm_sizes), std::end(supported_gemm_sizes), input->dimension(2)));
+ ARM_COMPUTE_UNUSED(kernel_dims);
+ if(bias != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != bias->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() != size_t(1));
+ }
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ const TensorInfo tensor_info_output = input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_output_transform_shape(*input, winograd_info));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window_winograd_output_trans(ITensorInfo *input, ITensorInfo *output, const WinogradInfo &winograd_info)
+{
+ // Output tensor auto initialization if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_output_transform_shape(*input, winograd_info)));
+
+ return std::make_pair(Status{}, calculate_max_window(*input, Steps(), true));
+}
+} // namespace
+
+Status ICpuWinogradConv2dTransformWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *weights)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ const DataLayout data_layout = input->data_layout();
+ const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(weights->dimension(width_idx), weights->dimension(height_idx))),
+ "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported");
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+ return Status{};
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+unsigned int CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_weight_storage_size(int num_output_channels, int num_input_channels) const
+{
+ const KernelShape shape(num_output_channels, KernelRows, KernelCols, num_input_channels);
+ return static_cast<unsigned int>(
+ // WinogradConv returns the size in bytes, we divide by `sizeof(T)` to express that in units of T
+ WinogradConv::get_kernel_storage_size(num_input_channels, num_output_channels) / sizeof(T));
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::CpuWinogradConv2dTransformWeightsKernel()
+ : _transform(nullptr), _num_output_channels(0), _matrix_stride(0)
+{
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+int CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(int num_output_channels, int num_input_channels) const
+{
+ return WinogradConv::get_kernel_matrix_stride(num_input_channels, num_output_channels);
+}
+
+#ifndef DOXYGEN_SKIP_THIS
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+void CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
+ const ITensorInfo *weights_hwio,
+ ITensorInfo *output,
+ const int matrix_stride, /** Stride across matrices in the output. */
+ const int num_output_channels, /** Number of filters. */
+ const int num_input_channels) /** Number of channels in each filter. */
+{
+ ARM_COMPUTE_UNUSED(weights_hwio, output);
+
+ _transform = std::make_unique<WeightsTransform>(num_output_channels, num_input_channels);
+ _num_output_channels = num_output_channels;
+ _matrix_stride = matrix_stride;
+
+ Window win;
+ auto win_last = _transform->get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+ ICpuKernel::configure(win);
+}
+#endif /* DOXYGEN_SKIP_THIS */
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+void CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON(tensors.empty());
+
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+
+ const ITensor *weights_hwio = tensors.get_const_tensor(TensorType::ACL_SRC);
+ ITensor *output = tensors.get_tensor(TensorType::ACL_DST);
+
+ _transform->set_weight_tensor(weights_hwio->buffer());
+ const int matrix_row_stride = roundup(_num_output_channels, WinogradConv::N_BLOCK);
+ _transform->set_output_matrices(output->buffer(), _matrix_stride, matrix_row_stride);
+ _transform->set_working_space(output->buffer());
+
+ _transform->run(fst, lst);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+bool CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::is_parallelisable() const
+{
+ return false;
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+Status CpuWinogradConv2dTransformWeightsKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::validate(const ITensorInfo *input, const ITensorInfo *output,
+ const WinogradInfo &winograd_info)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_weight_trans(input, output, winograd_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_winograd_weight_trans(input->clone().get(), output->clone().get(), winograd_info).first);
+ return Status{};
+}
+
+template class CpuWinogradConv2dTransformWeightsKernel<float, 2, 2, 3, 3>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 4, 4, 3, 3>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 2, 2, 5, 5>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 1, 6, 1, 3>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 6, 1, 3, 1>;
+
+template class CpuWinogradConv2dTransformWeightsKernel<float, 1, 4, 1, 5>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 4, 1, 5, 1>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 1, 2, 1, 7>;
+template class CpuWinogradConv2dTransformWeightsKernel<float, 2, 1, 7, 1>;
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class CpuWinogradConv2dTransformWeightsKernel<__fp16, 4, 4, 3, 3>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+// Input transform
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+unsigned int CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_input_storage_size(
+ int num_batches, /* Number of batches in the input tensor. */
+ int num_channels, /* Number of feature maps in the input tensor. */
+ int num_rows, /* Number of rows in each feature map. */
+ int num_cols, /* Number of columns in each feature map. */
+ bool same_padding /* Use "SAME" padding, otherwise use "VALID". */
+) const
+{
+ // Construct shapes for the input and kernel tensors.
+ const Tensor4DShape input_shape(num_batches, num_rows, num_cols, num_channels);
+ const KernelShape kern_shape(1, KernelRows, KernelCols, num_channels);
+ // Return the size, converted into units of TIn
+ return static_cast<unsigned int>(WinogradConv::get_input_storage_size(num_batches, num_rows, num_cols, num_channels, same_padding) / sizeof(T));
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+unsigned int CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_working_space_size(unsigned int num_threads) const
+{
+ return _transform->get_working_space_size(num_threads) / sizeof(T);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+int CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(
+ int num_batches, /* Number of batches in the input tensor. */
+ int num_channels, /* Number of feature maps in the input tensor. */
+ int num_rows, /* Number of rows in each feature map. */
+ int num_cols, /* Number of columns in each feature map. */
+ bool same_padding /* Use "SAME" padding, otherwise use "VALID". */) const
+{
+ return WinogradConv::get_input_matrix_stride(num_batches, num_rows, num_cols, num_channels, same_padding);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::CpuWinogradConv2dTransformInputKernel()
+ : _transform(nullptr), _num_channels(0), _matrix_stride(0)
+{
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+void CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
+ const ITensorInfo *input_nhwc,
+ const int num_batches, /* Number of batches in input tensor. */
+ const int num_rows, /* Number of rows in input tensor. */
+ const int num_cols, /* Number of columns in input tensor. */
+ const int num_channels, /* Number of channels in input tensor. */
+ const PaddingType padding, /* Padding type. */
+ ITensorInfo *output, /* Base of output matrices. */
+ const int matrix_stride, /* Stride between output matrices. */
+ ITensorInfo *workspace)
+{
+ ARM_COMPUTE_UNUSED(input_nhwc, output, matrix_stride, workspace);
+
+ _num_channels = num_channels;
+ _matrix_stride = matrix_stride;
+
+ const int padding_top = (padding == PADDING_SAME) ? (KernelRows - 1) / 2 : 0;
+ const int padding_left = (padding == PADDING_SAME) ? (KernelCols - 1) / 2 : 0;
+ const int padding_bottom = (padding == PADDING_SAME) ? iceildiv(KernelRows - 1, 2) : 0;
+ const int padding_right = (padding == PADDING_SAME) ? iceildiv(KernelCols - 1, 2) : 0;
+
+ _transform = std::make_unique<InputTransform>(
+ KernelRows,
+ KernelCols,
+ num_batches,
+ num_rows,
+ num_cols,
+ num_channels,
+ padding_top, /**< Padding to apply to the top of the image. */
+ padding_left, /**< Padding to apply to the left of the image. */
+ padding_bottom, /**< Padding to apply to the bottom of the image. */
+ padding_right /**< Padding to apply to the right of the image. */
+ );
+
+ Window win;
+ auto win_last = _transform->get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+ ICpuKernel::configure(win);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+void CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON(tensors.empty());
+
+ const ITensor *input_nhwc = tensors.get_const_tensor(TensorType::ACL_SRC);
+ const ITensor *workspace = tensors.get_const_tensor(TensorType::ACL_INT);
+ ITensor *output = tensors.get_tensor(TensorType::ACL_DST);
+
+ const int element_size_in_bytes = input_nhwc->info()->element_size();
+ const int input_col_stride = input_nhwc->info()->strides_in_bytes().y() / element_size_in_bytes;
+ const int input_row_stride = input_nhwc->info()->strides_in_bytes().z() / element_size_in_bytes;
+ const int input_batch_stride = input_nhwc->info()->strides_in_bytes()[3] / element_size_in_bytes;
+ const auto input_nhwc_ptr = reinterpret_cast<const T *>(input_nhwc->buffer() + input_nhwc->info()->offset_first_element_in_bytes());
+ auto output_ptr = reinterpret_cast<T *>(output->buffer() + output->info()->offset_first_element_in_bytes());
+ ARM_COMPUTE_ERROR_ON_NULLPTR(output_ptr);
+
+ _transform->set_input_tensor(input_nhwc_ptr, input_batch_stride, input_row_stride, input_col_stride);
+ _transform->set_output_matrices(output_ptr, _matrix_stride, _num_channels);
+
+ _transform->set_working_space(workspace->buffer());
+
+ // The code below cannot be moved to configure because biases hasn't been allocated at that point
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+ _transform->run(fst, lst, info.thread_id);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+Status CpuWinogradConv2dTransformInputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::validate(const ITensorInfo *input, const ITensorInfo *output,
+ const WinogradInfo &winograd_info)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_input_trans(input, output, winograd_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_winograd_input_trans(input->clone().get(), output->clone().get(), winograd_info).first);
+
+ return Status{};
+}
+
+template class CpuWinogradConv2dTransformInputKernel<float, 2, 2, 3, 3>;
+template class CpuWinogradConv2dTransformInputKernel<float, 4, 4, 3, 3>;
+template class CpuWinogradConv2dTransformInputKernel<float, 2, 2, 5, 5>;
+template class CpuWinogradConv2dTransformInputKernel<float, 1, 6, 1, 3>;
+template class CpuWinogradConv2dTransformInputKernel<float, 6, 1, 3, 1>;
+
+template class CpuWinogradConv2dTransformInputKernel<float, 1, 4, 1, 5>;
+template class CpuWinogradConv2dTransformInputKernel<float, 4, 1, 5, 1>;
+template class CpuWinogradConv2dTransformInputKernel<float, 1, 2, 1, 7>;
+template class CpuWinogradConv2dTransformInputKernel<float, 2, 1, 7, 1>;
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class CpuWinogradConv2dTransformInputKernel<__fp16, 4, 4, 3, 3>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+// Output transform
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+unsigned int CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_storage_size(
+ int num_batches, /* Number of batches in the output tensor. */
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ int num_output_channels /* Number of feature maps in the output tensor. */
+) const
+{
+ // Construct shapes for the input and kernel tensors.
+ const Tensor4DShape input_shape(num_batches, num_rows, num_cols, 1);
+ const KernelShape kern_shape(num_output_channels, KernelRows, KernelCols, 1);
+ // Return the size, converted into units of TOut
+ return static_cast<unsigned int>(
+ WinogradConv::get_output_storage_size(num_batches, num_rows, num_cols, num_output_channels) / sizeof(T));
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::CpuWinogradConv2dTransformOutputKernel()
+ : _transform(nullptr), _matrix_stride(0), _matrix_row_stride(0)
+{
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+unsigned int CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_working_space_size(unsigned int num_threads) const
+{
+ return _transform->get_working_space_size(num_threads) / sizeof(T);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+int CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_matrix_stride(
+ int num_batches, /* Number of batches in the output tensor. */
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ int num_output_channels /* Number of feature maps in the output tensor. */
+) const
+{
+ return WinogradConv::get_output_matrix_stride(num_batches, num_rows, num_cols, num_output_channels);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+std::pair<unsigned int, unsigned int> CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::get_output_shape(
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ bool padding_same) const
+{
+ return WinogradConv::get_output_shape(std::make_pair<unsigned int, unsigned int>(num_rows, num_cols), padding_same);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+void CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::configure(
+ const ITensorInfo *biases,
+ const ITensorInfo *transformed_output,
+ const int matrix_stride,
+ ITensorInfo *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ ITensorInfo *workspace,
+ const arm_gemm::Activation &activation)
+{
+ ARM_COMPUTE_UNUSED(biases, transformed_output, output_nhwc, num_batches, num_rows, num_cols, workspace, activation);
+
+ _matrix_stride = matrix_stride;
+ _matrix_row_stride = roundup(num_channels, WinogradConv::N_BLOCK);
+
+ // We don't have the biases buffer at this stage as it hasn't been allocated, we pass in nullptr OutputTransform is only used here to compute the window
+ _transform = std::make_unique<OutputTransform>(num_batches, num_rows, num_cols, num_channels, activation);
+ Window win;
+ auto win_last = _transform->get_window();
+ win.set(Window::DimX, Window::Dimension(0, win_last, 1));
+
+ ICpuKernel::configure(win);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+void CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON(tensors.empty());
+
+ const ITensor *biases = tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ const ITensor *transformed_output = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ ITensor *workspace = tensors.get_tensor(TensorType::ACL_INT);
+ ITensor *dst_nhwc = tensors.get_tensor(TensorType::ACL_DST);
+
+ const int out_batch_stride = dst_nhwc->info()->strides_in_bytes()[3] / sizeof(T);
+ const int out_row_stride = dst_nhwc->info()->strides_in_bytes()[2] / sizeof(T);
+ const int out_col_stride = dst_nhwc->info()->strides_in_bytes()[1] / sizeof(T);
+
+ _transform->set_input_matrices(transformed_output->buffer(), _matrix_stride, _matrix_row_stride);
+ _transform->set_bias((biases ? reinterpret_cast<T *>(biases->buffer() + biases->info()->offset_first_element_in_bytes()) : nullptr));
+ _transform->set_output_tensor(dst_nhwc->buffer() + dst_nhwc->info()->offset_first_element_in_bytes(), out_batch_stride, out_row_stride, out_col_stride);
+ _transform->set_working_space(workspace->buffer());
+
+ // The code below cannot be moved to configure because biases hasn't been allocated at that point
+ const size_t fst = window.x().start();
+ const size_t lst = window.x().end();
+ _transform->run(fst, lst, info.thread_id);
+}
+
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+Status CpuWinogradConv2dTransformOutputKernel<T, OutputTileRows, OutputTileCols, KernelRows, KernelCols>::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
+ const WinogradInfo &winograd_info)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_output_trans(input, (bias != nullptr ? bias->clone().get() : nullptr), output, winograd_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_winograd_output_trans(input->clone().get(), output->clone().get(), winograd_info).first);
+
+ return Status{};
+}
+
+template class CpuWinogradConv2dTransformOutputKernel<float, 2, 2, 3, 3>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 4, 4, 3, 3>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 2, 2, 5, 5>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 1, 6, 1, 3>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 6, 1, 3, 1>;
+
+template class CpuWinogradConv2dTransformOutputKernel<float, 1, 4, 1, 5>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 4, 1, 5, 1>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 1, 2, 1, 7>;
+template class CpuWinogradConv2dTransformOutputKernel<float, 2, 1, 7, 1>;
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+template class CpuWinogradConv2dTransformOutputKernel<__fp16, 4, 4, 3, 3>;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/core/cpu/kernels/CpuWinogradConv2dKernel.h b/src/core/cpu/kernels/CpuWinogradConv2dKernel.h
new file mode 100644
index 0000000000..b5a29ffd02
--- /dev/null
+++ b/src/core/cpu/kernels/CpuWinogradConv2dKernel.h
@@ -0,0 +1,575 @@
+/*
+ * Copyright (c) 2017-2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CPUWINOGRADCONV2DKERNEL_H
+#define ARM_COMPUTE_CPUWINOGRADCONV2DKERNEL_H
+
+#include "src/core/NEON/kernels/convolution/common/convolution.hpp"
+#include "src/core/NEON/kernels/convolution/common/tensor.hpp"
+#include "src/core/cpu/ICpuKernel.h"
+
+#include "src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp"
+
+namespace arm_compute
+{
+namespace cpu
+{
+/** Interface for the kernel to perform Winograd input transform. */
+class ICpuWinogradConv2dTransformInputKernel : public ICpuKernel
+{
+public:
+ /** Get the working space required to perform the transformation.
+ *
+ * Note, the working space is only required when performing the
+ * transformation - hence it can be reused whenever the transformation is
+ * not running.
+ *
+ * @param num_threads The greatest number of threads that will be used to execute the transform.
+ * @return Size of working space required in bytes.
+ */
+ virtual unsigned int get_working_space_size(unsigned int num_threads) const = 0;
+
+ /** Determine how much memory (in units of TIn) to allocate for the
+ * transformed input.
+ *
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_channels Number of feature maps in the input tensor.
+ * @param[in] num_rows Number of rows in each feature map.
+ * @param[in] num_cols Number of columns in each feature map.
+ * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
+ *
+ * @return Storage size (in units of TIn) required.
+ */
+ virtual unsigned int get_input_storage_size(int num_batches, int num_channels, int num_rows, int num_cols, bool same_padding) const = 0;
+
+ /** Gets the stride between matrices in the input worspace
+ *
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_channels Number of feature maps in the input tensor.
+ * @param[in] num_rows Number of rows in each feature map.
+ * @param[in] num_cols Number of columns in each feature map.
+ * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
+ *
+ * @return Stride expressed in bytes.
+ */
+ virtual int get_matrix_stride(int num_batches, int num_channels, int num_rows, int num_cols, bool same_padding) const = 0;
+
+ /** Configure the output transform kernel.
+ *
+ * @param[in] input_nhwc Input tensor in NHWC data layout format.
+ * @param[in] num_batches Number of batches in input tensor.
+ * @param[in] num_rows Number of rows in input tensor.
+ * @param[in] num_cols Number of columns in input tensor.
+ * @param[in] num_channels Number of channels in input tensor.
+ * @param[in] padding Padding type.
+ * @param[out] output Base of output matrices.
+ * @param[in] matrix_stride Stride between output matrices.
+ * @param[in] workspace Tensor to be used as the working space during the computation.
+ */
+ virtual void configure(const ITensorInfo *input_nhwc, const int num_batches, const int num_rows, const int num_cols, const int num_channels,
+ const PaddingType padding, ITensorInfo *output, const int matrix_stride, ITensorInfo *workspace) = 0;
+
+ /** Destructor */
+ virtual ~ICpuWinogradConv2dTransformInputKernel()
+ {
+ }
+};
+
+/** Kernel to perform Winograd input transform. */
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class CpuWinogradConv2dTransformInputKernel : public ICpuWinogradConv2dTransformInputKernel
+{
+public:
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CpuWinogradConv2dTransformInputKernel(const CpuWinogradConv2dTransformInputKernel &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CpuWinogradConv2dTransformInputKernel &operator=(const CpuWinogradConv2dTransformInputKernel &) = delete;
+ /** Allow instances of this class to be moved */
+ CpuWinogradConv2dTransformInputKernel(CpuWinogradConv2dTransformInputKernel &&) = default;
+ /** Allow instances of this class to be moved */
+ CpuWinogradConv2dTransformInputKernel &operator=(CpuWinogradConv2dTransformInputKernel &&) = default;
+ /** Default destructor */
+ ~CpuWinogradConv2dTransformInputKernel() = default;
+
+ /** Determine how much memory (in units of TIn) to allocate for the
+ * transformed input.
+ *
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_channels Number of feature maps in the input tensor.
+ * @param[in] num_rows Number of rows in each feature map.
+ * @param[in] num_cols Number of columns in each feature map.
+ * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
+ *
+ * @return Storage size (in units of TIn) required.
+ */
+ unsigned int get_input_storage_size(
+ int num_batches,
+ int num_channels,
+ int num_rows,
+ int num_cols,
+ bool same_padding) const override;
+
+ /** Get the working space required to perform the transformation.
+ *
+ * Note, the working space is only required when performing the
+ * transformation - hence it can be reused whenever the transformation is
+ * not running.
+ *
+ * @param[in] num_threads The greatest number of threads that will be used to execute the transform.
+ *
+ * @return Size of working space required in bytes.
+ */
+ unsigned int get_working_space_size(unsigned int num_threads) const override;
+
+ /** Gets the stride between matrices in the input worspace
+ *
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_channels Number of feature maps in the input tensor.
+ * @param[in] num_rows Number of rows in each feature map.
+ * @param[in] num_cols Number of columns in each feature map.
+ * @param[in] same_padding Use "SAME" padding, otherwise use "VALID".
+ *
+ * @return Stride expressed in bytes.
+ */
+ int get_matrix_stride(
+ int num_batches,
+ int num_channels,
+ int num_rows,
+ int num_cols,
+ bool same_padding) const override;
+
+ /** Default constructor */
+ CpuWinogradConv2dTransformInputKernel();
+
+ const char *name() const override
+ {
+ return "CpuWinogradConv2dTransformInputKernel";
+ }
+
+ /** Configure the output transform kernel.
+ *
+ * @param[in] input_nhwc Input tensor. Data types supported: F16/F32. Layout supported NHWC.
+ * @param[in] num_batches Number of batches in input tensor.
+ * @param[in] num_rows Number of rows in input tensor.
+ * @param[in] num_cols Number of columns in input tensor.
+ * @param[in] num_channels Number of channels in input tensor.
+ * @param[in] padding Padding type.
+ * @param[out] output Base of output matrices.
+ * @param[in] matrix_stride Stride between output matrices.
+ * @param[in] workspace Tensor to be used as the working space during the computation.
+ */
+ void configure(
+ const ITensorInfo *input_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ const PaddingType padding,
+ ITensorInfo *output,
+ const int matrix_stride,
+ ITensorInfo *workspace) override;
+
+ // Inherited methods overridden:
+ void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+
+ /** Winograd base kernel */
+ using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, winograd::WinogradRoots::Integers>;
+ /** Winograd convolution kernel */
+ using WinogradConv = typename WinogradBase::template Convolution<T, T>;
+
+ /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformInputKernel
+ *
+ * @param[in] input First tensor input info. Data types supported: F16/F32.
+ * @param[in] output Output tensor info. Data types supported: same as @p input.
+ * @param[in] winograd_info Contains Winograd's information described in @ref WinogradInfo
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info);
+
+private:
+ using InputTransform = typename WinogradBase::template InputTransform<T, T>;
+
+ std::unique_ptr<InputTransform> _transform{ nullptr };
+ int _num_channels; /**< Number of channels in input tensor. */
+ int _matrix_stride; /**< Stride between output matrices. */
+};
+
+/** Interface for the kernel to perform Winograd output transform. */
+class ICpuWinogradConv2dTransformOutputKernel : public ICpuKernel
+{
+public:
+ /** Get the working space required to perform the transformation.
+ *
+ * Note, the working space is only required when performing the
+ * transformation - hence it can be reused whenever the transformation is
+ * not running.
+ *
+ * @param[in] num_threads The greatest number of threads that will be used to execute the transform.
+ *
+ * @return Size of working space required in bytes.
+ */
+ virtual unsigned int get_working_space_size(unsigned int num_threads) const = 0;
+
+ /** Determine how much memory (in units of TOut) to allocate for the
+ * (Winograd domain) output.
+ *
+ * @param[in] num_batches Number of batches in the output tensor.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] num_output_channels Number of feature maps in the output tensor.
+ *
+ * @return Storage size (in units of TOut) required.
+ */
+ virtual unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0;
+
+ /** Gets the stride between matrices in the output worspace
+ *
+ * @param[in] num_batches Number of batches in the output tensor.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] num_output_channels Number of feature maps in the output tensor.
+ *
+ * @return Stride expressed in bytes.
+ */
+ virtual int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0;
+
+ /** Get the output shape of a convolution.
+ *
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] padding_same True if padding is SAME, false otherwise
+ *
+ * @return Shape of the output tensor
+ */
+ virtual std::pair<unsigned int, unsigned int> get_output_shape(
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ bool padding_same /* True if padding is SAME, false otherwise */
+ ) const = 0;
+
+ /** Configure the output transform kernel.
+ *
+ * @param[in] biases Pointer to the biases tensor.
+ * @param[in] transformed_output Pointer to working space for the output tensor in the Winograd domain.
+ * @param[in] matrix_stride Output matrix stride, can be computed with winograd::WinogradGEMM<2, 2, 3, 3>::Convolution<float, float>::get_output_matrix_stride()
+ * @param[out] output_nhwc Pointer to a tensor in NHWC data layout ordered output tensor, in the spatial domain.
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_rows Number of rows in output tensor.
+ * @param[in] num_cols Number of columns in output tensor.
+ * @param[in] num_channels Number of feature maps in the output tensor.
+ * @param[in] workspace Tensor to be used as the working space during the computation.
+ * @param[in] activation Activation to be used
+ */
+ virtual void configure(
+ const ITensorInfo *biases,
+ const ITensorInfo *transformed_output,
+ const int matrix_stride,
+ ITensorInfo *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ ITensorInfo *workspace,
+ const arm_gemm::Activation &activation) = 0;
+
+ virtual ~ICpuWinogradConv2dTransformOutputKernel()
+ {
+ }
+};
+
+/** Kernel to perform Winograd output transform. */
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class CpuWinogradConv2dTransformOutputKernel : public ICpuWinogradConv2dTransformOutputKernel
+{
+public:
+ const char *name() const override
+ {
+ return "CpuWinogradConv2dTransformOutputKernel";
+ }
+ /** Constructor */
+ CpuWinogradConv2dTransformOutputKernel();
+
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CpuWinogradConv2dTransformOutputKernel(const CpuWinogradConv2dTransformOutputKernel &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CpuWinogradConv2dTransformOutputKernel &operator=(const CpuWinogradConv2dTransformOutputKernel &) = delete;
+ /** Allow instances of this class to be moved */
+ CpuWinogradConv2dTransformOutputKernel(CpuWinogradConv2dTransformOutputKernel &&) = default;
+ /** Allow instances of this class to be moved */
+ CpuWinogradConv2dTransformOutputKernel &operator=(CpuWinogradConv2dTransformOutputKernel &&) = default;
+ /** Default destructor */
+ ~CpuWinogradConv2dTransformOutputKernel() = default;
+
+ // Inherited methods overridden:
+ /** Determine how much memory (in units of TOut) to allocate for the
+ * (Winograd domain) output.
+ *
+ * @param[in] num_batches Number of batches in the output tensor.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] num_output_channels Number of feature maps in the output tensor.
+ *
+ * @return Storage size (in units of TOut) required.
+ */
+ unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const override;
+
+ /** Gets the stride between matrices in the output worspace
+ *
+ * @param[in] num_batches Number of batches in the output tensor.
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] num_output_channels Number of feature maps in the output tensor.
+ *
+ * @return Stride expressed in bytes.
+ */
+ int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const override;
+ /** Get the output shape of a convolution.
+ *
+ * @param[in] num_rows Number of rows in each feature map of the input tensor.
+ * @param[in] num_cols Number of columns in each feature map of the input tensor.
+ * @param[in] padding_same True if padding is SAME, false otherwise
+ *
+ * @return Shape of the output tensor
+ */
+ std::pair<unsigned int, unsigned int> get_output_shape(
+ int num_rows, /* Number of rows in each feature map of the input tensor. */
+ int num_cols, /* Number of columns in each feature map of the input tensor. */
+ bool padding_same) const override;
+
+ /** Get the working space required to perform the transformation.
+ *
+ * Note, the working space is only required when performing the
+ * transformation - hence it can be reused whenever the transformation is
+ * not running.
+ *
+ * @param[in] num_threads The greatest number of threads that will be used to execute the transform.
+ *
+ * @return Size of working space required in bytes.
+ */
+ unsigned int get_working_space_size(unsigned int num_threads) const override;
+
+ /** Configure the output transform kernel.
+ *
+ * @param[in] biases Pointer to the biases tensor.
+ * @param[in] transformed_output Pointer to working space for the output tensor in the Winograd domain.
+ * @param[in] matrix_stride Output matrix stride, can be computed with winograd::WinogradGEMM<2, 2, 3, 3>::Convolution<float, float>::get_output_matrix_stride()
+ * @param[out] output_nhwc Pointer to a tensor with NHWC data layout, in the spatial domain.
+ * @param[in] num_batches Number of batches in the input tensor.
+ * @param[in] num_rows Number of rows in output tensor.
+ * @param[in] num_cols Number of columns in output tensor.
+ * @param[in] num_channels Number of feature maps in the output tensor.
+ * @param[in] workspace Tensor to be used as the working space during the computation.
+ * @param[in] activation Activation to be used
+ */
+ void configure(
+ const ITensorInfo *biases,
+ const ITensorInfo *transformed_output,
+ const int matrix_stride,
+ ITensorInfo *output_nhwc,
+ const int num_batches,
+ const int num_rows,
+ const int num_cols,
+ const int num_channels,
+ ITensorInfo *workspace,
+ const arm_gemm::Activation &activation) override;
+
+ void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+
+ /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformOutputKernel
+ *
+ * @param[in] input Source tensor info with shape [C, N, 16, batches] or [C, N, 36, batches]. Data types supported: F16/F32.
+ * @param[in] bias Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. It can be a nullptr. Data type supported: as @p input
+ * @param[in] output Destination tensor info with shape [output_convolved_dims.width, output_convolved_dims.height, C, batches]. Data type supported: same as @p input
+ * @param[in] winograd_info Contains Winograd's information described in @ref WinogradInfo
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info);
+
+private:
+ using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, winograd::WinogradRoots::Integers>;
+ using WinogradConv = typename WinogradBase::template Convolution<T, T>;
+ using OutputTransform = typename WinogradBase::template OutputTransform<T, T>;
+
+ std::unique_ptr<OutputTransform> _transform{ nullptr };
+ int _matrix_stride;
+ int _matrix_row_stride;
+};
+
+/** Interface for the kernel to perform Winograd weights transform. */
+class ICpuWinogradConv2dTransformWeightsKernel : public ICpuKernel
+{
+public:
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ ICpuWinogradConv2dTransformWeightsKernel(const ICpuWinogradConv2dTransformWeightsKernel &) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ ICpuWinogradConv2dTransformWeightsKernel &operator=(const ICpuWinogradConv2dTransformWeightsKernel &) = default;
+ /** Allow instances of this class to be moved */
+ ICpuWinogradConv2dTransformWeightsKernel(ICpuWinogradConv2dTransformWeightsKernel &&) = default;
+ /** Allow instances of this class to be moved */
+ ICpuWinogradConv2dTransformWeightsKernel &operator=(ICpuWinogradConv2dTransformWeightsKernel &&) = default;
+
+ ICpuWinogradConv2dTransformWeightsKernel()
+ {
+ }
+ virtual ~ICpuWinogradConv2dTransformWeightsKernel()
+ {
+ }
+ /** Determine how much memory (in units of T) to allocate for the
+ * transformed weights.
+ *
+ * @param[in] num_output_channels Number of output feature maps.
+ * @param[in] num_input_channels Number of input feature maps.
+ *
+ * @return Storage size (in units of T) required.
+ */
+ virtual unsigned int get_weight_storage_size(int num_output_channels, int num_input_channels) const = 0;
+ /** Gets the stride between matrices in the kernel worspace
+ *
+ * @param[in] num_output_channels Number of output feature maps.
+ * @param[in] num_input_channels Number of input feature maps.
+ *
+ * @return Stride expressed in bytes.
+ */
+ virtual int get_matrix_stride(int num_output_channels, int num_input_channels) const = 0;
+
+ /** Configure the weights transform kernel.
+ *
+ * @param[in] weights_hwio Pointer to the weights tensor info
+ * @param[out] output Pointer to working space for the output tensor in the Winograd domain.
+ * @param[in] matrix_stride Stride across matrices in the output workspace.
+ * @param[in] num_output_channels Number of filters.
+ * @param[in] num_input_channels Number of channels in each filter.
+ */
+
+ virtual void configure(const ITensorInfo *weights_hwio, ITensorInfo *output, const int matrix_stride, const int num_output_channels, const int num_input_channels) = 0;
+
+ /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformWeightsKernel
+ *
+ * @param[in] input First tensor input info. Data types supported: F16/F32.
+ * @param[in] weights Weights tensor info. Data types supported: same as @p input.
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *weights);
+};
+
+/** Kernel to perform Winograd weights transform. */
+template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class CpuWinogradConv2dTransformWeightsKernel final : public ICpuWinogradConv2dTransformWeightsKernel
+{
+public:
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CpuWinogradConv2dTransformWeightsKernel(const CpuWinogradConv2dTransformWeightsKernel &) = delete;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CpuWinogradConv2dTransformWeightsKernel &operator=(const CpuWinogradConv2dTransformWeightsKernel &) = delete;
+ /** Allow instances of this class to be moved */
+ CpuWinogradConv2dTransformWeightsKernel(CpuWinogradConv2dTransformWeightsKernel &&) = default;
+ /** Allow instances of this class to be moved */
+ CpuWinogradConv2dTransformWeightsKernel &operator=(CpuWinogradConv2dTransformWeightsKernel &&) = default;
+ /** Default destructor */
+ ~CpuWinogradConv2dTransformWeightsKernel() = default;
+
+ /** Default constructor. */
+ CpuWinogradConv2dTransformWeightsKernel();
+ const char *name() const override
+ {
+ return "CpuWinogradConv2dTransformWeightsKernel";
+ }
+
+ /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformWeightsKernel
+ *
+ * @param[in] input Source tensor info. The input is a 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM] (NCHW data layout).
+ * kernel_x must be 3 and equal to kernel_y. Data types supported: F16/F32.
+ * @param[in] output Destination tensor info. The output is a 3D tensor with dimensions [OFM, IFM, 16] or [OFM, IFM, 36]. Data type supported: same as @p input
+ * @param[in] winograd_info Contains Winograd's information described in @ref WinogradInfo
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info);
+
+ // Inherited methods overridden:
+
+#ifndef DOXYGEN_SKIP_THIS
+ /** Configure the weights transform kernel.
+ *
+ * @param[in] weights_hwio Pointer to the weights tensor info
+ * @param[out] output Pointer to working space for the output tensor in the Winograd domain.
+ * @param[in] matrix_stride Stride across matrices in the output workspace.
+ * @param[in] num_output_channels Number of filters.
+ * @param[in] num_input_channels Number of channels in each filter.
+ */
+ void configure(const ITensorInfo *weights_hwio, ITensorInfo *output, const int matrix_stride, const int num_output_channels, const int num_input_channels) override;
+#endif /* DOXYGEN_SKIP_THIS */
+
+ /** Determine how much memory (in units of T) to allocate for the
+ * transformed weights.
+ *
+ * @param[in] num_output_channels Number of output feature maps.
+ * @param[in] num_input_channels Number of input feature maps.
+ *
+ * @return Storage size (in units of T) required.
+ */
+ unsigned int get_weight_storage_size(int num_output_channels, int num_input_channels) const override;
+
+ /** Gets the stride between matrices in the input worspace
+ *
+ * @param[in] num_output_channels Number of output feature maps.
+ * @param[in] num_input_channels Number of input feature maps.
+ *
+ * @return Stride expressed in bytes.
+ */
+ int get_matrix_stride(int num_output_channels, int num_input_channels) const override;
+ void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+ bool is_parallelisable() const override;
+
+private:
+ using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, winograd::WinogradRoots::Integers>;
+ using WinogradConv = typename WinogradBase::template Convolution<T, T>;
+ using WeightsTransform = typename WinogradBase::template WeightsTransform<T, T>;
+
+ std::unique_ptr<WeightsTransform> _transform{ nullptr };
+ int _num_output_channels;
+ int _matrix_stride;
+};
+
+/** Kernel to perform Winograd. */
+template <typename TIn, typename TOut, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols>
+class CpuWinogradConv2dConfiguration
+{
+public:
+ /** Winograd base kernel */
+ using WinogradBase = winograd::WinogradGEMM<OutputTileRows, OutputTileCols, KernelRows, KernelCols, winograd::WinogradRoots::Integers>;
+ /** Winograd convolution kernel */
+
+ using WinogradConv = typename WinogradBase::template Convolution<TIn, TOut>;
+
+ using TransformInputKernel = CpuWinogradConv2dTransformInputKernel<TIn, OutputTileRows, OutputTileCols, KernelRows, KernelCols>;
+ using TransformWeightsKernel = CpuWinogradConv2dTransformWeightsKernel<TIn, OutputTileRows, OutputTileCols, KernelRows, KernelCols>;
+ using TransformOutputKernel = CpuWinogradConv2dTransformOutputKernel<TOut, OutputTileRows, OutputTileCols, KernelRows, KernelCols>;
+};
+
+} // namespace cpu
+} // namespace arm_compute
+#endif /*ARM_COMPUTE_CPUWINOGRADCONV2DKERNEL_H*/