diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/PixelWiseMultiplication.cpp | 57 | ||||
-rw-r--r-- | tests/validation/reference/PixelWiseMultiplication.h | 9 | ||||
-rw-r--r-- | tests/validation/reference/QLSTMLayerNormalization.cpp | 2 |
3 files changed, 35 insertions, 33 deletions
diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp index 2b4c849c39..3e21fca72a 100644 --- a/tests/validation/reference/PixelWiseMultiplication.cpp +++ b/tests/validation/reference/PixelWiseMultiplication.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -52,16 +52,16 @@ namespace * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. */ -template <typename T1, typename T2> -T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) +template <typename T1, typename T2, typename T3> +T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type; + using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type; const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale); - if(is_floating_point<T2>::value) + if(is_floating_point<T3>::value) { - const auto result = static_cast<T2>(val); + const auto result = static_cast<T3>(val); return result; } @@ -83,7 +83,7 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, ARM_COMPUTE_ERROR("Unsupported rounding policy"); } - const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val); + const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val); return result; } @@ -92,8 +92,8 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, template <size_t dim> struct BroadcastUnroll { - template <typename T1, typename T2> - static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst, + template <typename T1, typename T2, typename T3> + static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { @@ -117,23 +117,23 @@ struct BroadcastUnroll template <> struct BroadcastUnroll<0> { - template <typename T1, typename T2> - static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst, + template <typename T1, typename T2, typename T3> + static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) { - dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); + dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); } }; } // namespace -template <typename T1, typename T2> -SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + DataType dt_out, const QuantizationInfo &qout) { ARM_COMPUTE_UNUSED(qout); - SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type()); + SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out); if(scale < 0) { @@ -151,15 +151,15 @@ SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const S template <> SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) + DataType dt_out, const QuantizationInfo &qout) { - SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout); + SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8) { SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); - SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout); + SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout); } else @@ -179,15 +179,15 @@ SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src template <> SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) + DataType dt_out, const QuantizationInfo &qout) { - SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout); + SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED) { SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1); SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2); - SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout); + SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); dst = convert_to_asymmetric<int8_t>(dst_tmp, qout); } else @@ -207,15 +207,15 @@ SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, template <> SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - const QuantizationInfo &qout) + DataType dt_out, const QuantizationInfo &qout) { - SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout); + SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout); if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16) { SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1); SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2); - SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout); + SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout); dst = convert_to_symmetric<int16_t>(dst_tmp, qout); } else @@ -234,9 +234,10 @@ SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src } // *INDENT-OFF* // clang-format off -template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout); -template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout); -template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout); +template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); +template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); +template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); +template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout); // clang-format on // *INDENT-ON* } // namespace reference diff --git a/tests/validation/reference/PixelWiseMultiplication.h b/tests/validation/reference/PixelWiseMultiplication.h index f5b8e777fb..f8afa0384b 100644 --- a/tests/validation/reference/PixelWiseMultiplication.h +++ b/tests/validation/reference/PixelWiseMultiplication.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -34,9 +34,10 @@ namespace validation { namespace reference { -template <typename T1, typename T2> -SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, - ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout = QuantizationInfo()); +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, + ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, + const QuantizationInfo &qout = QuantizationInfo()); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/QLSTMLayerNormalization.cpp b/tests/validation/reference/QLSTMLayerNormalization.cpp index 90d59b93ad..0e24de6584 100644 --- a/tests/validation/reference/QLSTMLayerNormalization.cpp +++ b/tests/validation/reference/QLSTMLayerNormalization.cpp @@ -41,7 +41,7 @@ namespace reference SimpleTensor<float> qlstm_layer_normalization_float_compute(SimpleTensor<float> src, SimpleTensor<float> weight, SimpleTensor<float> bias) { SimpleTensor<float> output = mean_std_normalization_layer(src); - output = pixel_wise_multiplication(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO); + output = pixel_wise_multiplication<float, float, float>(output, weight, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO, DataType::F32); return arithmetic_operation(ArithmeticOperation::ADD, output, bias, DataType::F32, ConvertPolicy::SATURATE); } |