aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp1
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp1
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp5
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp84
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h28
13 files changed, 114 insertions, 18 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index f4af587898..dd72fb5901 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -144,6 +144,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index a502262de3..42f4528066 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -108,6 +108,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 8b855ab07c..69a2803903 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -232,6 +232,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>()
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 4d7f7983f8..cb3ff7aa29 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -236,6 +236,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons
}
template<typename Top, typename Tret, class OutputStage>
+bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
+ const GemmImplementation<Top, Tret, OutputStage> *impl;
+ return find_implementation<Top, Tret, OutputStage>(args, os, impl);
+}
+
+template<typename Top, typename Tret, class OutputStage>
UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) {
const GemmImplementation<Top, Tret, OutputStage> *impl;
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index d650116bf7..3915861112 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -56,6 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index a1134559df..0c68e4dd99 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -159,6 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 1532816b74..6b813c7974 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -230,6 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list
}
template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index a80766bad6..95139c2bf6 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -197,6 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li
}
template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index d459df8126..20cee556f0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -56,6 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t,
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index f2d46d5415..a2d2cc86f0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -157,6 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u
/* Explicitly instantiate the external functions for these types. */
template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
} // namespace arm_gemm
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index e38cc09202..200e04f9a8 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -187,4 +187,7 @@ UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {})
template <typename Top, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
+template <typename Top, typename Tret, class OutputStage = Nothing>
+bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {});
+
} // namespace arm_gemm
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 657f3b8e35..496b55ec1c 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
@@ -203,12 +203,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -635,6 +635,72 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
+Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+ ARM_COMPUTE_UNUSED(c);
+ arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
+ Params p = extract_parameters(a, b, d, info);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ unsigned int num_threads = NEScheduler::get().num_threads();
+
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode);
+ switch(a->data_type())
+ {
+ case DataType::F32:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for F32 input");
+ break;
+#ifdef __aarch64__
+ case DataType::U8:
+ case DataType::QASYMM8:
+ if(d->data_type() == DataType::S32)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})),
+ "We could not find an optimized kernel for U8 input and U8 output");
+ }
+ break;
+ case DataType::S8:
+ case DataType::QASYMM8_SIGNED:
+ if(d->data_type() == DataType::S32)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})),
+ "We could not find an optimized kernel for S8 input and S32 output");
+ }
+ break;
+#endif /* __aarch64__ */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ case DataType::BFLOAT16:
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat, float, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ break;
+ }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
+ break;
+ }
+
+ return Status{};
+}
+
Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_UNUSED(c, info);
@@ -663,7 +729,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- return Status{};
+ return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info);
}
bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index a50f3634c2..74359eee72 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -68,11 +68,11 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual experimental::MemoryRequirements workspace() const = 0;
- virtual bool is_configured() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -97,6 +97,18 @@ public:
* @return a status.
*/
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
+
+ /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
+ *
+ * @param[in] a Input tensor info (Matrix A)
+ * @param[in] b Input tensor info (Matrix B)
+ * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations
+ * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] info GEMM meta-data
+ *
+ * @return a status.
+ */
+ static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info);
/** Checks if activation is supported by the gemm assembly dispatcher
*
* @param[in] activation Activation to check
@@ -111,8 +123,8 @@ public:
bool is_configured() const;
// Inherited methods overridden:
- void prepare(ITensorPack &tensors) override;
- void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
experimental::MemoryRequirements workspace() const override;
private: