aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp71
1 files changed, 62 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
index 314fe766de..e4bfbe6783 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/kernels/a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst.hpp
@@ -28,21 +28,25 @@
#pragma once
+#if defined(__aarch64__)
+
namespace arm_conv {
namespace depthwise {
void a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
void a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
-struct a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst
+class a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst : public IDepthwiseDepthfirstStrategy
{
- typedef float bias_type;
- typedef float input_type;
- typedef float weight_type;
- typedef float return_type;
-
+ private:
typedef void (*indirect_kern_type)(const float *const *const, float *const *const, const void *, unsigned int, const float, const float);
+ indirect_kern_type m_indirect_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
+
typedef void (*direct_kern_type)(const unsigned int, const unsigned int, const float *, int64_t, int64_t, float *, int64_t, int64_t, const void *, unsigned int, const float, const float);
+ direct_kern_type m_direct_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
+
+ public:
+ typedef float return_type;
constexpr static arm_gemm::VLType vl_type = arm_gemm::VLType::None;
@@ -58,11 +62,60 @@ struct a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst
constexpr static unsigned int input_rows = 6;
constexpr static unsigned int input_cols = 6;
- indirect_kern_type indirect_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_indirect_impl;
- direct_kern_type direct_kernel = a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst_direct_impl;
-
a64_fp32_nhwc_5x5_s1_output2x2_mla_depthfirst(const CPUInfo *) {}
+
+ arm_gemm::VLType get_vl_type(void) const override { return vl_type; }
+
+ unsigned int get_kernel_rows(void) const override { return kernel_rows; }
+ unsigned int get_kernel_cols(void) const override { return kernel_cols; }
+
+ unsigned int get_stride_rows(void) const override { return stride_rows; }
+ unsigned int get_stride_cols(void) const override { return stride_cols; }
+
+ unsigned int get_output_rows(void) const override { return output_rows; }
+ unsigned int get_output_cols(void) const override { return output_cols; }
+
+ unsigned int get_input_rows(void) const override { return input_rows; }
+ unsigned int get_input_cols(void) const override { return input_cols; }
+
+ void indirect_kernel(
+ const void *const *const input_ptrs,
+ void *const *const outptrs,
+ const void *params,
+ unsigned int n_channels,
+ const void *activation_min,
+ const void *activation_max
+ ) const override
+ {
+ m_indirect_kernel(
+ reinterpret_cast<const float *const *>(input_ptrs),
+ reinterpret_cast<float *const *>(outptrs),
+ params, n_channels,
+ *static_cast<const float *>(activation_min),
+ *static_cast<const float *>(activation_max)
+ );
+ }
+
+ void direct_kernel(
+ const unsigned int n_tile_rows, const unsigned int n_tile_cols,
+ const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
+ void *outptr, int64_t ld_output_row, int64_t ld_output_col,
+ const void *params, unsigned int n_channels,
+ const void *activation_min, const void *activation_max
+ ) const override
+ {
+ m_direct_kernel(
+ n_tile_rows, n_tile_cols,
+ static_cast<const float *>(inptr), ld_input_row, ld_input_col,
+ static_cast<float *>(outptr), ld_output_row, ld_output_col,
+ params, n_channels,
+ *static_cast<const float *>(activation_min),
+ *static_cast<const float *>(activation_max)
+ );
+ }
};
} // namespace depthwise
} // namespace arm_conv
+
+#endif // defined(__aarch64__)