aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp23
1 files changed, 20 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index cd614ba582..470e9220ae 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -77,7 +77,17 @@ template <typename TypeInput, typename TypeOutput>
class Fallback : public NEGEMMAssemblyDispatch::IFallback
{
public:
- void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
+ /** Initialise the functions's input and output.
+ *
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] args Matrix multiplication information.
+ * @param[in] memory_group Memory group to be used by the function.
+ */
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
+
+ // Inherited methods overridden:
void run() override;
void prepare() override;
bool is_configured() const override;
@@ -116,8 +126,15 @@ private:
};
template <typename TypeInput, typename TypeOutput>
-void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
{
+ arm_gemm::GemmConfig gemm_cfg;
+ const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
+ if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
+ {
+ gemm_cfg.filter = gemm_kernel_info.name;
+ args._cfg = &gemm_cfg;
+ }
_gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args);
if(_gemm_kernel_asm == nullptr)
{
@@ -128,7 +145,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor
// arm_compute wrapper for the Gemm object (see above)
std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
- acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
+ acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
if(workspace_size > 0)
{