aboutsummaryrefslogtreecommitdiff
path: root/src/core/gpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/gpu')
-rw-r--r--src/core/gpu/cl/kernels/ClMulKernel.cpp15
-rw-r--r--src/core/gpu/cl/kernels/ClMulKernel.h7
2 files changed, 14 insertions, 8 deletions
diff --git a/src/core/gpu/cl/kernels/ClMulKernel.cpp b/src/core/gpu/cl/kernels/ClMulKernel.cpp
index 837324ede2..b8081bbacf 100644
--- a/src/core/gpu/cl/kernels/ClMulKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClMulKernel.cpp
@@ -53,12 +53,12 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1,
1,
DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
+ DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2,
1,
DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
+ DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative.");
ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(dst->data_type()));
@@ -83,8 +83,8 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons
"Dst can only be QASYMM8_SIGNED if both src are QASYMM8_SIGNED");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::QSYMM16 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16),
"Dst can only be QSYMM16 if both src are QSYMM16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::S32 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16),
- "Dst can only be S32 if both src are QSYMM16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((src1->data_type() == DataType::S32 || src2->data_type() == DataType::S32) && (dst->data_type() != DataType::S32),
+ "Dst must be S32 if source tensors are S32");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
}
@@ -127,7 +127,12 @@ void ClMulKernel::configure(const CLCompileContext &compile_context, ITensorInfo
}
else
{
- if(src1->element_size() == 2 || src2->element_size() == 2)
+ if(src1->element_size() == 4 || src2->element_size() == 4)
+ {
+ // use 64 bit accumulator for 32-bit input
+ acc_type = "long";
+ }
+ else if(src1->element_size() == 2 || src2->element_size() == 2)
{
// Use 32-bit accumulator for 16-bit input
acc_type = "int";
diff --git a/src/core/gpu/cl/kernels/ClMulKernel.h b/src/core/gpu/cl/kernels/ClMulKernel.h
index e2e54a836e..44162f3db3 100644
--- a/src/core/gpu/cl/kernels/ClMulKernel.h
+++ b/src/core/gpu/cl/kernels/ClMulKernel.h
@@ -50,6 +50,7 @@ public:
* - (U8,S16) -> S16
* - (S16,U8) -> S16
* - (S16,S16) -> S16
+ * - (S32,S32) -> S32
* - (F16,F16) -> F16
* - (F32,F32) -> F32
* - (QASYMM8,QASYMM8) -> QASYMM8
@@ -58,9 +59,9 @@ public:
* - (QSYMM16,QSYMM16) -> S32
*
* @param[in] compile_context The compile context to be used.
- * @param[in] src1 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
- * @param[in] src2 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
- * @param[out] dst The dst tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[in] src1 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
+ * @param[in] src2 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
+ * @param[out] dst The dst tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate