aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-01-30 17:17:16 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-02-04 12:45:51 +0000
commit3dbfd23d68b9450fc3e8bf97d92bc211e78e8979 (patch)
tree5ed9fa0b99ddaac47688af280b9f5415fffb0fbd /src/runtime/NEON
parent62c3639b086d768661edc04b9b7e01a54edf486b (diff)
downloadComputeLibrary-3dbfd23d68b9450fc3e8bf97d92bc211e78e8979.tar.gz
COMPMID-1710: Introduce GEMM strategy name in GEMMAssemblyWrapper.
Change-Id: I0fd1a313c051849572367e46e7aa64b1adee5763 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/604 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'src/runtime/NEON')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp23
1 files changed, 20 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index cd614ba582..470e9220ae 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -77,7 +77,17 @@ template <typename TypeInput, typename TypeOutput>
class Fallback : public NEGEMMAssemblyDispatch::IFallback
{
public:
- void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group);
+ /** Initialise the functions's input and output.
+ *
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] args Matrix multiplication information.
+ * @param[in] memory_group Memory group to be used by the function.
+ */
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
+
+ // Inherited methods overridden:
void run() override;
void prepare() override;
bool is_configured() const override;
@@ -116,8 +126,15 @@ private:
};
template <typename TypeInput, typename TypeOutput>
-void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
{
+ arm_gemm::GemmConfig gemm_cfg;
+ const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
+ if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
+ {
+ gemm_cfg.filter = gemm_kernel_info.name;
+ args._cfg = &gemm_cfg;
+ }
_gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args);
if(_gemm_kernel_asm == nullptr)
{
@@ -128,7 +145,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor
// arm_compute wrapper for the Gemm object (see above)
std::unique_ptr<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>> acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapperKernel<TypeInput, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
- acl_gemm_wrapper->configure(_gemm_kernel_asm.get());
+ acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
if(workspace_size > 0)
{