diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp | 71 |
1 files changed, 60 insertions, 11 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp index 6f1f187818..751874ffbf 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp @@ -28,7 +28,7 @@ #pragma once -#if defined(ARM_COMPUTE_ENABLE_SVE) +#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) namespace arm_conv { namespace depthwise { @@ -36,15 +36,17 @@ namespace depthwise { void sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float); void sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float); -struct sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst +class sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy { - typedef float bias_type; - typedef float input_type; - typedef float weight_type; - typedef float return_type; - + private: typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float); + indirect_kern_type m_indirect_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl; + typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float); + direct_kern_type m_direct_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl; + + public: + typedef float return_type; constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::SVE; @@ -60,13 +62,60 @@ struct sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst constexpr static unsigned int input_rows = 6; constexpr static unsigned int input_cols = 6; - indirect_kern_type indirect_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl; - direct_kern_type direct_kernel = sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl; - sve_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {} + + arm_gemm::VLType get_vl_type(void) const override { return vl_type; } + + unsigned int get_kernel_rows(void) const override { return kernel_rows; } + unsigned int get_kernel_cols(void) const override { return kernel_cols; } + + unsigned int get_stride_rows(void) const override { return stride_rows; } + unsigned int get_stride_cols(void) const override { return stride_cols; } + + unsigned int get_output_rows(void) const override { return output_rows; } + unsigned int get_output_cols(void) const override { return output_cols; } + + unsigned int get_input_rows(void) const override { return input_rows; } + unsigned int get_input_cols(void) const override { return input_cols; } + + void indirect_kernel( + const void *const *const input_ptrs, + void *const *const outptrs, + const void *params, + unsigned int n_channels, + const void *activation_min, + const void *activation_max + ) const override + { + m_indirect_kernel( + reinterpret_cast<const float *const *>(input_ptrs), + reinterpret_cast<float *const *>(outptrs), + params, n_channels, + *static_cast<const float *>(activation_min), + *static_cast<const float *>(activation_max) + ); + } + + void direct_kernel( + const unsigned int n_tile_rows, const unsigned int n_tile_cols, + const void *inptr, int64_t ld_input_row, int64_t ld_input_col, + void *outptr, int64_t ld_output_row, int64_t ld_output_col, + const void *params, unsigned int n_channels, + const void *activation_min, const void *activation_max + ) const override + { + m_direct_kernel( + n_tile_rows, n_tile_cols, + static_cast<const float *>(inptr), ld_input_row, ld_input_col, + static_cast<float *>(outptr), ld_output_row, ld_output_col, + params, n_channels, + *static_cast<const float *>(activation_min), + *static_cast<const float *>(activation_max) + ); + } }; } // namespace depthwise } // namespace arm_conv -#endif // defined(ARM_COMPUTE_ENABLE_SVE) +#endif // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) |