diff options
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemm.cpp | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index 61cd11ece0..f3fff608dc 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -51,6 +51,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.activation_info = info.activation_info(); asm_info.fast_mode = info.fast_math(); asm_info.fixed_format = info.fixed_format(); + asm_info.weight_format = info.weight_format(); return asm_info; } @@ -177,7 +178,8 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens if(d->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0)); + // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more. + ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0)); if(gemm_info.depth_output_gemm3d() != 0) { if(gemm_info.reinterpret_input_as_3d()) @@ -277,7 +279,7 @@ void CpuGemm::run(ITensorPack &tensors) auto c = tensors.get_const_tensor(ACL_SRC_2); auto d = tensors.get_tensor(ACL_DST); - if(_asm_glue->is_configured()) + if(_asm_glue && _asm_glue->is_configured()) { // Pass c to asm dispatch only if it's the bias tensor ITensorPack asm_pack = tensors; @@ -343,7 +345,7 @@ void CpuGemm::prepare(ITensorPack &tensors) { if(!_is_prepared) { - if(_asm_glue->is_configured()) + if(_asm_glue && _asm_glue->is_configured()) { _asm_glue->prepare(tensors); } @@ -365,5 +367,18 @@ experimental::MemoryRequirements CpuGemm::workspace() const { return _aux_mem; } + +Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const GEMMInfo &gemm_info) +{ + const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); + + return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info); +} + +bool CpuGemm::isVarWeightsKernel() const +{ + return _asm_glue && _asm_glue->isVarWeightsKernel(); +} } // namespace cpu } // namespace arm_compute |