aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h')
-rw-r--r--arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h38
1 files changed, 34 insertions, 4 deletions
diff --git a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
index 1e7ca64b8c..3ab3aa792b 100644
--- a/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEWinogradLayerKernel.h
@@ -25,17 +25,34 @@
#define __ARM_COMPUTE_NEGEMMWINOGRADLAYERKERNEL_H__
#include "arm_compute/core/NEON/INEKernel.h"
-
-#include "arm_compute/core/NEON/kernels/winograd/winograd_shim_nchw.hpp"
+#include "arm_compute/core/NEON/kernels/winograd/tensor.hpp"
namespace arm_compute
{
class ITensor;
+class NEWinogradLayerKernel;
+class Winograd3x3F32
+{
+public:
+ friend class NEWinogradLayerKernel;
+ Winograd3x3F32(const KernelShape &kernel_shape, const Tensor4DShape input_shape, const PaddingType padding_type, void *kernel_storage);
+ ~Winograd3x3F32();
+ std::pair<void *, void *> get_nhwc_ptrs(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space);
+ void transform_weights(const void *const kernel, void *transform_working_space);
+ void reshape_input(const Tensor4DShape &input_shape, const PaddingType padding_type, const void *const input, void *working_space);
+ void reshape_output(const Tensor4DShape &input_shape, const PaddingType padding_type, void *const output);
+ void nchw2nhwc(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, const void *const input);
+ void nhwc2nchw(const Tensor4DShape &input_shape, const PaddingType padding_type, void *working_space, void *const output);
+
+private:
+ class Private;
+ std::unique_ptr<Private> _pimpl;
+};
class NEWinogradLayerKernel : public INEKernel
{
public:
- using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM<float, float>;
+ // using Winograd3x3F32 = winograd_shim_nchw::Winograd2x2_3x3GEMM<float, float>;
/** Constructor */
NEWinogradLayerKernel();
@@ -61,9 +78,22 @@ public:
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
+ /* Get the memory required to instantiate a new Winograd operator.
+ */
+ static size_t get_kernel_storage_size(const KernelShape &shape);
+
+ /* Get the memory required to apply a Winograd operator to some input.
+ */
+ static size_t get_working_space_size(const Tensor4DShape &input_shape, const KernelShape &k_shape, const PaddingType padding);
+
+ /* Get the memory required to transform the kernel.
+ */
+ static size_t get_kernel_transform_working_size(const KernelShape &shape);
+
protected:
Winograd3x3F32 *_convolver;
- ITensor *_output;
+ // std::unique_ptr<Winograd3x3F32> _conv;
+ ITensor *_output;
};
} // namespace arm_compute