aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-30 14:10:20 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-04-01 12:02:02 +0000
commit9428a182911802cf6e6df6eb751a7c7eb43602f9 (patch)
tree78247c5657c92fc692a68b4df0d7d34b66bea408
parentafc630fee1c019bfbc191c37d9d7fdf805b0b1d7 (diff)
downloadComputeLibrary-9428a182911802cf6e6df6eb751a7c7eb43602f9.tar.gz
COMPMID-3237: Add support for QSYMM16 into S32 NEPixelwiseMultiplicationKernel
Change-Id: I8dc3348db37b041f442639ac0d072740ca639878 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2960 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h22
-rw-r--r--arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h22
-rw-r--r--src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp175
-rw-r--r--tests/validation/CL/PixelWiseMultiplication.cpp76
-rw-r--r--tests/validation/NEON/PixelWiseMultiplication.cpp147
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h22
-rw-r--r--tests/validation/fixtures/PixelWiseMultiplicationFixture.h46
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.cpp57
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.h9
-rw-r--r--tests/validation/reference/QLSTMLayerNormalization.cpp2
10 files changed, 352 insertions, 226 deletions
diff --git a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
index 9b71ac81cf..1a9dd6be2e 100644
--- a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
+++ b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,15 @@ public:
*
* @param[in] input1 An input tensor. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32
* @param[in] input2 An input tensor. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if @p input1 is QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32).
- * @param[out] output Output tensor. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32).
+ * @param[out] output Output tensor. Data types supported:
+ * - U8, only if both inputs are U8.
+ * - QASYMM8, only if both inputs are QASYMM8.
+ * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED.
+ * - S16.
+ * - QSYMM16, only if both inputs are QSYMM16.
+ * - S32, only if both inputs are QSYMM16.
+ * - F16, only if @p input1 is F16.
+ * - F32, only if both inputs are F32.
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16.
@@ -72,7 +80,15 @@ public:
*
* @param[in] input1 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/QSYMM16/S16/F16/F32
* @param[in] input2 An input tensor info. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if @p input1 is QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32).
- * @param[in] output Output tensor info. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED) , S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32).
+ * @param[in] output Output tensor info. Data types supported:
+ * - U8, only if both inputs are U8.
+ * - QASYMM8, only if both inputs are QASYMM8.
+ * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED.
+ * - S16.
+ * - QSYMM16, only if both inputs are QSYMM16.
+ * - S32, only if both inputs are QSYMM16.
+ * - F16, only if @p input1 is F16.
+ * - F32, only if both inputs are F32.
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16.
diff --git a/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h b/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h
index 25f409871b..ede4327bfb 100644
--- a/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h
+++ b/arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,7 +44,15 @@ public:
* This input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
* @param[in, out] input2 An input tensor. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if @p input1 is QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32).
* This input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
- * @param[out] output Output tensor. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32).
+ * @param[out] output Output tensor. Data types supported:
+ * - U8, only if both inputs are U8.
+ * - QASYMM8, only if both inputs are QASYMM8.
+ * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED.
+ * - S16.
+ * - QSYMM16, only if both inputs are QSYMM16.
+ * - S32, only if both inputs are QSYMM16.
+ * - F16, only if @p input1 is F16.
+ * - F32, only if both inputs are F32.
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16.
@@ -58,7 +66,15 @@ public:
*
* @param[in] input1 An input tensor info. Data types supported: U8/QASYMM8/QASYMM8_SIGNED/S16/QSYMM16/F16/F32
* @param[in] input2 An input tensor info. Data types supported: U8, QASYMM8 (only if @p input1 is QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if @p input1 is F32).
- * @param[in] output Output tensor info. Data types supported: U8 (Only if both inputs are U8), QASYMM8 (only if both inputs are QASYMM8), QASYMM8_SIGNED (only if @p input1 is QASYMM8_SIGNED), S16, QSYMM16 (only if both inputs are QSYMM16), F16 (only if @p input1 is F16), F32 (only if both inputs are F32).
+ * @param[in] output Output tensor info. Data types supported:
+ * - U8, only if both inputs are U8.
+ * - QASYMM8, only if both inputs are QASYMM8.
+ * - QASYMM8_SIGNED, only if @p input1 is QASYMM8_SIGNED.
+ * - S16.
+ * - QSYMM16, only if both inputs are QSYMM16.
+ * - S32, only if both inputs are QSYMM16.
+ * - F16, only if @p input1 is F16.
+ * - F32, only if both inputs are F32.
* @param[in] scale Scale to apply after multiplication.
* Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
* @param[in] overflow_policy Overflow policy. ConvertPolicy cannot be WRAP if datatype is QASYMM8, QASYMM8_SIGNED or QSYMM16.
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index a87588dbb3..ca59e66293 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,23 +24,12 @@
#include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h"
#include "arm_compute/core/CPP/Validate.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/IAccessWindow.h"
-#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEAsymm.h"
-#include "arm_compute/core/NEON/NEFixedPoint.h"
#include "arm_compute/core/NEON/NESymm.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
#include <arm_neon.h>
-#include <climits>
-#include <cmath>
-#include <cstdint>
-#include <cstdlib>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_fp16.h> // needed for float16_t
@@ -48,8 +37,6 @@
namespace arm_compute
{
-class Coordinates;
-
namespace
{
const float scale255_constant = 1.f / 255.f;
@@ -66,10 +53,9 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
- "Output can only be U8 if both inputs are U8");
-
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ DataType::S16, DataType::QSYMM16,
+ DataType::S32, DataType::F16, DataType::F32);
if(is_data_type_quantized(input1->data_type()) || is_data_type_quantized(input2->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
@@ -86,6 +72,12 @@ inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *i
const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
+ "Output can only be U8 if both inputs are U8");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && (input1->data_type() != DataType::QSYMM16 || input2->data_type() != DataType::QSYMM16),
+ "Output can only be S32 if both inputs are QSYMM16");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 output");
}
if(std::abs(scale - scale255_constant) < 0.00001f)
@@ -266,8 +258,20 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c
const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr);
const auto output = static_cast<qsymm16_t *__restrict>(output_ptr);
- const qsymm16x8x2_t input1_q = vld2q_s16(input1);
- const qsymm16x8x2_t input2_q = vld2q_s16(input2);
+ const qsymm16x8x2_t input1_q =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const qsymm16x8x2_t input2_q =
+ {
+ {
+ vld1q_s16(input2),
+ vld1q_s16(input2 + 8),
+ }
+ };
// Dequantitize inputs
const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
@@ -284,7 +288,65 @@ void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, c
};
const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
- vst2q_s16(output, result);
+ vst1q_s16(output, result.val[0]);
+ vst1q_s16(output + 8, result.val[1]);
+}
+
+void mul_QSYMM16_QSYMM16_S32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int scale)
+{
+ ARM_COMPUTE_UNUSED(scale);
+ const auto input1 = static_cast<const qsymm16_t *__restrict>(input1_ptr);
+ const auto input2 = static_cast<const qsymm16_t *__restrict>(input2_ptr);
+ const auto output = static_cast<int32_t *__restrict>(output_ptr);
+
+ const qsymm16x8x2_t input1_q =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const qsymm16x8x2_t input2_q =
+ {
+ {
+ vld1q_s16(input2),
+ vld1q_s16(input2 + 8),
+ }
+ };
+
+ const int32x4x4_t in1_s32 =
+ {
+ {
+ vmovl_s16(vget_low_s16(input1_q.val[0])),
+ vmovl_s16(vget_high_s16(input1_q.val[0])),
+ vmovl_s16(vget_low_s16(input1_q.val[1])),
+ vmovl_s16(vget_high_s16(input1_q.val[1])),
+ }
+ };
+ const int32x4x4_t in2_s32 =
+ {
+ {
+ vmovl_s16(vget_low_s16(input2_q.val[0])),
+ vmovl_s16(vget_high_s16(input2_q.val[0])),
+ vmovl_s16(vget_low_s16(input2_q.val[1])),
+ vmovl_s16(vget_high_s16(input2_q.val[1])),
+ }
+ };
+
+ const int32x4x4_t result =
+ {
+ {
+ vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
+ vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
+ vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
+ vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
+ }
+ };
+
+ vst1q_s32(output, result.val[0]);
+ vst1q_s32(output + 4, result.val[1]);
+ vst1q_s32(output + 8, result.val[2]);
+ vst1q_s32(output + 12, result.val[3]);
}
template <bool is_scale255, bool is_sat>
@@ -412,11 +474,24 @@ void mul_S16_S16_S16_n(const void *__restrict input1_ptr, const void *__restrict
const auto input2 = static_cast<const int16_t *__restrict>(input2_ptr);
const auto output = static_cast<int16_t *__restrict>(output_ptr);
- const int16x8x2_t ta1 = vld2q_s16(input1);
- const int16x8x2_t ta2 = vld2q_s16(input2);
+ const int16x8x2_t ta1 =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const int16x8x2_t ta2 =
+ {
+ {
+ vld1q_s16(input2),
+ vld1q_s16(input2 + 8),
+ }
+ };
const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
- vst2q_s16(output, result);
+ vst1q_s16(output, result.val[0]);
+ vst1q_s16(output + 8, result.val[1]);
}
void mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
@@ -472,11 +547,23 @@ void c_mul_F32_F32_F32_n(const void *__restrict input1_ptr, const void *__restri
void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
- const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
- const auto output = static_cast<float16_t *__restrict>(output_ptr);
- const float16x8x2_t ta1 = vld2q_f16(input1);
- const float16x8x2_t ta2 = vld2q_f16(input2);
+ const auto input1 = static_cast<const float16_t *__restrict>(input1_ptr);
+ const auto input2 = static_cast<const float16_t *__restrict>(input2_ptr);
+ const auto output = static_cast<float16_t *__restrict>(output_ptr);
+ const float16x8x2_t ta1 =
+ {
+ {
+ vld1q_f16(input1),
+ vld1q_f16(input1 + 8),
+ }
+ };
+ const float16x8x2_t ta2 =
+ {
+ {
+ vld1q_f16(input2),
+ vld1q_f16(input2 + 8),
+ }
+ };
const float16x8_t scale_vec = vdupq_n_f16(scale);
const float16x8x2_t result =
{
@@ -485,7 +572,8 @@ void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict
vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
}
};
- vst2q_f16(output, result);
+ vst1q_f16(output, result.val[0]);
+ vst1q_f16(output + 8, result.val[1]);
#else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
ARM_COMPUTE_UNUSED(input1_ptr);
ARM_COMPUTE_UNUSED(input2_ptr);
@@ -550,8 +638,20 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict
const auto input2 = static_cast<const uint8_t *__restrict>(input2_ptr);
const auto output = static_cast<int16_t *__restrict>(output_ptr);
- const int16x8x2_t ta1 = vld2q_s16(input1);
- const uint8x8x2_t ta2u = vld2_u8(input2);
+ const int16x8x2_t ta1 =
+ {
+ {
+ vld1q_s16(input1),
+ vld1q_s16(input1 + 8),
+ }
+ };
+ const uint8x8x2_t ta2u =
+ {
+ {
+ vld1_u8(input2),
+ vld1_u8(input2 + 8),
+ }
+ };
const int16x8x2_t ta2 =
{
{
@@ -562,7 +662,8 @@ void mul_S16_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict
const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
- vst2q_s16(output, result);
+ vst1q_s16(output, result.val[0]);
+ vst1q_s16(output + 8, result.val[1]);
}
template <bool is_scale255, bool is_sat>
@@ -629,10 +730,14 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe
{
_func_quantized = &mul_saturate_QASYMM8_SIGNED_QASYMM8_SIGNED_QASYMM8_SIGNED_n;
}
- else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16)
+ else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
{
_func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16_n;
}
+ else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
+ {
+ _func_int = &mul_QSYMM16_QSYMM16_S32_n;
+ }
else if(DataType::U8 == dt_input1 && DataType::U8 == dt_input2 && DataType::U8 == dt_output)
{
if(is_scale_255)
@@ -750,7 +855,7 @@ void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo
Iterator input2(_input2, slice_input2);
Iterator output(_output, slice);
- if(is_data_type_quantized(_input1->info()->data_type()))
+ if((_run_optimized_qasymm8) || (_func_quantized != nullptr))
{
if(_run_optimized_qasymm8)
{
diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp
index 3b55e25f37..ff9101a997 100644
--- a/tests/validation/CL/PixelWiseMultiplication.cpp
+++ b/tests/validation/CL/PixelWiseMultiplication.cpp
@@ -127,14 +127,17 @@ using CLPixelWiseMultiplicationQuantizedFixture = PixelWiseMultiplicationValidat
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
@@ -142,14 +145,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<uint8
TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("OUtQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
@@ -157,26 +163,32 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int8_
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE(QSYMM16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
- framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("Scale", { 1.f, 2.f })),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
- framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLPixelWiseMultiplicationQuantizedFixture<int16_t>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
+ framework::dataset::make("Scale", { 1.f, 2.f })),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("RoundingPolicy", RoundingPolicy::TO_NEAREST_EVEN)),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(5.f / 32768.f, 0) })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_qsymm16);
diff --git a/tests/validation/NEON/PixelWiseMultiplication.cpp b/tests/validation/NEON/PixelWiseMultiplication.cpp
index fd54e42083..6a75b00b9b 100644
--- a/tests/validation/NEON/PixelWiseMultiplication.cpp
+++ b/tests/validation/NEON/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -70,20 +70,6 @@ const auto PixelWiseMultiplicationPolicySTZDataset = combine(
// *INDENT-OFF*
// clang-format off
-#define PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(DT1, DT2, SCALE, RP) \
- DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, \
- combine(combine(combine(combine(combine( \
- concat(datasets::SmallShapes(), datasets::LargeShapes()), \
- framework::dataset::make("DataType1", DataType::DT1)), \
- framework::dataset::make("DataType2", DataType::DT2)), \
- framework::dataset::make("Scale", std::move(SCALE))), \
- datasets::ConvertPolicies()), \
- framework::dataset::make("RoundingPolicy", RoundingPolicy::RP)), \
- shape, dt1, dt2, scale, convert_policy, rounding_policy) \
- { \
- validate_configuration(shape, dt1, dt2, scale, convert_policy, rounding_policy); \
- }
-
#define PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(TEST_NAME, FIXTURE, MODE, SHAPES, DT1, DT2, SCALE, RP, VALIDATE) \
FIXTURE_DATA_TEST_CASE(TEST_NAME, NEPixelWiseMultiplication##FIXTURE, framework::DatasetMode::MODE, \
combine(combine(combine(combine(combine( \
@@ -99,38 +85,12 @@ const auto PixelWiseMultiplicationPolicySTZDataset = combine(
// *INDENT-ON*
// clang-format on
-
-void validate_configuration(TensorShape shape, DataType dt1, DataType dt2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
-{
- Tensor src1 = create_tensor<Tensor>(shape, dt1);
- Tensor src2 = create_tensor<Tensor>(shape, dt2);
- Tensor dst = create_tensor<Tensor>(shape, dt2);
-
- ARM_COMPUTE_EXPECT(src1.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(src2.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
-
- // Create and configure function
- NEPixelWiseMultiplication multiply;
- multiply.configure(&src1, &src2, &dst, scale, convert_policy, rounding_policy);
-
- // Validate valid region
- const ValidRegion valid_region = shape_to_valid_region(shape);
- validate(src1.info()->valid_region(), valid_region);
- validate(src2.info()->valid_region(), valid_region);
- validate(dst.info()->valid_region(), valid_region);
-
- // Validate padding
- const PaddingSize padding = PaddingCalculator(shape.x(), 16).required_padding();
- validate(src1.info()->padding(), padding);
- validate(src2.info()->padding(), padding);
- validate(dst.info()->padding(), padding);
-}
} // namespace
using NEPixelWiseMultiplicationQASYMM8Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, uint8_t, uint8_t>;
using NEPixelWiseMultiplicationQASYMM8SignedFixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int8_t, int8_t>;
using NEPixelWiseMultiplicationQSYMM16Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int16_t, int16_t>;
+using NEPixelWiseMultiplicationQSYMM16ToS32Fixture = PixelWiseMultiplicationValidationQuantizedFixture<Tensor, Accessor, NEPixelWiseMultiplication, int16_t, int16_t, int32_t>;
template <typename T>
using NEPixelWiseMultiplicationToU8Fixture = PixelWiseMultiplicationValidationFixture<Tensor, Accessor, NEPixelWiseMultiplication, T, uint8_t>;
template <typename T>
@@ -231,8 +191,10 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8_SIGNED)
TEST_SUITE(Scale255)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8SignedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8SignedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8_SIGNED)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -245,8 +207,10 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8)
TEST_SUITE(Scale255)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -254,8 +218,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -265,8 +231,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framew
}
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -274,8 +242,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -285,8 +255,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framew
}
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -294,8 +266,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQASYMM8Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQASYMM8Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeIn2", DataType::QASYMM8)),
+ framework::dataset::make("DataTypeOut", DataType::QASYMM8)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQASYMM8QuantDataset))
@@ -307,8 +281,10 @@ TEST_SUITE_END() // ScaleOther
TEST_SUITE_END() // QASYMM8
TEST_SUITE(QSYMM16)
TEST_SUITE(Scale255)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -316,8 +292,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_255 })),
PixelWiseMultiplicationPolicySTNUDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -327,8 +305,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew
}
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -336,8 +316,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_unity })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -347,8 +329,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew
}
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -356,8 +340,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16Fixture, framew
// Validate output
validate(Accessor(_target), _reference, tolerance_qsymm16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::QSYMM16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::QSYMM16)),
framework::dataset::make("Scale", { scale_other })),
PixelWiseMultiplicationPolicySTZDataset),
PixelWiseMultiplicationQSYMM16QuantDataset))
@@ -367,24 +353,34 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPixelWiseMultiplicationQSYMM16Fixture, framew
}
TEST_SUITE_END() // ScaleOther
TEST_SUITE_END() // QSYMM16
+TEST_SUITE(QSYMM16toS32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPixelWiseMultiplicationQSYMM16ToS32Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataTypeIn1", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeIn2", DataType::QSYMM16)),
+ framework::dataset::make("DataTypeOut", DataType::S32)),
+ framework::dataset::make("Scale", { scale_unity })),
+ PixelWiseMultiplicationPolicySTZDataset),
+ PixelWiseMultiplicationQSYMM16QuantDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // QSYMM16toS32
TEST_SUITE_END() // Quantized
TEST_SUITE(U8toU8)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(uint8_t, 1))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(uint8_t, 1))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, U8, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToU8Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, U8, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToU8Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, U8, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
@@ -394,19 +390,16 @@ TEST_SUITE_END() // U8toU8
TEST_SUITE(U8toS16)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(U8, S16, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<uint8_t>, PRECOMMIT, SmallShapes(), U8, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<uint8_t>, NIGHTLY, LargeShapes(), U8, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
@@ -416,19 +409,16 @@ TEST_SUITE_END() // U8toS16
TEST_SUITE(S16toS16)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_255, TO_NEAREST_UP, WRAP_VALIDATE(int16_t, 2))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(S16, S16, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToS16Fixture<int16_t>, PRECOMMIT, SmallShapes(), S16, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToS16Fixture<int16_t>, NIGHTLY, LargeShapes(), S16, S16, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
@@ -448,19 +438,16 @@ TEST_SUITE_END() // F16toF16
TEST_SUITE(F32toF32)
TEST_SUITE(Scale255)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_255, TO_NEAREST_UP)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
TEST_SUITE_END() // Scale255
TEST_SUITE(ScaleUnity)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_unity, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_unity, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleUnity
TEST_SUITE(ScaleOther)
-PIXEL_WISE_MULTIPLICATION_DATA_TEST_CASE(F32, F32, scale_other, TO_ZERO)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, ToF32Fixture<float>, PRECOMMIT, SmallShapes(), F32, F32, scale_other, TO_ZERO, DEFAULT_VALIDATE)
PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunLarge, ToF32Fixture<float>, NIGHTLY, LargeShapes(), F32, F32, scale_other, TO_ZERO, DEFAULT_VALIDATE)
TEST_SUITE_END() // ScaleOther
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 9260686d56..858ee07d3e 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -407,7 +407,7 @@ protected:
if(peephole_opt)
{
- SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
+ SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, data_type);
forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
@@ -416,7 +416,7 @@ protected:
SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type };
fill(forget_layer_norm_w, 23);
forget_gate = reference::mean_std_normalization_layer(forget_gate);
- forget_gate = reference::pixel_wise_multiplication(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ forget_gate = reference::pixel_wise_multiplication<T, T, T>(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(forget_gate_bias, 7);
forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE);
}
@@ -438,7 +438,7 @@ protected:
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
}
if(use_layer_norm)
@@ -446,7 +446,7 @@ protected:
SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type };
fill(input_layer_norm_w, 22);
input_gate = reference::mean_std_normalization_layer(input_gate);
- input_gate = reference::pixel_wise_multiplication(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ input_gate = reference::pixel_wise_multiplication<T, T, T>(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(input_gate_bias, 17);
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE);
}
@@ -457,19 +457,19 @@ protected:
SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_cell_w);
gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
- SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
if(use_layer_norm)
{
SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type };
fill(cell_layer_norm_w, 24);
cell_state_out = reference::mean_std_normalization_layer(cell_state_out);
- cell_state_out = reference::pixel_wise_multiplication(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(cell_bias, 8);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
}
cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
if(cell_threshold != 0.f)
{
@@ -483,7 +483,7 @@ protected:
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ pixelwise_mul = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
if(use_layer_norm)
@@ -491,7 +491,7 @@ protected:
SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type };
fill(output_layer_norm_w, 25);
output = reference::mean_std_normalization_layer(output);
- output = reference::pixel_wise_multiplication(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ output = reference::pixel_wise_multiplication<T, T, T>(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
fill(output_gate_bias, 9);
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE);
}
@@ -499,7 +499,7 @@ protected:
// Compute output state
SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info);
- output_state_out = reference::pixel_wise_multiplication(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ output_state_out = reference::pixel_wise_multiplication<T, T, T>(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
if(projection_opt)
{
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
index efdf5d078e..37359f421b 100644
--- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
+++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,7 +39,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2>
class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture
{
public:
@@ -48,6 +48,7 @@ public:
const TensorShape &shape1,
DataType dt_in1,
DataType dt_in2,
+ DataType dt_out,
float scale,
ConvertPolicy convert_policy,
RoundingPolicy rounding_policy,
@@ -55,8 +56,8 @@ public:
QuantizationInfo qinfo1,
QuantizationInfo qinfo_out)
{
- _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
- _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
+ _target = compute_target(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
+ _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out);
}
protected:
@@ -66,14 +67,14 @@ protected:
library->fill_tensor_uniform(tensor, seed_offset);
}
- TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
+ TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// Create tensors
TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0);
TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1);
- TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2, 1, qinfo_out);
+ TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_out, 1, qinfo_out);
// Create and configure function
FunctionType multiply;
@@ -102,7 +103,7 @@ protected:
return dst;
}
- SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
+ SimpleTensor<T3> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
@@ -114,24 +115,11 @@ protected:
fill(src1, 0);
fill(src2, 1);
- return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy, qinfo_out);
+ return reference::pixel_wise_multiplication<T1, T2, T3>(src1, src2, scale, convert_policy, rounding_policy, dt_out, qinfo_out);
}
TensorType _target{};
- SimpleTensor<T2> _reference{};
-};
-
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationQuatizedValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
-{
-public:
- template <typename...>
- void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- QuantizationInfo in1_qua_info, QuantizationInfo in2_qua_info, QuantizationInfo out_qua_info)
- {
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
- in1_qua_info, in2_qua_info, out_qua_info);
- }
+ SimpleTensor<T3> _reference{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
@@ -141,7 +129,7 @@ public:
template <typename...>
void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
}
};
@@ -153,21 +141,21 @@ public:
template <typename...>
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy,
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo());
}
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2>
+class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType dt, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
- PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt, dt, scale, convert_policy, rounding_policy,
- qinfo0, qinfo1, qinfo_out);
+ PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, shape, dt_in1, dt_in2, dt_out, scale, convert_policy,
+ rounding_policy, qinfo0, qinfo1, qinfo_out);
}
};
} // namespace validation
diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp
index 2b4c849c39..3e21fca72a 100644
--- a/tests/validation/reference/PixelWiseMultiplication.cpp
+++ b/tests/validation/reference/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,16 +52,16 @@ namespace
* @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
* @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
*/
-template <typename T1, typename T2>
-T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+template <typename T1, typename T2, typename T3>
+T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
+ using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type;
const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
- if(is_floating_point<T2>::value)
+ if(is_floating_point<T3>::value)
{
- const auto result = static_cast<T2>(val);
+ const auto result = static_cast<T3>(val);
return result;
}
@@ -83,7 +83,7 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy,
ARM_COMPUTE_ERROR("Unsupported rounding policy");
}
- const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val);
+ const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val);
return result;
}
@@ -92,8 +92,8 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy,
template <size_t dim>
struct BroadcastUnroll
{
- template <typename T1, typename T2>
- static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ template <typename T1, typename T2, typename T3>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
{
@@ -117,23 +117,23 @@ struct BroadcastUnroll
template <>
struct BroadcastUnroll<0>
{
- template <typename T1, typename T2>
- static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ template <typename T1, typename T2, typename T3>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
{
- dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
+ dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
}
};
} // namespace
-template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ DataType dt_out, const QuantizationInfo &qout)
{
ARM_COMPUTE_UNUSED(qout);
- SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type());
+ SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out);
if(scale < 0)
{
@@ -151,15 +151,15 @@ SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const S
template <>
SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
{
SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout);
}
else
@@ -179,15 +179,15 @@ SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src
template <>
SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED)
{
SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_asymmetric<int8_t>(dst_tmp, qout);
}
else
@@ -207,15 +207,15 @@ SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1,
template <>
SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16)
{
SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
}
else
@@ -234,9 +234,10 @@ SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src
}
// *INDENT-OFF*
// clang-format off
-template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
-template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
-template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
+template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
// clang-format on
// *INDENT-ON*
} // namespace reference
diff --git a/tests/validation/reference/PixelWiseMultiplication.h b/tests/validation/reference/PixelWiseMultiplication.h
index f5b8e777fb..f8afa0384b 100644
--- a/tests/validation/reference/PixelWiseMultiplication.h
+++ b/tests/validation/reference/PixelWiseMultiplication.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,9 +34,10 @@ namespace validation
{
namespace reference
{
-template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale,
- ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout = QuantizationInfo());
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale,
+ ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out,
+ const QuantizationInfo &qout = QuantizationInfo());
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp
index 90d59b93ad..0e24de6584 100644
--- a/tests/validation/reference/QLSTMLayerNormalization.cpp
+++ b/tests/validation/reference/QLSTMLayerNormalization.cpp
@@ -41,7 +41,7 @@ namespace reference
SimpleTensor<float> qlstm_layer_normalization_float_compute(SimpleTensor<float> src, SimpleTensor<float> weight, SimpleTensor<float> bias)
{
SimpleTensor<float> output = mean_std_normalization_layer(src);
- output = pixel_wise_multiplication(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
+ output = pixel_wise_multiplication<float, float, float>(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32);
return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE);
}