diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h index 1e7ca64b8c..3ab3aa792b 100644 --- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h @@ -25,17 +25,34 @@ #define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__ #include "arm_compute/core/NEON/INEKernel.h" - -#include "arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp" +#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp" namespace arm_compute { class ITensor; +class NEWinogradLayerKernel; +class Winograd3x3F32 +{ +public: + friend class NEWinogradLayerKernel; + Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage); + ~Winograd3x3F32(); + std::pair<void *, void *> get_nhwc_ptrs(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space); + void transform_weights(const void *const kernel, void *transform_working_space); + void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space); + void reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output); + void nchw2nhwc(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, const void *const input); + void nhwc2nchw(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, void *const output); + +private: + class Private; + std::unique_ptr<Private> _pimpl; +}; class NEWinogradLayerKernel : public INEKernel { public: - using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM<float, float>; + // using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM<float, float>; /** Constructor */ NEWinogradLayerKernel(); @@ -61,9 +78,22 @@ public: // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; + /* Get the memory required to instantiate a new Winograd operator. + */ + static size_t get_kernel_storage_size(const KernelShape &shape); + + /* Get the memory required to apply a Winograd operator to some input. + */ + static size_t get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding); + + /* Get the memory required to transform the kernel. + */ + static size_t get_kernel_transform_working_size(const KernelShape &shape); + protected: Winograd3x3F32 *_convolver; - ITensor *_output; + // std::unique_ptr<Winograd3x3F32> _conv; + ITensor *_output; }; } // namespace arm_compute |