aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2022-10-21 11:15:54 +0100
committerSuhail Munshi <MohammedSuhail.Munshi@arm.com>2022-11-01 10:58:52 +0000
commit4b5f6efef15efd79727a58c520c92c9e7a084256 (patch)
treeb8acdadcd397b992f2e7148a2f4da088cb7ae7d5
parentf44bbc5c697de841dce97c0f2fa39bae391a8174 (diff)
downloadComputeLibrary-4b5f6efef15efd79727a58c520c92c9e7a084256.tar.gz
Add check for Batch Matmul in GemmAssemblyDispatch
Relates to : COMPMID-5507 Change-Id: Ia2c4ea153ac2524ffa6b2a9c10f3a0318a8a67a1 Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8509 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp15
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp17
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h1
3 files changed, 18 insertions, 15 deletions
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index 8faa3c217a..aec9da193b 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -59,12 +59,13 @@ namespace
cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
{
cpu::AsmGemmInfo asm_info;
- asm_info.method = cpu::AsmConvMethod::Im2Col;
- asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
- asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
- asm_info.activation_info = info.activation_info();
- asm_info.output_stage = info.gemmlowp_output_stage();
- asm_info.fast_mode = info.fast_math();
+ asm_info.method = cpu::AsmConvMethod::Im2Col;
+ asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
+ asm_info.depth_output_gemm3d = info.depth_output_gemm3d();
+ asm_info.activation_info = info.activation_info();
+ asm_info.output_stage = info.gemmlowp_output_stage();
+ asm_info.fast_mode = info.fast_math();
+ asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run();
return asm_info;
}
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index ab668681ad..8ff81afe54 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -157,8 +157,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
bool isVarWeightsKernel() const override
@@ -211,12 +211,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -767,6 +767,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false");
#ifndef __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 691eeff8d2..0c51c92359 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -54,6 +54,7 @@ struct AsmGemmInfo
bool fast_mode{ false };
bool fixed_format{ false };
arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED };
+ bool reshape_b_only_on_first_run{ true };
};
/** Assembly kernel glue */