diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-01-30 17:17:16 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-02-04 12:45:51 +0000 |
commit | 3dbfd23d68b9450fc3e8bf97d92bc211e78e8979 (patch) | |
tree | 5ed9fa0b99ddaac47688af280b9f5415fffb0fbd /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | |
parent | 62c3639b086d768661edc04b9b7e01a54edf486b (diff) | |
download | ComputeLibrary-3dbfd23d68b9450fc3e8bf97d92bc211e78e8979.tar.gz |
COMPMID-1710: Introduce GEMM strategy name in GEMMAssemblyWrapper.
Change-Id: I0fd1a313c051849572367e46e7aa64b1adee5763
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/604
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index cd614ba582..470e9220ae 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -77,7 +77,17 @@ template <typename TypeInput, typename TypeOutput> class Fallback : public NEGEMMAssemblyDispatch::IFallback { public: - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group); + /** Initialise the functions's input and output. + * + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] memory_group Memory group to be used by the function. + */ + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group); + + // Inherited methods overridden: void run() override; void prepare() override; bool is_configured() const override; @@ -116,8 +126,15 @@ private: }; template <typename TypeInput, typename TypeOutput> -void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group) +void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group) { + arm_gemm::GemmConfig gemm_cfg; + const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args); + if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) + { + gemm_cfg.filter = gemm_kernel_info.name; + args._cfg = &gemm_cfg; + } _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args); if(_gemm_kernel_asm == nullptr) { @@ -128,7 +145,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor // arm_compute wrapper for the Gemm object (see above) std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); - acl_gemm_wrapper->configure(_gemm_kernel_asm.get()); + acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); const size_t workspace_size = _gemm_kernel_asm->get_working_size(); if(workspace_size > 0) { |