aboutsummaryrefslogtreecommitdiff
path: root/arm_compute
diff options
context:
space:
mode:
authorRadu Salavat <radu.salavat@arm.com>2024-02-27 18:32:26 +0000
committerRadu Salavat <radu.salavat@arm.com>2024-04-11 08:47:50 +0000
commitf1f1f87132690a8061801ef1a4638d637c780df7 (patch)
tree8ad4c3739217b3bc6281f4e0b9a7a63fe6c3f9bb /arm_compute
parent1322065a3fbd15b00dbfb0969d6b438b5ba15530 (diff)
downloadComputeLibrary-f1f1f87132690a8061801ef1a4638d637c780df7.tar.gz
Add in place summation to CPU GEMM kernels
Instead of dispatching the sum postop for GEMM kernels to a separate kernel + add, that requires an extra destination sized allocation, plus 3 extra load/stores per element, just do it in the GEMM kernel. Resolves: ONCPUML-1442 Signed-off-by: Radu Salavat <radu.salavat@arm.com> Co-authored-by: Milos Puzovic <milos.puzovic@arm.com> Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute')
-rw-r--r--arm_compute/function_info/GEMMInfo.h31
1 files changed, 25 insertions, 6 deletions
diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h
index a827c79fda..74fe30454e 100644
--- a/arm_compute/function_info/GEMMInfo.h
+++ b/arm_compute/function_info/GEMMInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,8 @@ public:
_pretranspose_B(false),
_activation_info(),
_fixed_format(false),
- _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
+ _accumulate(false)
{
}
/** Constructor
@@ -106,6 +107,7 @@ public:
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
* @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
* @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
+ * @param[in] accumulate (Optional) Whether to accumulate in destination or not
*/
GEMMInfo(bool is_a_reshaped,
bool is_b_reshaped,
@@ -120,7 +122,8 @@ public:
const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
bool fixed_format = false,
arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
- bool pretranspose_B = false) noexcept
+ bool pretranspose_B = false,
+ bool accumulate = false) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -135,7 +138,8 @@ public:
_pretranspose_B(pretranspose_B),
_activation_info(activation_info),
_fixed_format(fixed_format),
- _weight_format(weight_format)
+ _weight_format(weight_format),
+ _accumulate(accumulate)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -294,7 +298,14 @@ public:
{
return _fixed_format;
}
-
+ /** Flag which specifies if GEMM should accumulate the result in destination or not.
+ *
+ * @return True if GEMM is accumulating the result.
+ */
+ bool accumulate() const
+ {
+ return _accumulate;
+ }
/** Set fixed-format flag
*
* @param[in] fixed_format sets whether or not to use fixed-format kernels
@@ -303,12 +314,19 @@ public:
{
_fixed_format = fixed_format;
}
+ /** Set accumulate flag
+ *
+ * @param[in] accumulate sets whether or not to use accumulation
+ */
+ void set_accumulate(bool accumulate)
+ {
+ _accumulate = accumulate;
+ }
arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
-
/** Set weight format to be used
*
* @param[in] weight_format arm_compute::WeightFormat enumeration
@@ -334,6 +352,7 @@ private:
ActivationLayerInfo _activation_info;
bool _fixed_format;
arm_compute::WeightFormat _weight_format;
+ bool _accumulate;
};
} //namespace arm_compute
#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H