From d216f570750b8ccde3754c4aef53fc20a90cb32d Mon Sep 17 00:00:00 2001 From: Freddie Liardet Date: Tue, 3 Aug 2021 15:57:32 +0100 Subject: Update cpu depthwise kernels Resolves: COMPMID-4688 Signed-off-by: Freddie Liardet Change-Id: I9e22f967f5b7ccaebff2fc49f0253f621d62d820 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6030 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- ...e_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp | 71 ++++++++++++++++++---- 1 file changed, 60 insertions(+), 11 deletions(-) (limited to 'src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp') diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp index 236f9bf43a..4a9bd33a1e 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp @@ -28,7 +28,7 @@ #pragma once -#if defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS) +#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS) namespace arm_conv { namespace depthwise { @@ -36,15 +36,17 @@ namespace depthwise { void sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16); void sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16); -struct sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst +class sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy { - typedef __fp16 bias_type; - typedef __fp16 input_type; - typedef __fp16 weight_type; - typedef __fp16 return_type; - + private: typedef void (*indirect_kern_type)(const __fp16 *const *const, __fp16 *const *const, const void *, unsigned int, const __fp16, const __fp16); + indirect_kern_type m_indirect_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl; + typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const __fp16 *, int64_t, int64_t, __fp16 *, int64_t, int64_t, const void *, unsigned int, const __fp16, const __fp16); + direct_kern_type m_direct_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl; + + public: + typedef __fp16 return_type; constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE; @@ -60,13 +62,60 @@ struct sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst constexpr static unsigned int input_rows = 6; constexpr static unsigned int input_cols = 6; - indirect_kern_type indirect_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl; - direct_kern_type direct_kernel = sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl; - sve_fp16_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {} + + arm_gemm::VLType get_vl_type(void) const override { return vl_type; } + + unsigned int get_kernel_rows(void) const override { return kernel_rows; } + unsigned int get_kernel_cols(void) const override { return kernel_cols; } + + unsigned int get_stride_rows(void) const override { return stride_rows; } + unsigned int get_stride_cols(void) const override { return stride_cols; } + + unsigned int get_output_rows(void) const override { return output_rows; } + unsigned int get_output_cols(void) const override { return output_cols; } + + unsigned int get_input_rows(void) const override { return input_rows; } + unsigned int get_input_cols(void) const override { return input_cols; } + + void indirect_kernel( + const void *const *const input_ptrs, + void *const *const outptrs, + const void *params, + unsigned int n_channels, + const void *activation_min, + const void *activation_max + ) const override + { + m_indirect_kernel( + reinterpret_cast(input_ptrs), + reinterpret_cast<__fp16 *const *>(outptrs), + params, n_channels, + *static_cast(activation_min), + *static_cast(activation_max) + ); + } + + void direct_kernel( + const unsigned int n_tile_rows, const unsigned int n_tile_cols, + const void *inptr, int64_t ld_input_row, int64_t ld_input_col, + void *outptr, int64_t ld_output_row, int64_t ld_output_col, + const void *params, unsigned int n_channels, + const void *activation_min, const void *activation_max + ) const override + { + m_direct_kernel( + n_tile_rows, n_tile_cols, + static_cast(inptr), ld_input_row, ld_input_col, + static_cast<__fp16 *>(outptr), ld_output_row, ld_output_col, + params, n_channels, + *static_cast(activation_min), + *static_cast(activation_max) + ); + } }; } // namespace depthwise } // namespace arm_conv -#endif // defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS) +#endif // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) && defined(__ARM_FP16_ARGS) -- cgit v1.2.1