aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-10-17 18:37:26 +0100
committerPablo Marquez <pablo.tello@arm.com>2019-10-24 09:11:40 +0000
commit07263980e66059a91ce57612e4ca8f4b2a2a206a (patch)
tree138dc3ecf835df9f38a60959379a52eca08f8b0f
parent05069f07bcf95676597698a79926327555276362 (diff)
downloadComputeLibrary-07263980e66059a91ce57612e4ca8f4b2a2a206a.tar.gz
COMPMID-2501: Support multiplier > 1 during QASYMM8 requantization for Quantized LSTM
Change-Id: I7eddbdf77881f313b707b9e59428245f1330a2cf Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/2119 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
-rw-r--r--arm_compute/core/NEON/NESymm.h41
-rw-r--r--arm_compute/core/utils/quantization/AsymmHelpers.h13
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl4
-rw-r--r--src/core/utils/quantization/AsymmHelpers.cpp14
-rw-r--r--src/runtime/CL/functions/CLLSTMLayerQuantized.cpp12
-rw-r--r--src/runtime/NEON/functions/NELSTMLayerQuantized.cpp12
-rw-r--r--tests/validation/CL/GEMMLowp.cpp30
-rw-r--r--tests/validation/CL/LSTMLayerQuantized.cpp152
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp34
-rw-r--r--tests/validation/NEON/LSTMLayerQuantized.cpp151
-rw-r--r--tests/validation/reference/GEMMLowp.cpp9
11 files changed, 432 insertions, 40 deletions
diff --git a/arm_compute/core/NEON/NESymm.h b/arm_compute/core/NEON/NESymm.h
index a60d5d0fde..8345e0be91 100644
--- a/arm_compute/core/NEON/NESymm.h
+++ b/arm_compute/core/NEON/NESymm.h
@@ -54,13 +54,23 @@ int16x8_t finalize_quantization_int16(int32x4x2_t &in_s32,
int16x8_t min_s16,
int16x8_t max_s16)
{
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
- in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ if(result_shift < 0)
+ {
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], (1 << -result_shift));
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], (1 << -result_shift));
- // Round to the nearest division by a power-of-two using result_shift_s32
- in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
- in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+ in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+ in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ }
+ else
+ {
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+ in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ // Round to the nearest division by a power-of-two using result_shift_s32
+ in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
+ in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+ }
// Convert S32 to S16
int16x8_t out_s16 = vcombine_s16(vqmovn_s32(in_s32.val[0]), vqmovn_s32(in_s32.val[1]));
@@ -90,13 +100,18 @@ template <bool is_bounded_relu>
inline int16_t finalize_quantization_int16(int32_t in_value, int result_fixedpoint_multiplier,
int32_t result_shift, int16_t min_s16, int16_t max_s16)
{
- int32x4_t in_s32 = vdupq_n_s32(in_value);
-
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
-
- // Shift value by result_shift_s32
- in_value = rounding_divide_by_pow2(in_value, result_shift);
+ if(result_shift < 0)
+ {
+ const int64_t in_64 = static_cast<int64_t>(in_value) * (1 << (-result_shift)) * static_cast<int64_t>(result_fixedpoint_multiplier);
+ in_value = static_cast<int32_t>((in_64 + (1 << 30)) >> 31);
+ }
+ else
+ {
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ const int64_t in_64 = static_cast<int64_t>(in_value) * static_cast<int64_t>(result_fixedpoint_multiplier);
+ // Shift value by result_shift_s32
+ in_value = rounding_divide_by_pow2(static_cast<int32_t>((in_64 + (1 << 30)) >> 31), result_shift);
+ }
// Bound the result
int16_t out_s16 = static_cast<int16_t>(std::max<int32_t>(-32768, std::min<int32_t>(32767, in_value)));
diff --git a/arm_compute/core/utils/quantization/AsymmHelpers.h b/arm_compute/core/utils/quantization/AsymmHelpers.h
index a0efe120f2..bc5b9dbdba 100644
--- a/arm_compute/core/utils/quantization/AsymmHelpers.h
+++ b/arm_compute/core/utils/quantization/AsymmHelpers.h
@@ -31,6 +31,15 @@ namespace arm_compute
{
namespace quantization
{
+/** Calculate quantized representation of multiplier.
+ *
+ * @param[in] multiplier Real multiplier.
+ * @param[out] quant_multiplier Integer multiplier.
+ * @param[out] shift bit shift. A negative value indicates a left shift, while a positive value indicates a right shift
+ *
+ * @return a status
+ */
+Status calculate_quantized_multiplier(float multiplier, int *quant_multiplier, int *shift);
/** Calculate quantized representation of multiplier with value less than one.
*
* @param[in] multiplier Real multiplier.
@@ -39,7 +48,7 @@ namespace quantization
*
* @return a status
*/
-arm_compute::Status calculate_quantized_multiplier_less_than_one(float multiplier, int *quant_multiplier, int *right_shift);
+Status calculate_quantized_multiplier_less_than_one(float multiplier, int *quant_multiplier, int *right_shift);
/** Calculate quantized representation of multiplier having value greater than one.
*
* @param[in] multiplier Real multiplier.
@@ -48,7 +57,7 @@ arm_compute::Status calculate_quantized_multiplier_less_than_one(float multiplie
*
* @return a status
*/
-arm_compute::Status calculate_quantized_multiplier_greater_than_one(float multiplier, int *quantized_multiplier, int *left_shift);
+Status calculate_quantized_multiplier_greater_than_one(float multiplier, int *quantized_multiplier, int *left_shift);
/** Get minimum and maximum values for the input quantized data type
*
* @ return min and max values for the quantized data type
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index fc90dbd16c..214c7a4825 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -1888,7 +1888,11 @@ __kernel void gemmlowp_output_stage_quantize_down_fixedpoint_qsymm16(TENSOR3D_DE
#endif // defined(ADD_BIAS)
// Multiply by result_mult_int and shift
+#if RESULT_SHIFT < 0
+ input_values = ASYMM_MULT(input_values * (1 << (-RESULT_SHIFT)), RESULT_FIXEDPOINT_MULTIPLIER, 4);
+#else // RESULT_SHIFT >= 0
input_values = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(input_values, RESULT_FIXEDPOINT_MULTIPLIER, RESULT_SHIFT, 4);
+#endif // RESULT_SHIFT < 0
short4 res = convert_short4_sat(input_values);
diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp
index 59052449af..42bd84db47 100644
--- a/src/core/utils/quantization/AsymmHelpers.cpp
+++ b/src/core/utils/quantization/AsymmHelpers.cpp
@@ -34,6 +34,20 @@ namespace quantization
constexpr int64_t fixed_point_one_Q0 = (1LL << 31);
constexpr float epsilon = 0.00001f;
+Status calculate_quantized_multiplier(float multiplier, int *quant_multiplier, int *shift)
+{
+ if(multiplier > 1.f)
+ {
+ Status status = calculate_quantized_multiplier_greater_than_one(multiplier, quant_multiplier, shift);
+ *shift *= -1;
+ return status;
+ }
+ else
+ {
+ return calculate_quantized_multiplier_less_than_one(multiplier, quant_multiplier, shift);
+ }
+}
+
Status calculate_quantized_multiplier_less_than_one(float multiplier,
int *quant_multiplier,
int *right_shift)
diff --git a/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp b/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp
index 4e6df1d1cb..e5f127825b 100644
--- a/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayerQuantized.cpp
@@ -159,8 +159,7 @@ void CLLSTMLayerQuantized::configure(const ICLTensor *input,
const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
int output_multiplier = 0;
int output_shift = 0;
-
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
_memory_group.manage(&_output_lowp);
_output_stage.configure(&_output_highp, &_bias, &_output_lowp, output_multiplier, output_shift);
@@ -361,12 +360,13 @@ Status CLLSTMLayerQuantized::validate(const ITensorInfo *input,
input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
- // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
- const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
- ARM_COMPUTE_UNUSED(multiplier);
- ARM_COMPUTE_RETURN_ERROR_ON(multiplier > 1.0f);
+ const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
+ int output_multiplier = 0;
+ int output_shift = 0;
+ ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+
// _output_stage
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(&output_highp, &bias_concatenated, &output_lowp));
diff --git a/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp b/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp
index e325619ae4..cfd996b538 100644
--- a/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp
+++ b/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp
@@ -138,8 +138,7 @@ void NELSTMLayerQuantized::configure(const ITensor *input,
const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
int output_multiplier = 0;
int output_shift = 0;
-
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
_memory_group.manage(&_output_lowp);
_output_stage.configure(&_output_highp, &_bias, &_output_lowp, output_multiplier, output_shift);
@@ -340,12 +339,13 @@ Status NELSTMLayerQuantized::validate(const ITensorInfo *input,
input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
- // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
- const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
- ARM_COMPUTE_UNUSED(multiplier);
- ARM_COMPUTE_RETURN_ERROR_ON(multiplier > 1.0f);
+ const float multiplier = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
+ int output_multiplier = 0;
+ int output_shift = 0;
+ ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+
// _output_stage
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint::validate(&output_highp, &bias_concatenated, &output_lowp));
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp
index b8dfc030a2..f5bd871f90 100644
--- a/tests/validation/CL/GEMMLowp.cpp
+++ b/tests/validation/CL/GEMMLowp.cpp
@@ -305,6 +305,14 @@ const auto quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases = framewo
2)
* framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases = framework::dataset::make("result_fixedpoint_multiplier", 1073741823, 1073741825) * framework::dataset::make("result_shift", -3,
+ -2)
+ * framework::dataset::make("min", 0) * framework::dataset::make("max", 0) * framework::dataset::make("addBias", { false, true });
+
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", -3,
+ -1)
+ * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
+
using CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture =
GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture<CLTensor, CLAccessor, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint>;
@@ -344,19 +352,41 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(
}
// clang-format on
// *INDENT-ON*
+TEST_SUITE(NoRelu)
+TEST_SUITE(MultSmallerEq1)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
quantize_down_int32_to_int16_scale_by_fixedpoint_cases))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+ quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // NoRelu
TEST_SUITE(BoundedReLu)
+TEST_SUITE(MultSmallerEq1)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+ quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
TEST_SUITE_END() // BoundedReLu
TEST_SUITE_END() // QuantizeDownInt32ToInt16ScaleByFixedPoint
TEST_SUITE_END() // OutputStage
diff --git a/tests/validation/CL/LSTMLayerQuantized.cpp b/tests/validation/CL/LSTMLayerQuantized.cpp
index 1fc0af1ecb..686d6bcef8 100644
--- a/tests/validation/CL/LSTMLayerQuantized.cpp
+++ b/tests/validation/CL/LSTMLayerQuantized.cpp
@@ -72,13 +72,14 @@ TEST_SUITE(LSTMLayerQuantized)
// *INDENT-OFF*
// clang-format off
-TEST_CASE(IntegrationTestCaseSmall, framework::DatasetMode::PRECOMMIT)
+TEST_SUITE(IntegrationTestCase)
+TEST_SUITE(MultSmallerEq1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
{
const int batch_size = 2;
const int input_size = 2;
const int output_size = 4;
-
QuantizationInfo qasymm(1.f / 128.f, 128);
QuantizationInfo qweights(1.f / 128.f, 128);
QuantizationInfo qsymm_3(8.f / 32768.f, 0);
@@ -211,7 +212,7 @@ TEST_CASE(IntegrationTestCaseSmall, framework::DatasetMode::PRECOMMIT)
validate(CLAccessor(output_state), expected_output);
}
-TEST_CASE(IntegrationTestCaseLarge, framework::DatasetMode::PRECOMMIT)
+TEST_CASE(RunLarge, framework::DatasetMode::PRECOMMIT)
{
const int batch_size = 16;
const int input_size = 8;
@@ -448,11 +449,154 @@ TEST_CASE(IntegrationTestCaseLarge, framework::DatasetMode::PRECOMMIT)
lstmq.run();
validate(CLAccessor(output_state), expected_output);
}
+TEST_SUITE_END() // MultSmallerEq1
+
+TEST_SUITE(MultGreater1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
+{
+ //Input sequence length is 1
+ const int batch_size = 2;
+ const int input_size = 2;
+ const int output_size = 4;
+
+ QuantizationInfo qasymm(1.f / 128.f, 128);
+ QuantizationInfo qweights(1.f / 16.f, 16);
+ QuantizationInfo qsymm_3(8.f / 32768.f, 0);
+ QuantizationInfo qsymm_4(16.f / 32768.f, 0);
+
+ TensorShape input_shape{ input_size, batch_size };
+ TensorShape input_weights_shape{ input_size, output_size };
+ TensorShape recurrent_weights_shape{ output_size, output_size };
+ TensorShape output_shape{ output_size, batch_size};
+ TensorShape bias_shape{ output_size };
+
+ auto input_to_input_weights = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_to_forget_weights = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_to_cell_weights = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_to_output_weights = create_tensor<CLTensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_input_weights = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_forget_weights = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_cell_weights = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_output_weights = create_tensor<CLTensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_gate_bias = create_tensor<CLTensor>(bias_shape, DataType::S32);
+ auto forget_gate_bias = create_tensor<CLTensor>(bias_shape, DataType::S32);
+ auto cell_gate_bias = create_tensor<CLTensor>(bias_shape, DataType::S32);
+ auto output_gate_bias = create_tensor<CLTensor>(bias_shape, DataType::S32);
+
+ // LSTM input
+ auto input = create_tensor<CLTensor>(input_shape, DataType::QASYMM8, 1, qasymm);
+
+ // LSTM output state
+ auto output_state = create_tensor<CLTensor>(output_shape, DataType::QASYMM8, 1, qasymm);
+
+ // LSTM cell state
+ auto cell_state = create_tensor<CLTensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
+
+ CLLSTMLayerQuantized lstmq;
+
+ lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
+ &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
+ &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
+
+ input.allocator()->allocate();
+ input_to_input_weights.allocator()->allocate();
+ input_to_forget_weights.allocator()->allocate();
+ input_to_cell_weights.allocator()->allocate();
+ input_to_output_weights.allocator()->allocate();
+ recurrent_to_input_weights.allocator()->allocate();
+ recurrent_to_forget_weights.allocator()->allocate();
+ recurrent_to_cell_weights.allocator()->allocate();
+ recurrent_to_output_weights.allocator()->allocate();
+ input_gate_bias.allocator()->allocate();
+ forget_gate_bias.allocator()->allocate();
+ cell_gate_bias.allocator()->allocate();
+ output_gate_bias.allocator()->allocate();
+ cell_state.allocator()->allocate();
+ output_state.allocator()->allocate();
+
+ // Fill weights and biases
+ fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 122, 130,
+ 124, 134,
+ 120, 122,
+ 134, 134 });
+
+ fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 204, 193,
+ 148, 59,
+ 113, 17,
+ 66, 197 });
+
+ fill_tensor(input_to_cell_weights, std::vector<uint8_t> { 172, 101,
+ 184, 209,
+ 165, 82,
+ 108, 209 });
+
+ fill_tensor(input_to_output_weights, std::vector<uint8_t> { 203, 244,
+ 219, 114,
+ 130, 16,
+ 163, 222 });
+
+ fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 162, 168, 7, 95,
+ 91, 155, 108, 216,
+ 255, 100, 48, 188,
+ 58, 37, 186, 147 });
+
+ fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> { 46, 58, 47, 170,
+ 246, 96, 12, 99,
+ 68, 23, 186, 161,
+ 237, 164, 89, 6 });
+
+ fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> { 234, 99, 71, 206,
+ 205, 159, 64, 253,
+ 191, 148, 116, 8,
+ 209, 136, 59, 138 });
+
+ fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> { 23, 241, 137, 36,
+ 206, 5, 227, 56,
+ 254, 176, 231, 47,
+ 18, 201, 161, 11 });
+
+ fill_tensor(input_gate_bias, std::vector<int> {-103038, 30525, 115255, -38154 });
+ fill_tensor(forget_gate_bias, std::vector<int> { -23428, 126970, 116806, 46307 });
+ fill_tensor(cell_gate_bias, std::vector<int> { 128006, 69949, -42808, 42568 });
+ fill_tensor(output_gate_bias, std::vector<int> { -67066, -53607, 47233, 7300 });
+
+ SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
+
+ // Initialize state
+ fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128,
+ 128, 128, 128, 128 });
+ fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0,
+ 0, 0, 0, 0 });
+
+ // First input
+ fill_tensor(input, std::vector<uint8_t> { 106, 193,
+ 155, 150 });
+
+ fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 31, 128,
+ 128, 128, 31, 128 });
+
+ lstmq.run();
+ validate(CLAccessor(output_state), expected_output);
+
+ // Second input
+ fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 5, 128,
+ 128, 128, 5, 128 });
+ lstmq.run();
+ validate(CLAccessor(output_state), expected_output);
+
+ // Third input
+ fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 1, 128,
+ 128, 128, 1, 128, });
+ lstmq.run();
+ validate(CLAccessor(output_state), expected_output);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // IntegrationTestCase
// clang-format on
// *INDENT-ON*
TEST_SUITE_END() // LSTMLayerQuantized
-TEST_SUITE_END() // NEON
+TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index 2f604c95ea..d79374efa7 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -417,6 +417,13 @@ const auto quantize_down_int32_to_int16_scale_by_fixedpoint_cases = framework::d
const auto quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", 1,
2)
* framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases = framework::dataset::make("result_fixedpoint_multiplier", 1073741823, 1073741825) * framework::dataset::make("result_shift", -3,
+ -2)
+ * framework::dataset::make("min", 0) * framework::dataset::make("max", 0) * framework::dataset::make("addBias", { false, true });
+
+const auto quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases = framework::dataset::make("result_fixedpoint_multiplier", 254601600, 254601602) * framework::dataset::make("result_shift", -3,
+ -1)
+ * framework::dataset::make("min", -2, 0) * framework::dataset::make("max", 1, 3) * framework::dataset::make("addBias", { false, true });
using NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture =
GEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointValidationFixture<Tensor, Accessor, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPoint>;
@@ -499,27 +506,44 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Sma
validate(bias.info()->padding(), padding);
}
}
-
+TEST_SUITE(NoRelu)
+TEST_SUITE(MultSmallerEq1)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
quantize_down_int32_to_int16_scale_by_fixedpoint_cases))
{
// Validate output
validate(Accessor(_target), _reference);
}
-
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+ quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_cases))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // NoRelu
TEST_SUITE(BoundedReLu)
+TEST_SUITE(MultSmallerEq1)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
quantize_down_int32_to_int16_scale_by_fixedpoint_relu_cases))
{
// Validate output
validate(Accessor(_target), _reference);
}
+TEST_SUITE_END() // MultSmallerEq1
+TEST_SUITE(MultGreater1)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointFixture, framework::DatasetMode::ALL, combine(datasets::SmallShapes(),
+ quantize_down_int32_to_int16_scale_by_fixedpoint_multgreat1_relu_cases))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // MultGreater1
TEST_SUITE_END() // BoundedReLu
-
TEST_SUITE_END() // QuantizeDownInt32ToInt16ScaleByFixedPoint
-
TEST_SUITE_END() // OutputStage
-
TEST_SUITE_END() // GEMMLowp
TEST_SUITE_END() // NEON
} // namespace validation
diff --git a/tests/validation/NEON/LSTMLayerQuantized.cpp b/tests/validation/NEON/LSTMLayerQuantized.cpp
index 0935165564..b57a8f7d26 100644
--- a/tests/validation/NEON/LSTMLayerQuantized.cpp
+++ b/tests/validation/NEON/LSTMLayerQuantized.cpp
@@ -77,7 +77,9 @@ TEST_SUITE(LSTMLayerQuantized)
// *INDENT-OFF*
// clang-format off
-TEST_CASE(IntegrationTestCaseSmall, framework::DatasetMode::PRECOMMIT)
+TEST_SUITE(IntegrationTestCase)
+TEST_SUITE(MultSmallerEq1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
{
const int batch_size = 2;
const int input_size = 2;
@@ -216,7 +218,7 @@ TEST_CASE(IntegrationTestCaseSmall, framework::DatasetMode::PRECOMMIT)
validate(Accessor(output_state), expected_output, tolerance_qsymm16);
}
-TEST_CASE(IntegrationTestCaseLarge, framework::DatasetMode::PRECOMMIT)
+TEST_CASE(RunLarge, framework::DatasetMode::PRECOMMIT)
{
const int batch_size = 16;
const int input_size = 8;
@@ -453,11 +455,154 @@ TEST_CASE(IntegrationTestCaseLarge, framework::DatasetMode::PRECOMMIT)
lstmq.run();
validate(Accessor(output_state), expected_output, tolerance_qsymm16);
}
+TEST_SUITE_END() // MultSmallerEq1
+
+TEST_SUITE(MultGreater1)
+TEST_CASE(RunSmall, framework::DatasetMode::PRECOMMIT)
+{
+ //Input sequence length is 1
+ const int batch_size = 2;
+ const int input_size = 2;
+ const int output_size = 4;
+
+ QuantizationInfo qasymm(1.f / 128.f, 128);
+ QuantizationInfo qweights(1.f / 16.f, 16);
+ QuantizationInfo qsymm_3(8.f / 32768.f, 0);
+ QuantizationInfo qsymm_4(16.f / 32768.f, 0);
+
+ TensorShape input_shape{ input_size, batch_size };
+ TensorShape input_weights_shape{ input_size, output_size };
+ TensorShape recurrent_weights_shape{ output_size, output_size };
+ TensorShape output_shape{ output_size, batch_size};
+ TensorShape bias_shape{ output_size };
+
+ auto input_to_input_weights = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_to_forget_weights = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_to_cell_weights = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_to_output_weights = create_tensor<Tensor>(input_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_input_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_forget_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_cell_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto recurrent_to_output_weights = create_tensor<Tensor>(recurrent_weights_shape, DataType::QASYMM8, 1, qweights);
+ auto input_gate_bias = create_tensor<Tensor>(bias_shape, DataType::S32);
+ auto forget_gate_bias = create_tensor<Tensor>(bias_shape, DataType::S32);
+ auto cell_gate_bias = create_tensor<Tensor>(bias_shape, DataType::S32);
+ auto output_gate_bias = create_tensor<Tensor>(bias_shape, DataType::S32);
+
+ // LSTM input
+ auto input = create_tensor<Tensor>(input_shape, DataType::QASYMM8, 1, qasymm);
+
+ // LSTM output state
+ auto output_state = create_tensor<Tensor>(output_shape, DataType::QASYMM8, 1, qasymm);
+
+ // LSTM cell state
+ auto cell_state = create_tensor<Tensor>(output_shape, DataType::QSYMM16, 1, qsymm_4);
+
+ NELSTMLayerQuantized lstmq;
+
+ lstmq.configure(&input, &input_to_input_weights, &input_to_forget_weights, &input_to_cell_weights, &input_to_output_weights,
+ &recurrent_to_input_weights, &recurrent_to_forget_weights, &recurrent_to_cell_weights, &recurrent_to_output_weights,
+ &input_gate_bias, &forget_gate_bias, &cell_gate_bias, &output_gate_bias, &cell_state, &output_state, &cell_state, &output_state);
+
+ input.allocator()->allocate();
+ input_to_input_weights.allocator()->allocate();
+ input_to_forget_weights.allocator()->allocate();
+ input_to_cell_weights.allocator()->allocate();
+ input_to_output_weights.allocator()->allocate();
+ recurrent_to_input_weights.allocator()->allocate();
+ recurrent_to_forget_weights.allocator()->allocate();
+ recurrent_to_cell_weights.allocator()->allocate();
+ recurrent_to_output_weights.allocator()->allocate();
+ input_gate_bias.allocator()->allocate();
+ forget_gate_bias.allocator()->allocate();
+ cell_gate_bias.allocator()->allocate();
+ output_gate_bias.allocator()->allocate();
+ cell_state.allocator()->allocate();
+ output_state.allocator()->allocate();
+
+ // Fill weights and biases
+ fill_tensor(input_to_input_weights, std::vector<uint8_t>{ 122, 130,
+ 124, 134,
+ 120, 122,
+ 134, 134 });
+
+ fill_tensor(input_to_forget_weights, std::vector<uint8_t> { 204, 193,
+ 148, 59,
+ 113, 17,
+ 66, 197 });
+
+ fill_tensor(input_to_cell_weights, std::vector<uint8_t> { 172, 101,
+ 184, 209,
+ 165, 82,
+ 108, 209 });
+
+ fill_tensor(input_to_output_weights, std::vector<uint8_t> { 203, 244,
+ 219, 114,
+ 130, 16,
+ 163, 222 });
+
+ fill_tensor(recurrent_to_input_weights, std::vector<uint8_t> { 162, 168, 7, 95,
+ 91, 155, 108, 216,
+ 255, 100, 48, 188,
+ 58, 37, 186, 147 });
+
+ fill_tensor(recurrent_to_forget_weights, std::vector<uint8_t> { 46, 58, 47, 170,
+ 246, 96, 12, 99,
+ 68, 23, 186, 161,
+ 237, 164, 89, 6 });
+
+ fill_tensor(recurrent_to_cell_weights, std::vector<uint8_t> { 234, 99, 71, 206,
+ 205, 159, 64, 253,
+ 191, 148, 116, 8,
+ 209, 136, 59, 138 });
+
+ fill_tensor(recurrent_to_output_weights, std::vector<uint8_t> { 23, 241, 137, 36,
+ 206, 5, 227, 56,
+ 254, 176, 231, 47,
+ 18, 201, 161, 11 });
+
+ fill_tensor(input_gate_bias, std::vector<int> {-103038, 30525, 115255, -38154 });
+ fill_tensor(forget_gate_bias, std::vector<int> { -23428, 126970, 116806, 46307 });
+ fill_tensor(cell_gate_bias, std::vector<int> { 128006, 69949, -42808, 42568 });
+ fill_tensor(output_gate_bias, std::vector<int> { -67066, -53607, 47233, 7300 });
+
+ SimpleTensor<uint8_t> expected_output(output_shape, DataType::QASYMM8, 1, qasymm);
+
+ // Initialize state
+ fill_tensor(output_state, std::vector<uint8_t> { 128, 128, 128, 128,
+ 128, 128, 128, 128 });
+ fill_tensor(cell_state, std::vector<int16_t> { 0, 0, 0, 0,
+ 0, 0, 0, 0 });
+
+ // First input
+ fill_tensor(input, std::vector<uint8_t> { 106, 193,
+ 155, 150 });
+
+ fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 31, 128,
+ 128, 128, 31, 128 });
+
+ lstmq.run();
+ validate(Accessor(output_state), expected_output);
+
+ // Second input
+ fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 5, 128,
+ 128, 128, 5, 128 });
+ lstmq.run();
+ validate(Accessor(output_state), expected_output);
+
+ // Third input
+ fill_tensor(expected_output, std::vector<uint8_t> { 128, 128, 1, 128,
+ 128, 128, 1, 128, });
+ lstmq.run();
+ validate(Accessor(output_state), expected_output);
+}
+TEST_SUITE_END() // MultGreater1
+TEST_SUITE_END() // IntegrationTestCase
// clang-format on
// *INDENT-ON*
TEST_SUITE_END() // LSTMLayerQuantized
-TEST_SUITE_END() // CL
+TEST_SUITE_END() // NEON
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 97d05327e7..4283cb5bac 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -112,7 +112,14 @@ void quantize_down_int32_to_int16_scale_by_fixedpoint(const SimpleTensor<T> *in,
}
// Fixed point multiplication
- result = asymm_rounding_divide_by_pow2(asymm_int_mult(result, result_fixedpoint_multiplier), result_shift);
+ if(result_shift < 0)
+ {
+ result = asymm_int_mult(result * (1 << (-result_shift)), result_fixedpoint_multiplier);
+ }
+ else
+ {
+ result = asymm_rounding_divide_by_pow2(asymm_int_mult(result, result_fixedpoint_multiplier), result_shift);
+ }
// Bounded ReLu
if(min != max)