diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index cd614ba582..470e9220ae 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -77,7 +77,17 @@ template <typename TypeInput, typename TypeOutput> class Fallback : public NEGEMMAssemblyDispatch::IFallback { public: - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group); + /** Initialise the functions's input and output. + * + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] memory_group Memory group to be used by the function. + */ + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group); + + // Inherited methods overridden: void run() override; void prepare() override; bool is_configured() const override; @@ -116,8 +126,15 @@ private: }; template <typename TypeInput, typename TypeOutput> -void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group) +void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group) { + arm_gemm::GemmConfig gemm_cfg; + const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args); + if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) + { + gemm_cfg.filter = gemm_kernel_info.name; + args._cfg = &gemm_cfg; + } _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args); if(_gemm_kernel_asm == nullptr) { @@ -128,7 +145,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor // arm_compute wrapper for the Gemm object (see above) std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); - acl_gemm_wrapper->configure(_gemm_kernel_asm.get()); + acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); if(workspace_size > 0) { |