aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp74
1 files changed, 55 insertions, 19 deletions
diff --git a/src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp b/src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp
index fd4e045035..dcb3c8f57c 100644
--- a/src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp
+++ b/src/core/NEON/kernels/arm_conv/pooling/pooling_s8q.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,13 +25,17 @@
#include "arm_gemm_local.hpp"
#include "pooling_implementation.hpp"
-#include "pooling_depthfirst_generic_quantized.hpp"
+#include "pooling_depthfirst_generic.hpp"
#if defined(__aarch64__)
-#if defined(__ARM_FEATURE_SVE) && defined(SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SME)
+#include "kernels/sme_s8q_nhwc_avg_generic_depthfirst.hpp"
+#include "kernels/sme_s8q_nhwc_max_generic_depthfirst.hpp"
+#endif // defined(ARM_COMPUTE_ENABLE_SME)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
#include "kernels/sve_s8q_nhwc_avg_generic_depthfirst.hpp"
#include "kernels/sve_s8q_nhwc_max_generic_depthfirst.hpp"
-#endif // defined(__ARM_FEATURE_SVE) && defined(SVE2)
+#endif // defined(ARM_COMPUTE_ENABLE_SVE)
#include "kernels/a64_s8q_nhwc_avg_generic_depthfirst.hpp"
#include "kernels/a64_s8q_nhwc_max_generic_depthfirst.hpp"
#endif // defined(__aarch64__)
@@ -41,30 +45,60 @@
namespace arm_conv {
namespace pooling {
-static const PoolingImplementation<int8_t, int8_t, Requantize32> pooling_u8_methods[] = {
+static const PoolingImplementation<int8_t, int8_t, Requantize32> pooling_s8q_methods[] = {
#if defined(__aarch64__)
-#if defined(__ARM_FEATURE_SVE) && defined(SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SME)
+ {
+ PoolingMethod::DEPTHFIRST,
+ "sme_s8q_nhwc_avg_generic_depthfirst",
+ [] (const PoolingArgs &args, const Requantize32 &) -> bool {
+ return args.cpu_info->has_sme2() && args.pool_type == PoolingType::AVERAGE;
+ },
+ nullptr,
+ [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t> * {
+ auto strat = new sme_s8q_nhwc_avg_generic_depthfirst(args.cpu_info);
+ return new PoolingDepthfirstGeneric<int8_t, int8_t, Requantize32>(strat, args, rq);
+ },
+ },
+ {
+ PoolingMethod::DEPTHFIRST,
+ "sme_s8q_nhwc_max_generic_depthfirst",
+ [] (const PoolingArgs &args, const Requantize32 &) -> bool {
+ return args.cpu_info->has_sme2() && args.pool_type == PoolingType::MAX;
+ },
+ nullptr,
+ [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t> * {
+ auto strat = new sme_s8q_nhwc_max_generic_depthfirst(args.cpu_info);
+ return new PoolingDepthfirstGeneric<int8_t, int8_t, Requantize32>(strat, args, rq);
+ },
+ },
+#endif // defined(ARM_COMPUTE_ENABLE_SME)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
{
PoolingMethod::DEPTHFIRST,
"sve_s8q_nhwc_avg_generic_depthfirst",
[] (const PoolingArgs &args, const Requantize32 &) -> bool {
- return args.pool_type == PoolingType::AVERAGE;
+ return args.cpu_info->has_sve2() && args.pool_type == PoolingType::AVERAGE;
},
nullptr,
- [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t, Requantize32> * {
- return new PoolingDepthfirstGenericQuantized<sve_s8q_nhwc_avg_generic_depthfirst>(args, rq);
+ [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t> * {
+ auto strat = new sve_s8q_nhwc_avg_generic_depthfirst(args.cpu_info);
+ return new PoolingDepthfirstGeneric<int8_t, int8_t, Requantize32>(strat, args, rq);
},
},
{
PoolingMethod::DEPTHFIRST,
"sve_s8q_nhwc_max_generic_depthfirst",
- [] (const PoolingArgs &args, const Requantize32 &) -> bool { return args.pool_type == PoolingType::MAX; },
+ [] (const PoolingArgs &args, const Requantize32 &) -> bool {
+ return args.cpu_info->has_sve2() && args.pool_type == PoolingType::MAX;
+ },
nullptr,
- [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t, Requantize32> * {
- return new PoolingDepthfirstGenericQuantized<sve_s8q_nhwc_max_generic_depthfirst>(args, rq);
+ [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t> * {
+ auto strat = new sve_s8q_nhwc_max_generic_depthfirst(args.cpu_info);
+ return new PoolingDepthfirstGeneric<int8_t, int8_t, Requantize32>(strat, args, rq);
},
},
-#endif // defined(__ARM_FEATURE_SVE) && defined(SVE2)
+#endif // defined(ARM_COMPUTE_ENABLE_SVE)
{
PoolingMethod::DEPTHFIRST,
"a64_s8q_nhwc_avg_generic_depthfirst",
@@ -72,8 +106,9 @@ static const PoolingImplementation<int8_t, int8_t, Requantize32> pooling_u8_meth
return args.pool_type == PoolingType::AVERAGE;
},
nullptr,
- [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t, Requantize32> * {
- return new PoolingDepthfirstGenericQuantized<a64_s8q_nhwc_avg_generic_depthfirst>(args, rq);
+ [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t> * {
+ auto strat = new a64_s8q_nhwc_avg_generic_depthfirst(args.cpu_info);
+ return new PoolingDepthfirstGeneric<int8_t, int8_t, Requantize32>(strat, args, rq);
},
},
{
@@ -81,8 +116,9 @@ static const PoolingImplementation<int8_t, int8_t, Requantize32> pooling_u8_meth
"a64_s8q_nhwc_max_generic_depthfirst",
[] (const PoolingArgs &args, const Requantize32 &) -> bool { return args.pool_type == PoolingType::MAX; },
nullptr,
- [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t, Requantize32> * {
- return new PoolingDepthfirstGenericQuantized<a64_s8q_nhwc_max_generic_depthfirst>(args, rq);
+ [] (const PoolingArgs &args, const Requantize32 &rq) -> PoolingCommon<int8_t, int8_t> * {
+ auto strat = new a64_s8q_nhwc_max_generic_depthfirst(args.cpu_info);
+ return new PoolingDepthfirstGeneric<int8_t, int8_t, Requantize32>(strat, args, rq);
},
},
#endif // defined(__aarch64__)
@@ -92,10 +128,10 @@ static const PoolingImplementation<int8_t, int8_t, Requantize32> pooling_u8_meth
template <>
const PoolingImplementation<int8_t, int8_t, Requantize32> *pooling_implementation_list()
{
- return pooling_u8_methods;
+ return pooling_s8q_methods;
}
-template UniquePoolingCommon<int8_t, int8_t, Requantize32> pooling(const PoolingArgs &, const Requantize32 &);
+template UniquePoolingCommon<int8_t, int8_t> pooling(const PoolingArgs &, const Requantize32 &);
} // namespace pooling
} // namespace arm_conv