aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-07-06 17:53:35 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commite39334c15c7fd141bb8173d5017ea5ca157fca2c (patch)
treefffa2f7b136525037c4d99586bc194374e5bd3dc /src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
parente8bd2c729546e59aa0adc241976ea91fc6f25b52 (diff)
downloadComputeLibrary-e39334c15c7fd141bb8173d5017ea5ca157fca2c.tar.gz
COMPMID-1271: New system for GEMM heuristics
This patch implements a system for separating the "validity" from "preferred" aspect of the current heuristics in gemm_*.cpp. Now, each gemm_*.cpp defines a list of candidate implementations, each of which supplies an is_valid() function (to check for validity), an is_preferred() function (the "heuristic" part), and an instantiate() function which actually produces the GemmCommon object pointer. The actual gemm() function is now templated and uses this list to select an implementation. This patch also implements a mechanism to identify the preferred implementation, and override it via the GemmConfig structure. Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp52
1 files changed, 36 insertions, 16 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index 1ca92f9d4e..d97dd5c3de 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -25,32 +25,52 @@
#include "arm_gemm.hpp"
#include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
#include "gemm_interleaved.hpp"
-#include "kernels/a64_gemm_u8_4x4.hpp"
+#include "kernels/a64_gemm_u16_12x8.hpp"
#include "kernels/a64_gemm_u8_12x8.hpp"
+#include "kernels/a64_gemm_u8_4x4.hpp"
namespace arm_gemm {
-template<>
-UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const unsigned int nbatches, const unsigned int nmulti,
- const bool trA, const bool trB, const uint32_t alpha, const uint32_t beta,
- const int maxthreads, const bool pretransposed_hint) {
- if (ci.has_dotprod()) {
- // Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_u8_interleaved_dot : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ bool is_supported(const GemmArgs<uint32_t> &args) override {
+ return args._ci->has_dotprod();
}
- // Non dot-product code.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved_dot() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED_DOT) { }
+};
- // TODO: There's a better approach for A53, but it doesn't work
- // well on heterogeneous systems as the required data formats
- // are different. Figure out how to enable this:
- // gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
+class GemmImpl_gemm_u8_interleaved : public GemmImplementation<uint8_t, uint32_t> {
+public:
+ UniqueGemmCommon<uint8_t, uint32_t> instantiate(const GemmArgs<uint32_t> &args) override {
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(args));
+ }
+
+ GemmImpl_gemm_u8_interleaved() : GemmImplementation<uint8_t, uint32_t>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<uint8_t, uint32_t> *> gemm_u8_methods = {
+ new GemmImpl_gemm_u8_interleaved_dot(),
+ new GemmImpl_gemm_u8_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<uint8_t, uint32_t> *> &gemm_implementation_list<uint8_t, uint32_t>() {
+ return gemm_u8_methods;
}
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(GemmArgs<uint32_t> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<uint8_t, uint32_t>(GemmArgs<uint32_t> &args);
+template bool method_is_compatible<uint8_t, uint32_t>(GemmMethod method, GemmArgs<uint32_t> &args);
+
} // namespace arm_gemm
-#endif // aarch64
+#endif // __aarch64__