aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/KernelDescriptors.h
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-12 19:34:33 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-16 09:42:36 +0000
commita602f03f4c66e5ee2480f1a3fc66847968fc1076 (patch)
treea2752ca0de84f7920dd7296151d14e5edc8cacc0 /arm_compute/core/KernelDescriptors.h
parent0ec53a0e54ae0be0ed9c4e4c14a5fd10ed5f48a8 (diff)
downloadComputeLibrary-a602f03f4c66e5ee2480f1a3fc66847968fc1076.tar.gz
COMPMID-3237: Extend GEMMLowpReduction kernels to multiply reductions by a scalar value
Change-Id: If2a242f52aea753591525d30a4cb64c1a766bf8d Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2881 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/KernelDescriptors.h')
-rw-r--r--arm_compute/core/KernelDescriptors.h22
1 files changed, 22 insertions, 0 deletions
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index 58400b190b..d9d3e1a4d8 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -124,5 +124,27 @@ struct InstanceNormalizationLayerKernelInfo
float epsilon; /**< Lower bound value for the normalization. Defaults to 1e-12 */
bool use_mixed_precision; /**< Use mixed precision in case of FP16 execution. Defaults to true */
};
+
+struct GEMMLowpReductionKernelInfo
+{
+ /** Default constructor */
+ GEMMLowpReductionKernelInfo() = default;
+ /** Constructor
+ *
+ * @param[in] k Number of matrix columns/rows.
+ * @param[in] is_reshaped True if the input tensor has been reshaped.
+ * @param[in] scalar Scalar value to multiply each reduced column/row by.
+ * @param[in] mul_by_scalar True if each column/row reduction has to be multiplied by a scalar value.
+ */
+ GEMMLowpReductionKernelInfo(int32_t k, bool is_reshaped, int32_t scalar, bool mul_by_scalar)
+ : k(k), is_reshaped(is_reshaped), scalar(scalar), mul_by_scalar(mul_by_scalar)
+ {
+ }
+
+ int32_t k{ 0 }; /**< Number of matrix columns/rows */
+ bool is_reshaped{ false }; /**< True if the input tensor has been reshaped */
+ int32_t scalar{ 0 }; /**< Scalar value to multiply each reduced column/row by */
+ bool mul_by_scalar{ false }; /**< True if each column/row reduction has to be multiplied by a scalar value */
+};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H */