diff options
-rw-r--r-- | arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h | 15 | ||||
-rw-r--r-- | src/core/CL/cl_kernels/pixelwise_mul_int.cl | 1 | ||||
-rw-r--r-- | src/core/gpu/cl/kernels/ClMulKernel.cpp | 15 | ||||
-rw-r--r-- | src/core/gpu/cl/kernels/ClMulKernel.h | 7 | ||||
-rw-r--r-- | tests/validation/CL/PixelWiseMultiplication.cpp | 17 | ||||
-rw-r--r-- | tests/validation/fixtures/PixelWiseMultiplicationFixture.h | 12 |
6 files changed, 51 insertions, 16 deletions
diff --git a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h index 029b924512..d352c6e282 100644 --- a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h +++ b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h @@ -68,13 +68,14 @@ public: * |S16 |U8 |S16 | * |S16 |S16 |S16 | * |F16 |F16 |F16 | - * |F32 |S32 |F32 | + * |F32 |F32 |F32 | + * |S32 |S32 |S32 | * - * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. - * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. - * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate @@ -86,11 +87,11 @@ public: /** Initialise the kernel's inputs, output and convertion policy. * * @param[in] compile_context The compile context to be used. - * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[in, out] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. - * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[in, out] input2 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. - * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[out] output The output tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate diff --git a/src/core/CL/cl_kernels/pixelwise_mul_int.cl b/src/core/CL/cl_kernels/pixelwise_mul_int.cl index 32c46def77..ac5cabcb8c 100644 --- a/src/core/CL/cl_kernels/pixelwise_mul_int.cl +++ b/src/core/CL/cl_kernels/pixelwise_mul_int.cl @@ -90,7 +90,6 @@ __kernel void pixelwise_mul_int( // Load data VEC_ACC_TYPE in1_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN1, VEC_SIZE_OUT))VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_IN1 *)in1_addr), VEC_ACC_TYPE); VEC_ACC_TYPE in2_data = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN2, VEC_SIZE_OUT))VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_IN2 *)in2_addr), VEC_ACC_TYPE); - // Perform multiplication and store result VEC_OUT_TYPE out_data0 = MUL_OP(in1_data, in2_data, scale, DATA_TYPE_OUT, VEC_SIZE_OUT); STORE_VECTOR_SELECT(out_data, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0); diff --git a/src/core/gpu/cl/kernels/ClMulKernel.cpp b/src/core/gpu/cl/kernels/ClMulKernel.cpp index 837324ede2..b8081bbacf 100644 --- a/src/core/gpu/cl/kernels/ClMulKernel.cpp +++ b/src/core/gpu/cl/kernels/ClMulKernel.cpp @@ -53,12 +53,12 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, + DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, + DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !is_data_type_float(dst->data_type())); @@ -83,8 +83,8 @@ Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, cons "Dst can only be QASYMM8_SIGNED if both src are QASYMM8_SIGNED"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::QSYMM16 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16), "Dst can only be QSYMM16 if both src are QSYMM16"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::S32 && (src1->data_type() != DataType::QSYMM16 || src2->data_type() != DataType::QSYMM16), - "Dst can only be S32 if both src are QSYMM16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src1->data_type() == DataType::S32 || src2->data_type() == DataType::S32) && (dst->data_type() != DataType::S32), + "Dst must be S32 if source tensors are S32"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst"); } @@ -127,7 +127,12 @@ void ClMulKernel::configure(const CLCompileContext &compile_context, ITensorInfo } else { - if(src1->element_size() == 2 || src2->element_size() == 2) + if(src1->element_size() == 4 || src2->element_size() == 4) + { + // use 64 bit accumulator for 32-bit input + acc_type = "long"; + } + else if(src1->element_size() == 2 || src2->element_size() == 2) { // Use 32-bit accumulator for 16-bit input acc_type = "int"; diff --git a/src/core/gpu/cl/kernels/ClMulKernel.h b/src/core/gpu/cl/kernels/ClMulKernel.h index e2e54a836e..44162f3db3 100644 --- a/src/core/gpu/cl/kernels/ClMulKernel.h +++ b/src/core/gpu/cl/kernels/ClMulKernel.h @@ -50,6 +50,7 @@ public: * - (U8,S16) -> S16 * - (S16,U8) -> S16 * - (S16,S16) -> S16 + * - (S32,S32) -> S32 * - (F16,F16) -> F16 * - (F32,F32) -> F32 * - (QASYMM8,QASYMM8) -> QASYMM8 @@ -58,9 +59,9 @@ public: * - (QSYMM16,QSYMM16) -> S32 * * @param[in] compile_context The compile context to be used. - * @param[in] src1 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. - * @param[in] src2 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. - * @param[out] dst The dst tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32. + * @param[in] src1 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 + * @param[in] src2 An src tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 + * @param[out] dst The dst tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32/S32 * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp index 8a88bc919c..9e0a6243d7 100644 --- a/tests/validation/CL/PixelWiseMultiplication.cpp +++ b/tests/validation/CL/PixelWiseMultiplication.cpp @@ -82,6 +82,8 @@ template <typename T> using CLPixelWiseMultiplicationToF32Fixture = PixelWiseMultiplicationValidationFloatFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>; template <typename T> using CLPixelWiseMultiplicationToF32BroadcastFixture = PixelWiseMultiplicationBroadcastValidationFloatFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>; +template <typename T> +using CLPixelWiseMultiplicationIntegerFixture = PixelWiseMultiplicationValidationIntegerFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, int>; TEST_SUITE(CL) TEST_SUITE(PixelWiseMultiplication) @@ -116,6 +118,21 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( } // clang-format on // *INDENT-ON* +TEST_SUITE(INT32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationIntegerFixture<int>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine(combine( + datasets::SmallShapes(), + framework::dataset::make("DataType1", DataType::S32)), + framework::dataset::make("DataType2", DataType::S32)), + framework::dataset::make("Scale", {1.f})), + datasets::ConvertPolicies()), + framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_UP)), + EmptyActivationFunctionsDataset), + InPlaceDataSet)) +{ + validate(CLAccessor(_target), _reference); +} +TEST_SUITE_END() TEST_SUITE(F16toF16) TEST_SUITE(Scale255) diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h index c307421876..7c643bd726 100644 --- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h +++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h @@ -177,6 +177,18 @@ public: }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> +class PixelWiseMultiplicationValidationIntegerFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, ActivationLayerInfo act_info, bool is_inplace) + { + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> class PixelWiseMultiplicationBroadcastValidationFloatFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> { public: |