diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 75fb1cb306..19c8fcadd3 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -306,9 +306,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons } template<typename Top, typename Tret, class OutputStage> -bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) { +bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { const GemmImplementation<Top, Tret, OutputStage> *impl; - return find_implementation<Top, Tret, OutputStage>(args, os, impl); + const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl); + if (success) + wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; + return success; } template<typename Top, typename Tret, class OutputStage> |