aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2022-10-21 11:15:54 +0100
committerSuhail Munshi <MohammedSuhail.Munshi@arm.com>2022-11-01 10:58:52 +0000
commit4b5f6efef15efd79727a58c520c92c9e7a084256 (patch)
treeb8acdadcd397b992f2e7148a2f4da088cb7ae7d5 /src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
parentf44bbc5c697de841dce97c0f2fa39bae391a8174 (diff)
downloadComputeLibrary-4b5f6efef15efd79727a58c520c92c9e7a084256.tar.gz
Add check for Batch Matmul in GemmAssemblyDispatch
Relates to : COMPMID-5507 Change-Id: Ia2c4ea153ac2524ffa6b2a9c10f3a0318a8a67a1 Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8509 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index ab668681ad..8ff81afe54 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -157,8 +157,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
bool isVarWeightsKernel() const override
@@ -211,12 +211,12 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
+ bool _B_pretranspose_required{ false };
+ bool _is_b_constant{ true };
+ bool _is_c_constant{ true };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -767,6 +767,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false");
#ifndef __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");