aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-06-30 10:22:01 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-07-19 09:26:27 +0000
commit553f6953fe3bdfad53c11c25f305a16d79d83b24 (patch)
tree73642b948b79662096f593458c6138d2f7f48ec6 /src
parent99c46475daf277aa53e6747f9e41209f418fed33 (diff)
downloadComputeLibrary-553f6953fe3bdfad53c11c25f305a16d79d83b24.tar.gz
[ONCPUML-951] Variable weight support for Convolution.
API changes for NEGEMMConvolutionLayer and CpuGemmConv2d Built with: scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \ build=native -j32 Werror=false validation_tests=1 build_dir=opt \ standalone=1 asserts=1 experimental_fixed_format_kernels=1 . Tested with: ./build/opt/tests/arm_compute_validation Hardware where the test executable was run: Neoverse N1 Test coverage: * NEGEMMConvolutionLayer, CpuGemmConv2d * NHWC (the only one supported by the fixed-format kernels) * F16, F32 * Shapes: RunSmall Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99 Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp7
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp2
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp116
-rw-r--r--src/cpu/operators/CpuGemm.cpp21
-rw-r--r--src/cpu/operators/CpuGemm.h24
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp82
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h24
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp69
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h21
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp10
20 files changed, 363 insertions, 81 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index 50fc5bdb8a..58e4861bc0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -33,21 +33,21 @@
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
#include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
#include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
#include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
@@ -89,7 +89,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sve_ffinterleaved_bf16fp32_mmla_8x3VL",
@@ -106,7 +106,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_HYBRID,
@@ -136,7 +136,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_ffinterleaved_bf16fp32_mmla_8x12",
@@ -161,7 +161,7 @@ GemmImplementation<bfloat16, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
[](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<bfloat16, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -197,7 +197,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 2796b0d204..d749dce98d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -34,17 +34,17 @@
#include "gemm_interleaved.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hgemm_8x24.hpp"
#include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
#include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
#include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
@@ -66,7 +66,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<__fp16, __fp16>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"sve_ffinterleaved_fp16_mla_8x3VL",
@@ -83,7 +83,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
#if defined(__aarch64__)
GemmImplementation<__fp16, __fp16>::with_estimate(
@@ -100,7 +100,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
GemmImplementation<__fp16, __fp16>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
"a64_ffinterleaved_fp16_mla_8x24",
@@ -117,7 +117,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
{
GemmMethod::GEMM_INTERLEAVED,
"a64_sgemm_8x12",
@@ -152,7 +152,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 4f7e191fb3..0fc9e8b912 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -31,12 +31,12 @@
#include "gemv_pretransposed.hpp"
#include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp"
#include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_hybrid_fp32_mla_4x24.hpp"
@@ -48,12 +48,12 @@
#include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp"
#include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp"
#include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp"
#include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp"
#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp"
#include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp"
#include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp"
@@ -165,7 +165,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); }
),
- #ifdef ENABLE_FIXED_FORMAT_KERNELS
+ #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#ifdef ARM_COMPUTE_ENABLE_BF16
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
@@ -200,7 +200,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // ARM_COMPUTE_ENABLE_SVE
// Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases.
{
@@ -253,7 +253,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); }
),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#ifdef ARM_COMPUTE_ENABLE_BF16
// "fast mode" (BF16) kernels
GemmImplementation<float, float>::with_estimate(
@@ -289,7 +289,7 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); }
),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#endif // __aarch64__
#ifdef __arm__
@@ -318,7 +318,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>()
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index c41b0a5b3e..90e2f07607 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -450,7 +450,7 @@ public:
}
/* Make sure we've been set up correctly. */
- assert(_B_transposed);
+ assert(FixedFormat || _B_transposed);
static_assert(std::is_same<To, Tloi>::value, "gemm_native: Operand types must be the same.");
// static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 75fb1cb306..19c8fcadd3 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -306,9 +306,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
}
template<typename Top, typename Tret, class OutputStage>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
+bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) {
const GemmImplementation<Top, Tret, OutputStage> *impl;
- return find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl);
+ if (success)
+ wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
+ return success;
}
template<typename Top, typename Tret, class OutputStage>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 3915861112..18d8fc9312 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -56,7 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int16_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 0c68e4dd99..24507486ac 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -159,7 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int8_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 6b813c7974..1d7b9c5b73 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -230,7 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list
}
template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<int8_t, int8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 95139c2bf6..be7a4ee570 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -197,7 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li
}
template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 20cee556f0..fc836f9790 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -56,7 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t,
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index a2d2cc86f0..03e9cd6c1f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -157,7 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 18e124b83e..d7b5398488 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -24,7 +24,7 @@
#pragma once
-#include "arm_gemm.hpp"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include <cstddef>
#include <limits>
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 247cb1d470..48fd7c6b43 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,6 +47,57 @@ enum class GemmMethod
GEMM_HYBRID_QUANTIZED
};
+/** Memory layouts for the weights tensor.
+ *
+ * * UNSPECIFIED is used to select kernels that do not run in
+ * variable weights mode.
+ *
+ * * ANY is used to query the kernel database to retrieve any of the
+ * kernels that runs in variable weights mode. Once a kernel is
+ * found, the specific format expected by the kernel can be
+ * retrieved by the user for reordering the weights tensor
+ * accordingly.
+ *
+ * The other values OHWIo{interleave_by}i{block_by} describe the
+ * memory layout of a 4D tensor with layout OHWI that has been
+ * transformed into a 4D tensor with dimensions O'HWI' where:
+ *
+ * O' = first multiple of {interleave_by} s.t. O<=O'
+ * I' = first multiple of {block_by} s.t. I<=I'
+ *
+ * The total size of the dst tensor is O' x H x W x I'
+ *
+ * The access function of the tensor with layout
+ * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
+ * access function, where the 6 parameters are computed as follows:
+ *
+ * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
+ *
+ * x4 = h RANGE [0, H-1] SIZE: H
+ * x3 = w RANGE [0, W-1] SIZE: W
+ * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
+ * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
+ * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
+ * TOTAL SIZE: O' * H * W * I'
+ *
+ * 4D 6D
+ * ----------------- -----------------------------------
+ * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
+ * + x4 * W * I' * {interleave_by}
+ * + x3 * I' * {interleave_by}
+ * + x2 * {interleave_by} * {block_by}
+ * + x1 * {block_by}
+ * + x0
+ *
+ * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
+ * for the OHWIo{interleave_by}i{block_by} format is in reality seen
+ * as a 2D tensor, where the number of rows is O'/{interleave_by}
+ * and the number of columns is {interleave_by} * H * W * I'.
+ *
+ * The postfix *_bf16 is for the memory layout needed for the
+ * fast-mode kernels, in which the weights are passed in bfloat16
+ * format.
+ */
enum class WeightFormat
{
UNSPECIFIED = 0x1,
@@ -87,6 +138,69 @@ enum class WeightFormat
OHWIo64i8 = 0x804000
};
+// OHWIo<interleave_by>i<block_by>
+inline int interleave_by(const WeightFormat wf)
+{
+ return ((int)wf >> 8) & 0xFFF;
+}
+inline int block_by(const WeightFormat wf)
+{
+ return ((int)wf >> 20) & 0xF;
+}
+inline bool is_fixed_format(const WeightFormat wf)
+{
+ return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
+}
+
+inline std::string to_string(WeightFormat wf)
+{
+#define __CASE_WEIGHT_FORMAT(wf) \
+case WeightFormat::wf: \
+ return #wf;
+ switch(wf)
+ {
+ __CASE_WEIGHT_FORMAT(UNSPECIFIED)
+ __CASE_WEIGHT_FORMAT(ANY)
+ __CASE_WEIGHT_FORMAT(OHWI)
+ __CASE_WEIGHT_FORMAT(OHWIo2)
+ __CASE_WEIGHT_FORMAT(OHWIo4)
+ __CASE_WEIGHT_FORMAT(OHWIo8)
+ __CASE_WEIGHT_FORMAT(OHWIo16)
+ __CASE_WEIGHT_FORMAT(OHWIo32)
+ __CASE_WEIGHT_FORMAT(OHWIo64)
+ __CASE_WEIGHT_FORMAT(OHWIo128)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo2i8)
+ __CASE_WEIGHT_FORMAT(OHWIo4i8)
+ __CASE_WEIGHT_FORMAT(OHWIo8i8)
+ __CASE_WEIGHT_FORMAT(OHWIo16i8)
+ __CASE_WEIGHT_FORMAT(OHWIo32i8)
+ __CASE_WEIGHT_FORMAT(OHWIo64i8)
+ default:
+ return "invalid value";
+ }
+#undef __CASE_WEIGHT_FORMAT
+}
+
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
@@ -230,6 +344,6 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
template <typename Top, typename Tret, class OutputStage = Nothing>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {});
+bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 61cd11ece0..f3fff608dc 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -51,6 +51,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
+ asm_info.weight_format = info.weight_format();
return asm_info;
}
@@ -177,7 +178,8 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
if(d->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0));
+ // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more.
+ ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0));
if(gemm_info.depth_output_gemm3d() != 0)
{
if(gemm_info.reinterpret_input_as_3d())
@@ -277,7 +279,7 @@ void CpuGemm::run(ITensorPack &tensors)
auto c = tensors.get_const_tensor(ACL_SRC_2);
auto d = tensors.get_tensor(ACL_DST);
- if(_asm_glue->is_configured())
+ if(_asm_glue && _asm_glue->is_configured())
{
// Pass c to asm dispatch only if it's the bias tensor
ITensorPack asm_pack = tensors;
@@ -343,7 +345,7 @@ void CpuGemm::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- if(_asm_glue->is_configured())
+ if(_asm_glue && _asm_glue->is_configured())
{
_asm_glue->prepare(tensors);
}
@@ -365,5 +367,18 @@ experimental::MemoryRequirements CpuGemm::workspace() const
{
return _aux_mem;
}
+
+Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const GEMMInfo &gemm_info)
+{
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
+
+ return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info);
+}
+
+bool CpuGemm::isVarWeightsKernel() const
+{
+ return _asm_glue && _asm_glue->isVarWeightsKernel();
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index 334ab6c647..b37ab73485 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -101,11 +101,29 @@ public:
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+ *
+ * This method has the same use of @ref
+ * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
+ * the value of arm_gemm::WeightFormat need to be passed via the
+ * parameter gemm_info.
+ */
+ static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const GEMMInfo &gemm_info = GEMMInfo());
+
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * When ACL executes convolution in variable weights mode, it does
+ * not perform any processing of the weights tensor. Instead, it
+ * utilizes the data as it is given by the user.
+ */
+ bool isVarWeightsKernel() const;
+
private:
enum AuxTensorIdx
{
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index c021d31059..0174d0eed3 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -99,15 +99,15 @@ CpuGemmConv2d::CpuGemmConv2d()
CpuGemmConv2d::~CpuGemmConv2d() = default;
void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info,
- bool enable_fast_math, int gemm_3d_depth)
+ bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
- ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format));
// Create GEMMInfo structure
const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
- false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info);
+ false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format);
// Supported activations in GEMM
const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
@@ -156,7 +156,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info);
_mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>();
- _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info));
+ _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info,
+ experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format));
auto mm_mem_req = _mm_gemmlowp->workspace();
for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
@@ -178,7 +179,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
}
Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col)
+ const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format)
{
const DataType data_type = src->data_type();
const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
@@ -187,7 +188,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
// Create GEMMInfo structure
const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
- false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info);
+ false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format);
if(is_quantized)
{
@@ -227,6 +228,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset));
weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset));
+
return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, false, enable_fast_math,
false, act_info));
}
@@ -294,6 +296,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
kernel_height,
conv_info,
dilation);
+
ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h),
"Output shape does not match the expected one");
@@ -357,7 +360,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
// Configure GEMM
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
- configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth);
+ const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format());
if(!_skip_col2im && _data_layout == DataLayout::NCHW)
{
@@ -384,6 +388,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights
_aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size());
}
+Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
+{
+ const DataLayout data_layout = src->data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int kernel_width = weights->dimension(idx_width);
+ const unsigned int kernel_height = weights->dimension(idx_height);
+ unsigned int conv_w = 0;
+ unsigned int conv_h = 0;
+ std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
+ src->dimension(idx_height),
+ kernel_width,
+ kernel_height,
+ conv_info,
+ dilation);
+
+ const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+ dilation, act_info);
+
+ const bool skip_im2col = skip_info.skip_im2col;
+ const bool skip_col2im = skip_info.skip_col2im;
+ const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0;
+ const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+ const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
+ gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
+ false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format());
+
+ return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info);
+}
+
Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info,
const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
@@ -450,7 +486,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases);
}
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != dst->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
@@ -472,7 +508,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
im2col_reshaped_info.set_quantization_info(src->quantization_info());
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, 1));
gemm_input_to_use = &im2col_reshaped_info;
}
@@ -490,8 +526,11 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight
info_gemm = TensorInfo(dst->tensor_shape(), 1, output_data_type);
}
info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout());
- gemm_output_to_use = &info_gemm;
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col));
+ gemm_output_to_use = &info_gemm;
+ const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED;
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format,
+ weights_info.weight_format()));
// Validate Col2Im/ReshapeLayer
if(!skip_col2im && (data_layout == DataLayout::NCHW))
@@ -548,7 +587,10 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
// Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions
ITensorPack pack_mm = tensors;
pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use);
- pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ if(!this->isVarWeightsKernel())
+ {
+ pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
+ }
pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use);
if(_is_quantized)
{
@@ -598,6 +640,15 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
+ // Variable weights executions that use fixed-format kernels
+ // need no reshaping of the weights.
+ if(this->isVarWeightsKernel())
+ {
+ _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors);
+ _is_prepared = true;
+ return;
+ }
+
// Run weights reshaping and mark original weights tensor as unused
CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -608,12 +659,9 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors)
};
NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack);
weights->mark_as_unused();
-
- // Prepare GEMM
ITensorPack gemm_pack = tensors;
gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get());
_is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack);
-
_is_prepared = true;
}
}
@@ -621,5 +669,9 @@ experimental::MemoryRequirements CpuGemmConv2d::workspace() const
{
return _aux_mem;
}
+bool CpuGemmConv2d::isVarWeightsKernel() const
+{
+ return _mm_gemm && _mm_gemm->isVarWeightsKernel();
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index aec4a2ffa5..f8f0bce048 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -117,6 +117,17 @@ public:
const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
bool enable_fast_math = false, unsigned int num_groups = 1);
+ /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+ *
+ * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl
+ *
+ * @return a status.
+ */
+ static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(),
+ const bool enable_fast_math = false);
+
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
@@ -135,9 +146,11 @@ private:
* @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation
* available which may introduce a drop of accuracy as well. Default is false
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
+ * @param[in] fixed_format (Optional) Select GEMM execution with variable weights.
+ * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*/
void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
/** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -151,11 +164,13 @@ private:
* available which may introduce a drop of accuracy as well. Default is false
* @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1)
* @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false)
+ * @param[in] fixed_format (Optional) Select GEMM execution with variable weights.
+ * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights.
*
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(),
- bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false);
+ bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED);
/** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32.
@@ -187,6 +202,11 @@ private:
static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info,
const Size2D &dilation, const ActivationLayerInfo &act_info);
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * Similar to @ref CpuGemm::isVarWeightsKernel
+ */
+ bool isVarWeightsKernel() const;
enum AuxTensorIdx
{
// CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 787ea95372..5694a3d9ee 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -160,6 +160,13 @@ public:
void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
+ bool isVarWeightsKernel() const override
+ {
+ if(!_gemm_kernel_asm)
+ return false;
+ const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+ }
private:
enum AuxTensorIdx
@@ -420,6 +427,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
+ // Fixed format kernels need no pretranspose.
+ ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -483,7 +492,24 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+ if(is_fixed_format(wf))
+ {
+ // The 4D tensor of dimension O'HWI' created for the
+ // OHWIo<interleave_by>i<block_by> format is in reality seen
+ // as a 2D tensor at arm_gemm level, where the rows are
+ // O'/<interleave_by> and the columns are <interleave_by> *
+ // H * W * I'.
+ ITensorInfo *tensor_info = b->info();
+ const DataLayout data_layout = tensor_info->data_layout();
+ const TensorShape tensor_shape = tensor_info->tensor_shape();
+ const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
+ const int interleave_by = arm_gemm::interleave_by(wf);
+ ldb = (interleave_by * H * W * Ip);
+ }
multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
@@ -576,7 +602,9 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = info.weight_format;
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -594,7 +622,9 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const CPUInfo &ci = NEScheduler::get().cpu_info();
const unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = info.weight_format;
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -635,7 +665,8 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_UNUSED(c);
@@ -643,12 +674,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
Params p = extract_parameters(a, b, d, info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = info.weight_format;
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
switch(a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -656,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -669,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
"We could not find an optimized kernel for S8 input and S32 output");
}
break;
@@ -689,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor
#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -729,7 +762,17 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info);
+ arm_gemm::WeightFormat expected_weight_format;
+ const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+ if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+ {
+ // Correctness check: if the format expected by the kernel is
+ // not "any", make sure that the one found matches the format
+ // intended by the caller.
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format),
+ "The format expected by the kernel does not correspond with the one requested by the user.");
+ }
+ return ret;
}
bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
@@ -801,7 +844,7 @@ void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors)
bool CpuGemmAssemblyDispatch::is_configured() const
{
- return _arm_gemm != nullptr && _arm_gemm->is_configured();
+ return _arm_gemm && _arm_gemm->is_configured();
}
void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 3c25866f25..4ef108d430 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -53,6 +53,7 @@ struct AsmGemmInfo
float padding_value{ 0.f };
bool fast_mode{ false };
bool fixed_format{ false };
+ arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
};
/** Assembly kernel glue */
@@ -73,6 +74,7 @@ public:
virtual void prepare(ITensorPack &tensors) = 0;
virtual experimental::MemoryRequirements workspace() const = 0;
virtual bool is_configured() const = 0;
+ virtual bool isVarWeightsKernel() const = 0;
virtual ~IFallback() = default;
};
@@ -101,15 +103,14 @@ public:
/** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
*
- * @param[in] a Input tensor info (Matrix A)
- * @param[in] b Input tensor info (Matrix B)
- * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations
- * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] info GEMM meta-data
+ * This method has the same use of @ref
+ * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that
+ * the value of arm_gemm::WeightFormat need to be passed via the
+ * parameter info.
*
* @return a status.
*/
- static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+ static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -122,6 +123,14 @@ public:
* @return True if the function is configured and ready to run
*/
bool is_configured() const;
+ /** Indicates if the convolution executes in variable weights mode.
+ *
+ * Similar to @ref CpuGemm::isVarWeightsKernel
+ */
+ bool isVarWeightsKernel() const
+ {
+ return _arm_gemm && _arm_gemm->isVarWeightsKernel();
+ }
// Inherited methods overridden:
void prepare(ITensorPack &tensors) override;
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index c780d63763..13635c6e34 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,6 +58,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+
_impl->weights = weights;
_impl->op = std::make_unique<cpu::CpuGemmConv2d>();
_impl->op->configure(input->info(), weights->info(), (biases != nullptr ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
@@ -79,6 +80,13 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
return cpu::CpuGemmConv2d::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
}
+Status NEGEMMConvolutionLayer::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
+ const PadStrideInfo &conv_info,
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math)
+{
+ return cpu::CpuGemmConv2d::has_opt_impl(expected_weight_format, src, weights, biases, dst, conv_info, weights_info, dilation, act_info, enable_fast_math);
+}
+
void NEGEMMConvolutionLayer::run()
{
prepare();