aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/functions/NEGEMM.h
diff options
context:
space:
mode:
authorMilos Puzovic <milos.puzovic@arm.com>2022-07-27 17:53:21 +0000
committerGunes Bayir <gunes.bayir@arm.com>2022-08-03 16:47:58 +0000
commit13b623e575ed2f1096c70560a2db4a9e03cf22f9 (patch)
treea517d94c55cfc803c7e13dc89090bf2b3be4dc41 /arm_compute/runtime/NEON/functions/NEGEMM.h
parent3c4d085da54c3d9727cb31718c5b407c18ff646a (diff)
downloadComputeLibrary-13b623e575ed2f1096c70560a2db4a9e03cf22f9.tar.gz
[ONCPUML-968] Fixed format kernel support in additional APIs
Implements required plumbing in order to be able to ask and execute fixed format kernels from NEFullyConnected, NEGEMM and NEGEMMConv2d. These APIs are used to accelerate oneDNN primitives (inner product, matrix multiplication and indirect GEMM respectively) and without changes it would not be possible to call fixed format kernels from those oneDNN primitives. Change-Id: I27534f0491ce28d0ccb98c19f318bd33dcdf2ff5 Signed-off-by: Milos Puzovic <milos.puzovic@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7999 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEGEMM.h')
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMM.h11
1 files changed, 10 insertions, 1 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h
index ce68a61923..7ce2521148 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMM.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -84,6 +84,15 @@ public:
*/
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format
+ * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same
+ * as in @ref NEGEMM::validate() except that all arguments are required.
+ *
+ * @return a status
+ */
+ static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output,
+ float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+
// Inherited methods overridden:
void run() override;
void prepare() override;