aboutsummaryrefslogtreecommitdiff
path: root/src/core/gpu/cl/kernels/ClMulKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/gpu/cl/kernels/ClMulKernel.cpp')
-rw-r--r--src/core/gpu/cl/kernels/ClMulKernel.cpp15
1 files changed, 10 insertions, 5 deletions
diff --git a/src/core/gpu/cl/kernels/ClMulKernel.cpp b/src/core/gpu/cl/kernels/ClMulKernel.cpp
index 837324ede2..b8081bbacf 100644
--- a/src/core/gpu/cl/kernels/ClMulKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClMulKernel.cpp
@@ -53,12 +53,12 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1,
1,
DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
+ DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2,
1,
DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
+ DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative.");
ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(dst->data_type()));
@@ -83,8 +83,8 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons
"Dst can only be QASYMM8_SIGNED if both src are QASYMM8_SIGNED");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::QSYMM16 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16),
"Dst can only be QSYMM16 if both src are QSYMM16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::S32 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16),
- "Dst can only be S32 if both src are QSYMM16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((src1->data_type() == DataType::S32 || src2->data_type() == DataType::S32) && (dst->data_type() != DataType::S32),
+ "Dst must be S32 if source tensors are S32");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
}
@@ -127,7 +127,12 @@ void ClMulKernel::configure(const CLCompileContext &compile_context, ITensorInfo
}
else
{
- if(src1->element_size() == 2 || src2->element_size() == 2)
+ if(src1->element_size() == 4 || src2->element_size() == 4)
+ {
+ // use 64 bit accumulator for 32-bit input
+ acc_type = "long";
+ }
+ else if(src1->element_size() == 2 || src2->element_size() == 2)
{
// Use 32-bit accumulator for 16-bit input
acc_type = "int";