diff options
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 84 |
1 files changed, 75 insertions, 9 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 657f3b8e35..496b55ec1c 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -156,8 +156,8 @@ public: const std::vector<int32_t> &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; @@ -203,12 +203,12 @@ private: /** Indirect buffer */ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -635,6 +635,72 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() { } +Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); + ARM_COMPUTE_UNUSED(c); + arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); + Params p = extract_parameters(a, b, d, info); + const CPUInfo &ci = NEScheduler::get().cpu_info(); + unsigned int num_threads = NEScheduler::get().num_threads(); + + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode); + switch(a->data_type()) + { + case DataType::F32: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})), + "We could not find an optimized kernel for F32 input"); + break; +#ifdef __aarch64__ + case DataType::U8: + case DataType::QASYMM8: + if(d->data_type() == DataType::S32) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})), + "We could not find an optimized kernel for U8/QASYMM8 input and S32 output"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})), + "We could not find an optimized kernel for U8 input and U8 output"); + } + break; + case DataType::S8: + case DataType::QASYMM8_SIGNED: + if(d->data_type() == DataType::S32) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})), + "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})), + "We could not find an optimized kernel for S8 input and S32 output"); + } + break; +#endif /* __aarch64__ */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) + case DataType::BFLOAT16: + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat, float, arm_gemm::Nothing>(args, {})), + "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + break; + } +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})), + "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); + break; + } + + return Status{}; +} + Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_UNUSED(c, info); @@ -663,7 +729,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - return Status{}; + return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info); } bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) |