From 448cb45e2cb86f32a739c925a1ac8c688cf573bf Mon Sep 17 00:00:00 2001 From: Suhail Munshi Date: Fri, 23 Apr 2021 16:23:25 +0100 Subject: Adding S32 support to CLPixelWiseMultiplication Partially resolves : COMPMID-3793 Signed-off-by: Suhail Munshi Change-Id: Id82e00c784f0a039017fd896f11671bdda2dd4ab Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5530 Comments-Addressed: Arm Jenkins Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- src/core/gpu/cl/kernels/ClMulKernel.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'src/core/gpu/cl/kernels/ClMulKernel.cpp') diff --git a/src/core/gpu/cl/kernels/ClMulKernel.cpp b/src/core/gpu/cl/kernels/ClMulKernel.cpp index 837324ede2..b8081bbacf 100644 --- a/src/core/gpu/cl/kernels/ClMulKernel.cpp +++ b/src/core/gpu/cl/kernels/ClMulKernel.cpp @@ -53,12 +53,12 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, + DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, + DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(dst->data_type())); @@ -83,8 +83,8 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons "Dst can only be QASYMM8_SIGNED if both src are QASYMM8_SIGNED"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::QSYMM16 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16), "Dst can only be QSYMM16 if both src are QSYMM16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::S32 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16), - "Dst can only be S32 if both src are QSYMM16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src1->data_type() == DataType::S32 || src2->data_type() == DataType::S32) && (dst->data_type() != DataType::S32), + "Dst must be S32 if source tensors are S32"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst"); } @@ -127,7 +127,12 @@ void ClMulKernel::configure(const CLCompileContext &compile_context, ITensorInfo } else { - if(src1->element_size() == 2 || src2->element_size() == 2) + if(src1->element_size() == 4 || src2->element_size() == 4) + { + // use 64 bit accumulator for 32-bit input + acc_type = "long"; + } + else if(src1->element_size() == 2 || src2->element_size() == 2) { // Use 32-bit accumulator for 16-bit input acc_type = "int"; -- cgit v1.2.1