diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-03-12 19:34:33 +0000 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-03-16 09:42:36 +0000 |
commit | a602f03f4c66e5ee2480f1a3fc66847968fc1076 (patch) | |
tree | a2752ca0de84f7920dd7296151d14e5edc8cacc0 /arm_compute/core/KernelDescriptors.h | |
parent | 0ec53a0e54ae0be0ed9c4e4c14a5fd10ed5f48a8 (diff) | |
download | ComputeLibrary-a602f03f4c66e5ee2480f1a3fc66847968fc1076.tar.gz |
COMPMID-3237: Extend GEMMLowpReduction kernels to multiply reductions by a scalar value
Change-Id: If2a242f52aea753591525d30a4cb64c1a766bf8d
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2881
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/KernelDescriptors.h')
-rw-r--r-- | arm_compute/core/KernelDescriptors.h | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h index 58400b190b..d9d3e1a4d8 100644 --- a/arm_compute/core/KernelDescriptors.h +++ b/arm_compute/core/KernelDescriptors.h @@ -124,5 +124,27 @@ struct InstanceNormalizationLayerKernelInfo float epsilon; /**< Lower bound value for the normalization. Defaults to 1e-12 */ bool use_mixed_precision; /**< Use mixed precision in case of FP16 execution. Defaults to true */ }; + +struct GEMMLowpReductionKernelInfo +{ + /** Default constructor */ + GEMMLowpReductionKernelInfo() = default; + /** Constructor + * + * @param[in] k Number of matrix columns/rows. + * @param[in] is_reshaped True if the input tensor has been reshaped. + * @param[in] scalar Scalar value to multiply each reduced column/row by. + * @param[in] mul_by_scalar True if each column/row reduction has to be multiplied by a scalar value. + */ + GEMMLowpReductionKernelInfo(int32_t k, bool is_reshaped, int32_t scalar, bool mul_by_scalar) + : k(k), is_reshaped(is_reshaped), scalar(scalar), mul_by_scalar(mul_by_scalar) + { + } + + int32_t k{ 0 }; /**< Number of matrix columns/rows */ + bool is_reshaped{ false }; /**< True if the input tensor has been reshaped */ + int32_t scalar{ 0 }; /**< Scalar value to multiply each reduced column/row by */ + bool mul_by_scalar{ false }; /**< True if each column/row reduction has to be multiplied by a scalar value */ +}; } // namespace arm_compute #endif /* ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H */ |