From d216f570750b8ccde3754c4aef53fc20a90cb32d Mon Sep 17 00:00:00 2001 From: Freddie Liardet Date: Tue, 3 Aug 2021 15:57:32 +0100 Subject: Update cpu depthwise kernels Resolves: COMPMID-4688 Signed-off-by: Freddie Liardet Change-Id: I9e22f967f5b7ccaebff2fc49f0253f621d62d820 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6030 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- ...4_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp | 71 +++++++++++++++++++--- 1 file changed, 62 insertions(+), 9 deletions(-) (limited to 'src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp') diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp index 88f20bb125..a888eb5776 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst.hpp @@ -28,21 +28,25 @@ #pragma once +#if __aarch64__ + namespace arm_conv { namespace depthwise { void a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float); void a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float); -struct a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst +class a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy { - typedef float bias_type; - typedef float input_type; - typedef float weight_type; - typedef float return_type; - + private: typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float); + indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl; + typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float); + direct_kern_type m_direct_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl; + + public: + typedef float return_type; constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None; @@ -58,11 +62,60 @@ struct a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst constexpr static unsigned int input_rows = 4; constexpr static unsigned int input_cols = 4; - indirect_kern_type indirect_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_indirect_impl; - direct_kern_type direct_kernel = a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst_direct_impl; - a64_fp32_nhwc_3x3_s1_output2x2_mla_depthfirst(const CPUInfo *) {} + + arm_gemm::VLType get_vl_type(void) const override { return vl_type; } + + unsigned int get_kernel_rows(void) const override { return kernel_rows; } + unsigned int get_kernel_cols(void) const override { return kernel_cols; } + + unsigned int get_stride_rows(void) const override { return stride_rows; } + unsigned int get_stride_cols(void) const override { return stride_cols; } + + unsigned int get_output_rows(void) const override { return output_rows; } + unsigned int get_output_cols(void) const override { return output_cols; } + + unsigned int get_input_rows(void) const override { return input_rows; } + unsigned int get_input_cols(void) const override { return input_cols; } + + void indirect_kernel( + const void *const *const input_ptrs, + void *const *const outptrs, + const void *params, + unsigned int n_channels, + const void *activation_min, + const void *activation_max + ) const override + { + m_indirect_kernel( + reinterpret_cast(input_ptrs), + reinterpret_cast(outptrs), + params, n_channels, + *static_cast(activation_min), + *static_cast(activation_max) + ); + } + + void direct_kernel( + const unsigned int n_tile_rows, const unsigned int n_tile_cols, + const void *inptr, int64_t ld_input_row, int64_t ld_input_col, + void *outptr, int64_t ld_output_row, int64_t ld_output_col, + const void *params, unsigned int n_channels, + const void *activation_min, const void *activation_max + ) const override + { + m_direct_kernel( + n_tile_rows, n_tile_cols, + static_cast(inptr), ld_input_row, ld_input_col, + static_cast(outptr), ld_output_row, ld_output_col, + params, n_channels, + *static_cast(activation_min), + *static_cast(activation_max) + ); + } }; } // namespace depthwise } // namespace arm_conv + +#endif // __aarch64__ -- cgit v1.2.1