diff options
Diffstat (limited to 'src/core/NEON/kernels/convolution/winograd/transforms/output_6_3_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/convolution/winograd/transforms/output_6_3_fp32.cpp | 63 |
1 files changed, 28 insertions, 35 deletions
diff --git a/src/core/NEON/kernels/convolution/winograd/transforms/output_6_3_fp32.cpp b/src/core/NEON/kernels/convolution/winograd/transforms/output_6_3_fp32.cpp index 16667ccdb6..58bed71a47 100644 --- a/src/core/NEON/kernels/convolution/winograd/transforms/output_6_3_fp32.cpp +++ b/src/core/NEON/kernels/convolution/winograd/transforms/output_6_3_fp32.cpp @@ -23,38 +23,32 @@ */ #include "arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp" -#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_gemm.hpp" +#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_output_transform.hpp" #include "arm_compute/core/NEON/kernels/convolution/common/arm.hpp" -namespace winograd -{ - -using Transform = WinogradGEMM<1, 6, 1, 3>::OutputTransform<float>; -using TransformTransposed = WinogradGEMM<6, 1, 3, 1>::OutputTransform<float>; - -template <> -template <> -int Transform::ops_performed(const Tensor4DShape &shape) +namespace { - (void) shape; - return 0; // TODO -} -template <> -template <> -template <int pad_bottom, int pad_right> -void Transform::process_tile( +template <bool Specialized, int PadRight=0> +void winograd_output_transform_6_3_fp32_process_tile( const int n_channels, const float* const matrix_base, const int matrix_stride, const float* const biases, float* const output, const int output_row_stride, - const int output_col_stride + const int output_col_stride, + const int _pad_bottom, + const int _pad_right ) { (void) output_row_stride; - constexpr int cells_j = output_tile_cols - pad_right; + (void) _pad_bottom; + constexpr int output_tile_cols = 6; + constexpr int inner_tile_cols = 8; + + const int pad_right = Specialized ? PadRight : _pad_right; + const int cells_j = output_tile_cols - pad_right; // Construct a map to the output cells float *outptrs[cells_j]; @@ -162,25 +156,24 @@ void Transform::process_tile( } } -template <> -template <> -const Transform::TileFn Transform::tile_fns[max_pad_bottom][max_pad_right] = +} // namespace (anonymous) + +namespace winograd { - { - Transform::template process_tile<0, 0>, - Transform::template process_tile<0, 1>, - Transform::template process_tile<0, 2>, - Transform::template process_tile<0, 3>, - Transform::template process_tile<0, 4>, - Transform::template process_tile<0, 5>, - }, -}; +using Tiles = OutputTransformImplTiles<1, 3, 1, 8, float>; template <> -template <> -const TransformTransposed::TileFn TransformTransposed::tile_fns[max_pad_bottom][max_pad_right] = {}; +const Tiles::TileFn Tiles::tilefn_unpadded = winograd_output_transform_6_3_fp32_process_tile<true>; +template <> +const Tiles::TileFn Tiles::tilefn_right_padded[n_pad_right] = { + winograd_output_transform_6_3_fp32_process_tile<true, 1>, + winograd_output_transform_6_3_fp32_process_tile<true, 2>, + winograd_output_transform_6_3_fp32_process_tile<true, 3>, + winograd_output_transform_6_3_fp32_process_tile<true, 4>, + winograd_output_transform_6_3_fp32_process_tile<true, 5>, +}; -template struct WinogradGEMM<1, 6, 1, 3>::OutputTransform<float>; -template struct WinogradGEMM<6, 1, 3, 1>::OutputTransform<float>; +template class OutputTransform<1, 3, 1, 8, float>; +template class OutputTransform<3, 1, 8, 1, float>; } // namespace winograd |