aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels
diff options
context:
space:
mode:
authorFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-06-30 10:22:01 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-07-19 09:26:27 +0000
commit553f6953fe3bdfad53c11c25f305a16d79d83b24 (patch)
tree73642b948b79662096f593458c6138d2f7f48ec6 /src/core/NEON/kernels
parent99c46475daf277aa53e6747f9e41209f418fed33 (diff)
downloadComputeLibrary-553f6953fe3bdfad53c11c25f305a16d79d83b24.tar.gz
[ONCPUML-951] Variable weight support for Convolution.
API changes for NEGEMMConvolutionLayer and CpuGemmConv2d Built with: scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \ build=native -j32 Werror=false validation_tests=1 build_dir=opt \ standalone=1 asserts=1 experimental_fixed_format_kernels=1 . Tested with: ./build/opt/tests/arm_compute_validation Hardware where the test executable was run: Neoverse N1 Test coverage: * NEGEMMConvolutionLayer, CpuGemmConv2d * NHWC (the only one supported by the fixed-format kernels) * F16, F32 * Shapes: RunSmall Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99 Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp7
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp2
12 files changed, 40 insertions, 37 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index 50fc5bdb8a..58e4861bc0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -33,21 +33,21 @@
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
#include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
#include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
#include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
@@ -89,7 +89,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sve_ffinterleaved_bf16fp32_mmla_8x3VL",
@@ -106,7 +106,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_HYBRID,
@@ -136,7 +136,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_ffinterleaved_bf16fp32_mmla_8x12",
@@ -161,7 +161,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -197,7 +197,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 2796b0d204..d749dce98d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -34,17 +34,17 @@
#include "gemm_interleaved.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hgemm_8x24.hpp"
#include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
@@ -66,7 +66,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<__fp16, __fp16>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sve_ffinterleaved_fp16_mla_8x3VL",
@@ -83,7 +83,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
#if defined(__aarch64__)
GemmImplementation<__fp16, __fp16>::with_estimate(
@@ -100,7 +100,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<__fp16, __fp16>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_ffinterleaved_fp16_mla_8x24",
@@ -117,7 +117,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
{
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -152,7 +152,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 4f7e191fb3..0fc9e8b912 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -31,12 +31,12 @@
#include "gemv_pretransposed.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp"
#include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_hybrid_fp32_mla_4x24.hpp"
@@ -48,12 +48,12 @@
#include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp"
#include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp"
#include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp"
#include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp"
#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp"
#include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp"
@@ -165,7 +165,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); }
),
- #ifdef ENABLE_FIXED_FORMAT_KERNELS
+ #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#ifdef ARM_COMPUTE_ENABLE_BF16
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
@@ -200,7 +200,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
// Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases.
{
@@ -253,7 +253,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#ifdef ARM_COMPUTE_ENABLE_BF16
// "fast mode" (BF16) kernels
GemmImplementation<float, float>::with_estimate(
@@ -289,7 +289,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // __aarch64__
#ifdef __arm__
@@ -318,7 +318,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>()
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index c41b0a5b3e..90e2f07607 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -450,7 +450,7 @@ public:
}
/* Make sure we've been set up correctly. */
- assert(_B_transposed);
+ assert(FixedFormat || _B_transposed);
static_assert(std::is_same<To, Tloi>::value, "gemm_native: Operand types must be the same.");
// static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 75fb1cb306..19c8fcadd3 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -306,9 +306,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
}
template<typename Top, typename Tret, class OutputStage>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
+bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) {
const GemmImplementation<Top, Tret, OutputStage> *impl;
- return find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ if (success)
+ wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
+ return success;
}
template<typename Top, typename Tret, class OutputStage>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 3915861112..18d8fc9312 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -56,7 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int16_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 0c68e4dd99..24507486ac 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -159,7 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int8_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 6b813c7974..1d7b9c5b73 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -230,7 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list
}
template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<int8_t, int8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 95139c2bf6..be7a4ee570 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -197,7 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li
}
template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 20cee556f0..fc836f9790 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -56,7 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t,
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index a2d2cc86f0..03e9cd6c1f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -157,7 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 18e124b83e..d7b5398488 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -24,7 +24,7 @@
#pragma once
-#include "arm_gemm.hpp"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include <cstddef>
#include <limits>