diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-03-30 14:10:20 +0100 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2020-04-01 12:02:02 +0000 |
commit | 9428a182911802cf6e6df6eb751a7c7eb43602f9 (patch) | |
tree | 78247c5657c92fc692a68b4df0d7d34b66bea408 | |
parent | afc630fee1c019bfbc191c37d9d7fdf805b0b1d7 (diff) | |
download | ComputeLibrary-9428a182911802cf6e6df6eb751a7c7eb43602f9.tar.gz |
COMPMID-3237: Add support for QSYMM16 into S32 NEPixelwiseMultiplicationKernel
Change-Id: I8dc3348db37b041f442639ac0d072740ca639878
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2960
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
10 files changed, 352 insertions, 226 deletions
diff --git a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h index 9b71ac81cf..1a9dd6be2e 100644 --- a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h +++ b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 ARM Limited. + * Copyright (c) 2016-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -58,7 +58,15 @@ public: * * @param[in] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32 * @param[in] input2 An input tensor. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if @p input1 is QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32). - * @param[out] output Output tensor. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32). + * @param[out] output Output tensor. Data types supported: + * - U8, only if both inputs are U8. + * - QASYMM8, only if both inputs are QASYMM8. + * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED. + * - S16. + * - QSYMM16, only if both inputs are QSYMM16. + * - S32, only if both inputs are QSYMM16. + * - F16, only if @p input1 is F16. + * - F32, only if both inputs are F32. * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16. @@ -72,7 +80,15 @@ public: * * @param[in] input1 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32 * @param[in] input2 An input tensor info. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if @p input1 is QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32). - * @param[in] output Output tensor info. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED) , S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32). + * @param[in] output Output tensor info. Data types supported: + * - U8, only if both inputs are U8. + * - QASYMM8, only if both inputs are QASYMM8. + * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED. + * - S16. + * - QSYMM16, only if both inputs are QSYMM16. + * - S32, only if both inputs are QSYMM16. + * - F16, only if @p input1 is F16. + * - F32, only if both inputs are F32. * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16. diff --git a/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h b/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h index 25f409871b..ede4327bfb 100644 --- a/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h +++ b/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 ARM Limited. + * Copyright (c) 2016-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -44,7 +44,15 @@ public: * This input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. * @param[in, out] input2 An input tensor. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if @p input1 is QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32). * This input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. - * @param[out] output Output tensor. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32). + * @param[out] output Output tensor. Data types supported: + * - U8, only if both inputs are U8. + * - QASYMM8, only if both inputs are QASYMM8. + * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED. + * - S16. + * - QSYMM16, only if both inputs are QSYMM16. + * - S32, only if both inputs are QSYMM16. + * - F16, only if @p input1 is F16. + * - F32, only if both inputs are F32. * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16. @@ -58,7 +66,15 @@ public: * * @param[in] input1 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32 * @param[in] input2 An input tensor info. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32). - * @param[in] output Output tensor info. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32). + * @param[in] output Output tensor info. Data types supported: + * - U8, only if both inputs are U8. + * - QASYMM8, only if both inputs are QASYMM8. + * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED. + * - S16. + * - QSYMM16, only if both inputs are QSYMM16. + * - S32, only if both inputs are QSYMM16. + * - F16, only if @p input1 is F16. + * - F32, only if both inputs are F32. * @param[in] scale Scale to apply after multiplication. * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. * @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16. diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index a87588dbb3..ca59e66293 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 ARM Limited. + * Copyright (c) 2016-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,23 +24,12 @@ #include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h" #include "arm_compute/core/CPP/Validate.h" -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/IAccessWindow.h" -#include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/NEAsymm.h" -#include "arm_compute/core/NEON/NEFixedPoint.h" #include "arm_compute/core/NEON/NESymm.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" #include <arm_neon.h> -#include <climits> -#include <cmath> -#include <cstdint> -#include <cstdlib> #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include <arm_fp16.h> // needed for float16_t @@ -48,8 +37,6 @@ namespace arm_compute { -class Coordinates; - namespace { const float scale255_constant = 1.f / 255.f; @@ -66,10 +53,9 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, + DataType::S16, DataType::QSYMM16, + DataType::S32, DataType::F16, DataType::F32); if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type())) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); @@ -86,6 +72,12 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16), + "Output can only be S32 if both inputs are QSYMM16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output"); } if(std::abs(scale - scale255_constant) < 0.00001f) @@ -266,8 +258,20 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr); const auto output = static_cast<qsymm16_t *__restrict>(output_ptr); - const qsymm16x8x2_t input1_q = vld2q_s16(input1); - const qsymm16x8x2_t input2_q = vld2q_s16(input2); + const qsymm16x8x2_t input1_q = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const qsymm16x8x2_t input2_q = + { + { + vld1q_s16(input2), + vld1q_s16(input2 + 8), + } + }; // Dequantitize inputs const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info); @@ -284,7 +288,65 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c }; const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info); - vst2q_s16(output, result); + vst1q_s16(output, result.val[0]); + vst1q_s16(output + 8, result.val[1]); +} + +void mul_QSYMM16_QSYMM16_S32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int scale) +{ + ARM_COMPUTE_UNUSED(scale); + const auto input1 = static_cast<const qsymm16_t *__restrict>(input1_ptr); + const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr); + const auto output = static_cast<int32_t *__restrict>(output_ptr); + + const qsymm16x8x2_t input1_q = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const qsymm16x8x2_t input2_q = + { + { + vld1q_s16(input2), + vld1q_s16(input2 + 8), + } + }; + + const int32x4x4_t in1_s32 = + { + { + vmovl_s16(vget_low_s16(input1_q.val[0])), + vmovl_s16(vget_high_s16(input1_q.val[0])), + vmovl_s16(vget_low_s16(input1_q.val[1])), + vmovl_s16(vget_high_s16(input1_q.val[1])), + } + }; + const int32x4x4_t in2_s32 = + { + { + vmovl_s16(vget_low_s16(input2_q.val[0])), + vmovl_s16(vget_high_s16(input2_q.val[0])), + vmovl_s16(vget_low_s16(input2_q.val[1])), + vmovl_s16(vget_high_s16(input2_q.val[1])), + } + }; + + const int32x4x4_t result = + { + { + vmulq_s32(in1_s32.val[0], in2_s32.val[0]), + vmulq_s32(in1_s32.val[1], in2_s32.val[1]), + vmulq_s32(in1_s32.val[2], in2_s32.val[2]), + vmulq_s32(in1_s32.val[3], in2_s32.val[3]), + } + }; + + vst1q_s32(output, result.val[0]); + vst1q_s32(output + 4, result.val[1]); + vst1q_s32(output + 8, result.val[2]); + vst1q_s32(output + 12, result.val[3]); } template <bool is_scale255, bool is_sat> @@ -412,11 +474,24 @@ void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict const auto input2 = static_cast<const int16_t *__restrict>(input2_ptr); const auto output = static_cast<int16_t *__restrict>(output_ptr); - const int16x8x2_t ta1 = vld2q_s16(input1); - const int16x8x2_t ta2 = vld2q_s16(input2); + const int16x8x2_t ta1 = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const int16x8x2_t ta2 = + { + { + vld1q_s16(input2), + vld1q_s16(input2 + 8), + } + }; const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n); - vst2q_s16(output, result); + vst1q_s16(output, result.val[0]); + vst1q_s16(output + 8, result.val[1]); } void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) @@ -472,11 +547,23 @@ void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restri void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale) { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr); - const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr); - const auto output = static_cast<float16_t *__restrict>(output_ptr); - const float16x8x2_t ta1 = vld2q_f16(input1); - const float16x8x2_t ta2 = vld2q_f16(input2); + const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr); + const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr); + const auto output = static_cast<float16_t *__restrict>(output_ptr); + const float16x8x2_t ta1 = + { + { + vld1q_f16(input1), + vld1q_f16(input1 + 8), + } + }; + const float16x8x2_t ta2 = + { + { + vld1q_f16(input2), + vld1q_f16(input2 + 8), + } + }; const float16x8_t scale_vec = vdupq_n_f16(scale); const float16x8x2_t result = { @@ -485,7 +572,8 @@ void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec), } }; - vst2q_f16(output, result); + vst1q_f16(output, result.val[0]); + vst1q_f16(output + 8, result.val[1]); #else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ ARM_COMPUTE_UNUSED(input1_ptr); ARM_COMPUTE_UNUSED(input2_ptr); @@ -550,8 +638,20 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr); const auto output = static_cast<int16_t *__restrict>(output_ptr); - const int16x8x2_t ta1 = vld2q_s16(input1); - const uint8x8x2_t ta2u = vld2_u8(input2); + const int16x8x2_t ta1 = + { + { + vld1q_s16(input1), + vld1q_s16(input1 + 8), + } + }; + const uint8x8x2_t ta2u = + { + { + vld1_u8(input2), + vld1_u8(input2 + 8), + } + }; const int16x8x2_t ta2 = { { @@ -562,7 +662,8 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n); - vst2q_s16(output, result); + vst1q_s16(output, result.val[0]); + vst1q_s16(output + 8, result.val[1]); } template <bool is_scale255, bool is_sat> @@ -629,10 +730,14 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe { _func_quantized = &mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n; } - else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16) + else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16) { _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n; } + else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32) + { + _func_int = &mul_QSYMM16_QSYMM16_S32_n; + } else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output) { if(is_scale_255) @@ -750,7 +855,7 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo Iterator input2(_input2, slice_input2); Iterator output(_output, slice); - if(is_data_type_quantized(_input1->info()->data_type())) + if((_run_optimized_qasymm8) || (_func_quantized != nullptr)) { if(_run_optimized_qasymm8) { diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp index 3b55e25f37..ff9101a997 100644 --- a/tests/validation/CL/PixelWiseMultiplication.cpp +++ b/tests/validation/CL/PixelWiseMultiplication.cpp @@ -127,14 +127,17 @@ using CLPixelWiseMultiplicationQuantizedFixture = PixelWiseMultiplicationValidat TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("Scale", { 1.f, 2.f })), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), - framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), - framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), - framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), - framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) }))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), + framework::dataset::make("Scale", { 1.f, 2.f })), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), + framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), + framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), + framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), + framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qasymm8); @@ -142,14 +145,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8 TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) -FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("Scale", { 1.f, 2.f })), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), - framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), - framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), - framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), - framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) }))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8_SIGNED)), + framework::dataset::make("Scale", { 1.f, 2.f })), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), + framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), + framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), + framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), + framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qasymm8); @@ -157,26 +163,32 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_ TEST_SUITE_END() // QASYMM8_SIGNED TEST_SUITE(QSYMM16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), - framework::dataset::make("Scale", { 1.f, 2.f })), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), - framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), - framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })), - framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })), - framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) }))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), + framework::dataset::make("Scale", { 1.f, 2.f })), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), + framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), + framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })), + framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })), + framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qsymm16); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), - framework::dataset::make("Scale", { 1.f, 2.f })), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), - framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), - framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })), - framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })), - framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) }))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), + framework::dataset::make("Scale", { 1.f, 2.f })), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), + framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)), + framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })), + framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })), + framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) }))) { // Validate output validate(CLAccessor(_target), _reference, tolerance_qsymm16); diff --git a/tests/validation/NEON/PixelWiseMultiplication.cpp b/tests/validation/NEON/PixelWiseMultiplication.cpp index fd54e42083..6a75b00b9b 100644 --- a/tests/validation/NEON/PixelWiseMultiplication.cpp +++ b/tests/validation/NEON/PixelWiseMultiplication.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -70,20 +70,6 @@ const auto PixelWiseMultiplicationPolicySTZDataset = combine( // *INDENT-OFF* // clang-format off -#define PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(DT1, DT2, SCALE, RP) \ - DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, \ - combine(combine(combine(combine(combine( \ - concat(datasets::SmallShapes(), datasets::LargeShapes()), \ - framework::dataset::make("DataType1", DataType::DT1)), \ - framework::dataset::make("DataType2", DataType::DT2)), \ - framework::dataset::make("Scale", std::move(SCALE))), \ - datasets::ConvertPolicies()), \ - framework::dataset::make("RoundingPolicy", RoundingPolicy::RP)), \ - shape, dt1, dt2, scale, convert_policy, rounding_policy) \ - { \ - validate_configuration(shape, dt1, dt2, scale, convert_policy, rounding_policy); \ - } - #define PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, SHAPES, DT1, DT2, SCALE, RP, VALIDATE) \ FIXTURE_DATA_TEST_CASE(TEST_NAME, NEPixelWiseMultiplication##FIXTURE, framework::DatasetMode::MODE, \ combine(combine(combine(combine(combine( \ @@ -99,38 +85,12 @@ const auto PixelWiseMultiplicationPolicySTZDataset = combine( // *INDENT-ON* // clang-format on - -void validate_configuration(TensorShape shape, DataType dt1, DataType dt2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) -{ - Tensor src1 = create_tensor<Tensor>(shape, dt1); - Tensor src2 = create_tensor<Tensor>(shape, dt2); - Tensor dst = create_tensor<Tensor>(shape, dt2); - - ARM_COMPUTE_EXPECT(src1.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(src2.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); - - // Create and configure function - NEPixelWiseMultiplication multiply; - multiply.configure(&src1, &src2, &dst, scale, convert_policy, rounding_policy); - - // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape); - validate(src1.info()->valid_region(), valid_region); - validate(src2.info()->valid_region(), valid_region); - validate(dst.info()->valid_region(), valid_region); - - // Validate padding - const PaddingSize padding = PaddingCalculator(shape.x(), 16).required_padding(); - validate(src1.info()->padding(), padding); - validate(src2.info()->padding(), padding); - validate(dst.info()->padding(), padding); -} } // namespace using NEPixelWiseMultiplicationQASYMM8Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, uint8_t, uint8_t>; using NEPixelWiseMultiplicationQASYMM8SignedFixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int8_t, int8_t>; using NEPixelWiseMultiplicationQSYMM16Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int16_t, int16_t>; +using NEPixelWiseMultiplicationQSYMM16ToS32Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int16_t, int16_t, int32_t>; template <typename T> using NEPixelWiseMultiplicationToU8Fixture = PixelWiseMultiplicationValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, T, uint8_t>; template <typename T> @@ -231,8 +191,10 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( TEST_SUITE(Quantized) TEST_SUITE(QASYMM8_SIGNED) TEST_SUITE(Scale255) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8SignedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8SignedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8_SIGNED)), framework::dataset::make("Scale", { scale_unity })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -245,8 +207,10 @@ TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8) TEST_SUITE(Scale255) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), framework::dataset::make("Scale", { scale_255 })), PixelWiseMultiplicationPolicySTNUDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -254,8 +218,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), +FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), framework::dataset::make("Scale", { scale_255 })), PixelWiseMultiplicationPolicySTNUDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -265,8 +231,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framew } TEST_SUITE_END() // Scale255 TEST_SUITE(ScaleUnity) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), framework::dataset::make("Scale", { scale_unity })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -274,8 +242,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), +FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), framework::dataset::make("Scale", { scale_unity })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -285,8 +255,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framew } TEST_SUITE_END() // ScaleUnity TEST_SUITE(ScaleOther) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), framework::dataset::make("Scale", { scale_other })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -294,8 +266,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QASYMM8)), +FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QASYMM8)), + framework::dataset::make("DataTypeIn2", DataType::QASYMM8)), + framework::dataset::make("DataTypeOut", DataType::QASYMM8)), framework::dataset::make("Scale", { scale_other })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQASYMM8QuantDataset)) @@ -307,8 +281,10 @@ TEST_SUITE_END() // ScaleOther TEST_SUITE_END() // QASYMM8 TEST_SUITE(QSYMM16) TEST_SUITE(Scale255) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), framework::dataset::make("Scale", { scale_255 })), PixelWiseMultiplicationPolicySTNUDataset), PixelWiseMultiplicationQSYMM16QuantDataset)) @@ -316,8 +292,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew // Validate output validate(Accessor(_target), _reference, tolerance_qsymm16); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), +FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), framework::dataset::make("Scale", { scale_255 })), PixelWiseMultiplicationPolicySTNUDataset), PixelWiseMultiplicationQSYMM16QuantDataset)) @@ -327,8 +305,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew } TEST_SUITE_END() // Scale255 TEST_SUITE(ScaleUnity) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), framework::dataset::make("Scale", { scale_unity })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQSYMM16QuantDataset)) @@ -336,8 +316,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew // Validate output validate(Accessor(_target), _reference, tolerance_qsymm16); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), +FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), framework::dataset::make("Scale", { scale_unity })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQSYMM16QuantDataset)) @@ -347,8 +329,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew } TEST_SUITE_END() // ScaleUnity TEST_SUITE(ScaleOther) -FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), framework::dataset::make("Scale", { scale_other })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQSYMM16QuantDataset)) @@ -356,8 +340,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew // Validate output validate(Accessor(_target), _reference, tolerance_qsymm16); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - framework::dataset::make("DataType", DataType::QSYMM16)), +FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::QSYMM16)), framework::dataset::make("Scale", { scale_other })), PixelWiseMultiplicationPolicySTZDataset), PixelWiseMultiplicationQSYMM16QuantDataset)) @@ -367,24 +353,34 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew } TEST_SUITE_END() // ScaleOther TEST_SUITE_END() // QSYMM16 +TEST_SUITE(QSYMM16toS32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16ToS32Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), + framework::dataset::make("DataTypeIn1", DataType::QSYMM16)), + framework::dataset::make("DataTypeIn2", DataType::QSYMM16)), + framework::dataset::make("DataTypeOut", DataType::S32)), + framework::dataset::make("Scale", { scale_unity })), + PixelWiseMultiplicationPolicySTZDataset), + PixelWiseMultiplicationQSYMM16QuantDataset)) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() // QSYMM16toS32 TEST_SUITE_END() // Quantized TEST_SUITE(U8toU8) TEST_SUITE(Scale255) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_255, TO_NEAREST_UP) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(uint8_t, 1)) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(uint8_t, 1)) TEST_SUITE_END() // Scale255 TEST_SUITE(ScaleUnity) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_unity, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_unity, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_unity, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleUnity TEST_SUITE(ScaleOther) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_other, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_other, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_other, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleOther @@ -394,19 +390,16 @@ TEST_SUITE_END() // U8toU8 TEST_SUITE(U8toS16) TEST_SUITE(Scale255) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_255, TO_NEAREST_UP) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2)) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2)) TEST_SUITE_END() // Scale255 TEST_SUITE(ScaleUnity) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_unity, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleUnity TEST_SUITE(ScaleOther) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_other, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleOther @@ -416,19 +409,16 @@ TEST_SUITE_END() // U8toS16 TEST_SUITE(S16toS16) TEST_SUITE(Scale255) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_255, TO_NEAREST_UP) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2)) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2)) TEST_SUITE_END() // Scale255 TEST_SUITE(ScaleUnity) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_unity, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleUnity TEST_SUITE(ScaleOther) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_other, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleOther @@ -448,19 +438,16 @@ TEST_SUITE_END() // F16toF16 TEST_SUITE(F32toF32) TEST_SUITE(Scale255) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_255, TO_NEAREST_UP) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f)) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f)) TEST_SUITE_END() // Scale255 TEST_SUITE(ScaleUnity) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_unity, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_unity, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_unity, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleUnity TEST_SUITE(ScaleOther) -PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_other, TO_ZERO) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_other, TO_ZERO, DEFAULT_VALIDATE) PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_other, TO_ZERO, DEFAULT_VALIDATE) TEST_SUITE_END() // ScaleOther diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index 9260686d56..858ee07d3e 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -407,7 +407,7 @@ protected: if(peephole_opt) { - SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); + SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type); forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); } @@ -416,7 +416,7 @@ protected: SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type }; fill(forget_layer_norm_w, 23); forget_gate = reference::mean_std_normalization_layer(forget_gate); - forget_gate = reference::pixel_wise_multiplication(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); fill(forget_gate_bias, 7); forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE); } @@ -438,7 +438,7 @@ protected: input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { - SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); } if(use_layer_norm) @@ -446,7 +446,7 @@ protected: SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type }; fill(input_layer_norm_w, 22); input_gate = reference::mean_std_normalization_layer(input_gate); - input_gate = reference::pixel_wise_multiplication(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); fill(input_gate_bias, 17); input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE); } @@ -457,19 +457,19 @@ protected: SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape); transposed_weights = reference::transpose(recurrent_to_cell_w); gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); - SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE); if(use_layer_norm) { SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type }; fill(cell_layer_norm_w, 24); cell_state_out = reference::mean_std_normalization_layer(cell_state_out); - cell_state_out = reference::pixel_wise_multiplication(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); fill(cell_bias, 8); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE); } cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); - cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); if(cell_threshold != 0.f) { @@ -483,7 +483,7 @@ protected: output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { - pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); } if(use_layer_norm) @@ -491,7 +491,7 @@ protected: SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type }; fill(output_layer_norm_w, 25); output = reference::mean_std_normalization_layer(output); - output = reference::pixel_wise_multiplication(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); fill(output_gate_bias, 9); output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE); } @@ -499,7 +499,7 @@ protected: // Compute output state SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info); - output_state_out = reference::pixel_wise_multiplication(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + output_state_out = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); if(projection_opt) { diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h index efdf5d078e..37359f421b 100644 --- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h +++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -39,7 +39,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2> class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture { public: @@ -48,6 +48,7 @@ public: const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, @@ -55,8 +56,8 @@ public: QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); - _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); + _target = compute_target(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); + _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); } protected: @@ -66,14 +67,14 @@ protected: library->fill_tensor_uniform(tensor, seed_offset); } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // Create tensors TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0); TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1); - TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2, 1, qinfo_out); + TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_out, 1, qinfo_out); // Create and configure function FunctionType multiply; @@ -102,7 +103,7 @@ protected: return dst; } - SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + SimpleTensor<T3> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { @@ -114,24 +115,11 @@ protected: fill(src1, 0); fill(src2, 1); - return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy, qinfo_out); + return reference::pixel_wise_multiplication<T1, T2, T3>(src1, src2, scale, convert_policy, rounding_policy, dt_out, qinfo_out); } TensorType _target{}; - SimpleTensor<T2> _reference{}; -}; - -template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> -class PixelWiseMultiplicationQuatizedValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> -{ -public: - template <typename...> - void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - QuantizationInfo in1_qua_info, QuantizationInfo in2_qua_info, QuantizationInfo out_qua_info) - { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy, - in1_qua_info, in2_qua_info, out_qua_info); - } + SimpleTensor<T3> _reference{}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> @@ -141,7 +129,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy, + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; @@ -153,21 +141,21 @@ public: template <typename...> void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> -class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2> +class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3> { public: template <typename...> - void setup(const TensorShape &shape, DataType dt, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt, dt, scale, convert_policy, rounding_policy, - qinfo0, qinfo1, qinfo_out); + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, shape, dt_in1, dt_in2, dt_out, scale, convert_policy, + rounding_policy, qinfo0, qinfo1, qinfo_out); } }; } // namespace validation diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp index 2b4c849c39..3e21fca72a 100644 --- a/tests/validation/reference/PixelWiseMultiplication.cpp +++ b/tests/validation/reference/PixelWiseMultiplication.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -52,16 +52,16 @@ namespace * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. */ -template <typename T1, typename T2> -T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) +template <typename T1, typename T2, typename T3> +T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type; + using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type; const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale); - if(is_floating_point<T2>::value) + if(is_floating_point<T3>::value) { - const auto result = static_cast<T2>(val); + const auto result = static_cast<T3>(val); return result; } @@ -83,7 +83,7 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, ARM_COMPUTE_ERROR("Unsupported rounding policy"); } - const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val); + const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val); return result; } @@ -92,8 +92,8 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, template <size_t dim> struct BroadcastUnroll { - template <typename T1, typename T2> - static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst, + template <typename T1, typename T2, typename T3> + static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { @@ -117,23 +117,23 @@ struct BroadcastUnroll template <> struct BroadcastUnroll<0> { - template <typename T1, typename T2> - static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst, + template <typename T1, typename T2, typename T3> + static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { - dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); + dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); } }; } // namespace -template <typename T1, typename T2> -SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + DataType dt_out, const QuantizationInfo &qout) { ARM_COMPUTE_UNUSED(qout); - SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type()); + SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out); if(scale < 0) { @@ -151,15 +151,15 @@ SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const S template <> SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) + DataType dt_out, const QuantizationInfo &qout) { - SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout); + SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8) { SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); - SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout); + SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout); } else @@ -179,15 +179,15 @@ SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src template <> SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) + DataType dt_out, const QuantizationInfo &qout) { - SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout); + SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED) { SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); - SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout); + SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); dst = convert_to_asymmetric<int8_t>(dst_tmp, qout); } else @@ -207,15 +207,15 @@ SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, template <> SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) + DataType dt_out, const QuantizationInfo &qout) { - SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout); + SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16) { SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1); SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2); - SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout); + SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); dst = convert_to_symmetric<int16_t>(dst_tmp, qout); } else @@ -234,9 +234,10 @@ SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src } // *INDENT-OFF* // clang-format off -template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout); -template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout); -template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout); +template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); +template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); +template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); +template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); // clang-format on // *INDENT-ON* } // namespace reference diff --git a/tests/validation/reference/PixelWiseMultiplication.h b/tests/validation/reference/PixelWiseMultiplication.h index f5b8e777fb..f8afa0384b 100644 --- a/tests/validation/reference/PixelWiseMultiplication.h +++ b/tests/validation/reference/PixelWiseMultiplication.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -34,9 +34,10 @@ namespace validation { namespace reference { -template <typename T1, typename T2> -SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, - ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout = QuantizationInfo()); +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, + ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, + const QuantizationInfo &qout = QuantizationInfo()); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp index 90d59b93ad..0e24de6584 100644 --- a/tests/validation/reference/QLSTMLayerNormalization.cpp +++ b/tests/validation/reference/QLSTMLayerNormalization.cpp @@ -41,7 +41,7 @@ namespace reference SimpleTensor<float> qlstm_layer_normalization_float_compute(SimpleTensor<float> src, SimpleTensor<float> weight, SimpleTensor<float> bias) { SimpleTensor<float> output = mean_std_normalization_layer(src); - output = pixel_wise_multiplication(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); + output = pixel_wise_multiplication<float, float, float>(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32); return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE); } |