From 4b5f6efef15efd79727a58c520c92c9e7a084256 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Fri, 21 Oct 2022 11:15:54 +0100 Subject: Add check for Batch Matmul in GemmAssemblyDispatch Relates to : COMPMID-5507 Change-Id: Ia2c4ea153ac2524ffa6b2a9c10f3a0318a8a67a1 Signed-off-by: Mohammed Suhail Munshi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8509 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 15 ++++++++------- src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 17 +++++++++-------- src/cpu/operators/internal/CpuGemmAssemblyDispatch.h | 1 + 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index 8faa3c217a..aec9da193b 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -59,12 +59,13 @@ namespace cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) { cpu::AsmGemmInfo asm_info; - asm_info.method = cpu::AsmConvMethod::Im2Col; - asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d(); - asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); - asm_info.activation_info = info.activation_info(); - asm_info.output_stage = info.gemmlowp_output_stage(); - asm_info.fast_mode = info.fast_math(); + asm_info.method = cpu::AsmConvMethod::Im2Col; + asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d(); + asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); + asm_info.activation_info = info.activation_info(); + asm_info.output_stage = info.gemmlowp_output_stage(); + asm_info.fast_mode = info.fast_math(); + asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run(); return asm_info; } diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index ab668681ad..8ff81afe54 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -157,8 +157,8 @@ public: const std::vector &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; bool isVarWeightsKernel() const override @@ -211,12 +211,12 @@ private: /** Indirect buffer */ std::unique_ptr _indirect_arg{}; std::unique_ptr _indirect_buf{}; - std::vector _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template @@ -767,6 +767,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false"); #ifndef __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 691eeff8d2..0c51c92359 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -54,6 +54,7 @@ struct AsmGemmInfo bool fast_mode{ false }; bool fixed_format{ false }; arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; + bool reshape_b_only_on_first_run{ true }; }; /** Assembly kernel glue */ -- cgit v1.2.1