aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorFrancesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>2022-03-31 17:55:35 +0000
committerFrancesco Petrogalli <francesco.petrogalli@arm.com>2022-04-06 13:41:41 +0000
commite33c55640b5e6e6af193a41a3376e2a01a321a35 (patch)
tree3a8282410d5f8e7f57276cbca9e2dfa6d14115d7 /src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
parent4c17ba951b76e97102f101a88edbf012b722c732 (diff)
downloadComputeLibrary-e33c55640b5e6e6af193a41a3376e2a01a321a35.tar.gz
[arm_gemm] Use static validate to find arm_gemm kernels.
The static method `CpuGemmAssemblyDispatch::validate` should look into the list of the available kernels to make sure the one requested by the user was found. Formatting changes in the files touched by the patch have been automatically inserted by the formatting script. Resolves: ONCPUML-840 Change-Id: Icd650a30e142284a942c64f8a2b72441ee7b3f4e Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7375 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp84
1 files changed, 75 insertions, 9 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 657f3b8e35..496b55ec1c 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -156,8 +156,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
@@ -203,12 +203,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -635,6 +635,72 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
{
}
+Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+ ARM_COMPUTE_UNUSED(c);
+ arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
+ Params p = extract_parameters(a, b, d, info);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ unsigned int num_threads = NEScheduler::get().num_threads();
+
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode);
+ switch(a->data_type())
+ {
+ case DataType::F32:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for F32 input");
+ break;
+#ifdef __aarch64__
+ case DataType::U8:
+ case DataType::QASYMM8:
+ if(d->data_type() == DataType::S32)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})),
+ "We could not find an optimized kernel for U8 input and U8 output");
+ }
+ break;
+ case DataType::S8:
+ case DataType::QASYMM8_SIGNED:
+ if(d->data_type() == DataType::S32)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})),
+ "We could not find an optimized kernel for S8 input and S32 output");
+ }
+ break;
+#endif /* __aarch64__ */
+#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
+ case DataType::BFLOAT16:
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat, float, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ break;
+ }
+#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+ default:
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel");
+ break;
+ }
+
+ return Status{};
+}
+
Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_UNUSED(c, info);
@@ -663,7 +729,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
- return Status{};
+ return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info);
}
bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)