diff options
author | Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> | 2022-10-21 11:15:54 +0100 |
---|---|---|
committer | Suhail Munshi <MohammedSuhail.Munshi@arm.com> | 2022-11-01 10:58:52 +0000 |
commit | 4b5f6efef15efd79727a58c520c92c9e7a084256 (patch) | |
tree | b8acdadcd397b992f2e7148a2f4da088cb7ae7d5 /src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | |
parent | f44bbc5c697de841dce97c0f2fa39bae391a8174 (diff) | |
download | ComputeLibrary-4b5f6efef15efd79727a58c520c92c9e7a084256.tar.gz |
Add check for Batch Matmul in GemmAssemblyDispatch
Relates to : COMPMID-5507
Change-Id: Ia2c4ea153ac2524ffa6b2a9c10f3a0318a8a67a1
Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8509
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index ab668681ad..8ff81afe54 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -157,8 +157,8 @@ public: const std::vector<int32_t> &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; bool isVarWeightsKernel() const override @@ -211,12 +211,12 @@ private: /** Indirect buffer */ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -767,6 +767,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false"); #ifndef __aarch64__ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64"); |