aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSuhail Munshi <MohammedSuhail.Munshi@arm.com>2021-04-23 16:23:25 +0100
committerSheri Zhang <sheri.zhang@arm.com>2021-05-05 14:21:05 +0000
commit448cb45e2cb86f32a739c925a1ac8c688cf573bf (patch)
tree121c4d98be4eb27a8652eefe040f835d17d8955e
parent7e9f34d219ae5dd3ddd5d26475f42aa02bcf010f (diff)
downloadComputeLibrary-448cb45e2cb86f32a739c925a1ac8c688cf573bf.tar.gz
Adding S32 support to CLPixelWiseMultiplication
Partially resolves : COMPMID-3793 Signed-off-by: Suhail Munshi <MohammedSuhail.Munshi@arm.com> Change-Id: Id82e00c784f0a039017fd896f11671bdda2dd4ab Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5530 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h15
-rw-r--r--src/core/CL/cl_kernels/pixelwise_mul_int.cl1
-rw-r--r--src/core/gpu/cl/kernels/ClMulKernel.cpp15
-rw-r--r--src/core/gpu/cl/kernels/ClMulKernel.h7
-rw-r--r--tests/validation/CL/PixelWiseMultiplication.cpp17
-rw-r--r--tests/validation/fixtures/PixelWiseMultiplicationFixture.h12
6 files changed, 51 insertions, 16 deletions
diff --git a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
index 029b924512..d352c6e282 100644
--- a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
+++ b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
@@ -68,13 +68,14 @@ public:
* |S16 |U8 |S16 |
* |S16 |S16 |S16 |
* |F16 |F16 |F16 |
- * |F32 |S32 |F32 |
+ * |F32 |F32 |F32 |
+ * |S32 |S32 |S32 |
*
- * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
- * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
- * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
@@ -86,11 +87,11 @@ public:
/** Initialise the kernel's inputs, output and convertion policy.
*
* @param[in] compile_context The compile context to be used.
- * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
- * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
- * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
diff --git a/src/core/CL/cl_kernels/pixelwise_mul_int.cl b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
index 32c46def77..ac5cabcb8c 100644
--- a/src/core/CL/cl_kernels/pixelwise_mul_int.cl
+++ b/src/core/CL/cl_kernels/pixelwise_mul_int.cl
@@ -90,7 +90,6 @@ __kernel void pixelwise_mul_int(
// Load data
VEC_ACC_TYPE in1_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN1, VEC_SIZE_OUT))VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_IN1 *)in1_addr), VEC_ACC_TYPE);
VEC_ACC_TYPE in2_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN2, VEC_SIZE_OUT))VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_IN2 *)in2_addr), VEC_ACC_TYPE);
-
// Perform multiplication and store result
VEC_OUT_TYPE out_data0 = MUL_OP(in1_data, in2_data, scale, DATA_TYPE_OUT, VEC_SIZE_OUT);
STORE_VECTOR_SELECT(out_data, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
diff --git a/src/core/gpu/cl/kernels/ClMulKernel.cpp b/src/core/gpu/cl/kernels/ClMulKernel.cpp
index 837324ede2..b8081bbacf 100644
--- a/src/core/gpu/cl/kernels/ClMulKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClMulKernel.cpp
@@ -53,12 +53,12 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1,
1,
DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
+ DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2,
1,
DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
+ DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32,
DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative.");
ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(dst->data_type()));
@@ -83,8 +83,8 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons
"Dst can only be QASYMM8_SIGNED if both src are QASYMM8_SIGNED");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::QSYMM16 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16),
"Dst can only be QSYMM16 if both src are QSYMM16");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::S32 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16),
- "Dst can only be S32 if both src are QSYMM16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((src1->data_type() == DataType::S32 || src2->data_type() == DataType::S32) && (dst->data_type() != DataType::S32),
+ "Dst must be S32 if source tensors are S32");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
}
@@ -127,7 +127,12 @@ void ClMulKernel::configure(const CLCompileContext &compile_context, ITensorInfo
}
else
{
- if(src1->element_size() == 2 || src2->element_size() == 2)
+ if(src1->element_size() == 4 || src2->element_size() == 4)
+ {
+ // use 64 bit accumulator for 32-bit input
+ acc_type = "long";
+ }
+ else if(src1->element_size() == 2 || src2->element_size() == 2)
{
// Use 32-bit accumulator for 16-bit input
acc_type = "int";
diff --git a/src/core/gpu/cl/kernels/ClMulKernel.h b/src/core/gpu/cl/kernels/ClMulKernel.h
index e2e54a836e..44162f3db3 100644
--- a/src/core/gpu/cl/kernels/ClMulKernel.h
+++ b/src/core/gpu/cl/kernels/ClMulKernel.h
@@ -50,6 +50,7 @@ public:
* - (U8,S16) -> S16
* - (S16,U8) -> S16
* - (S16,S16) -> S16
+ * - (S32,S32) -> S32
* - (F16,F16) -> F16
* - (F32,F32) -> F32
* - (QASYMM8,QASYMM8) -> QASYMM8
@@ -58,9 +59,9 @@ public:
* - (QSYMM16,QSYMM16) -> S32
*
* @param[in] compile_context The compile context to be used.
- * @param[in] src1 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
- * @param[in] src2 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
- * @param[out] dst The dst tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32.
+ * @param[in] src1 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
+ * @param[in] src2 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
+ * @param[out] dst The dst tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp
index 8a88bc919c..9e0a6243d7 100644
--- a/tests/validation/CL/PixelWiseMultiplication.cpp
+++ b/tests/validation/CL/PixelWiseMultiplication.cpp
@@ -82,6 +82,8 @@ template <typename T>
using CLPixelWiseMultiplicationToF32Fixture = PixelWiseMultiplicationValidationFloatFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>;
template <typename T>
using CLPixelWiseMultiplicationToF32BroadcastFixture = PixelWiseMultiplicationBroadcastValidationFloatFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>;
+template <typename T>
+using CLPixelWiseMultiplicationIntegerFixture = PixelWiseMultiplicationValidationIntegerFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, int>;
TEST_SUITE(CL)
TEST_SUITE(PixelWiseMultiplication)
@@ -116,6 +118,21 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
}
// clang-format on
// *INDENT-ON*
+TEST_SUITE(INT32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationIntegerFixture<int>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(
+ datasets::SmallShapes(),
+ framework::dataset::make("DataType1", DataType::S32)),
+ framework::dataset::make("DataType2", DataType::S32)),
+ framework::dataset::make("Scale", {1.f})),
+ datasets::ConvertPolicies()),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_UP)),
+ EmptyActivationFunctionsDataset),
+ InPlaceDataSet))
+{
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
TEST_SUITE(F16toF16)
TEST_SUITE(Scale255)
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
index c307421876..7c643bd726 100644
--- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
+++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
@@ -177,6 +177,18 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+class PixelWiseMultiplicationValidationIntegerFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, ActivationLayerInfo act_info, bool is_inplace)
+ {
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
class PixelWiseMultiplicationBroadcastValidationFloatFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
{
public: