diff options
Diffstat (limited to 'src/core/gpu/cl/kernels/ClMulKernel.cpp')
-rw-r--r-- | src/core/gpu/cl/kernels/ClMulKernel.cpp | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/src/core/gpu/cl/kernels/ClMulKernel.cpp b/src/core/gpu/cl/kernels/ClMulKernel.cpp index 837324ede2..b8081bbacf 100644 --- a/src/core/gpu/cl/kernels/ClMulKernel.cpp +++ b/src/core/gpu/cl/kernels/ClMulKernel.cpp @@ -53,12 +53,12 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, + DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, + DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(dst->data_type())); @@ -83,8 +83,8 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons "Dst can only be QASYMM8_SIGNED if both src are QASYMM8_SIGNED"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::QSYMM16 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16), "Dst can only be QSYMM16 if both src are QSYMM16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::S32 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16), - "Dst can only be S32 if both src are QSYMM16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src1->data_type() == DataType::S32 || src2->data_type() == DataType::S32) && (dst->data_type() != DataType::S32), + "Dst must be S32 if source tensors are S32"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst"); } @@ -127,7 +127,12 @@ void ClMulKernel::configure(const CLCompileContext &compile_context, ITensorInfo } else { - if(src1->element_size() == 2 || src2->element_size() == 2) + if(src1->element_size() == 4 || src2->element_size() == 4) + { + // use 64 bit accumulator for 32-bit input + acc_type = "long"; + } + else if(src1->element_size() == 2 || src2->element_size() == 2) { // Use 32-bit accumulator for 16-bit input acc_type = "int"; |