diff options
Diffstat (limited to 'tests/validation/fixtures/PixelWiseMultiplicationFixture.h')
-rw-r--r-- | tests/validation/fixtures/PixelWiseMultiplicationFixture.h | 46 |
1 files changed, 17 insertions, 29 deletions
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h index efdf5d078e..37359f421b 100644 --- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h +++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -39,7 +39,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2> class PixelWiseMultiplicationGenericValidationFixture : public framework::Fixture { public: @@ -48,6 +48,7 @@ public: const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, @@ -55,8 +56,8 @@ public: QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); - _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); + _target = compute_target(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); + _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, dt_out, scale, convert_policy, rounding_policy, qinfo0, qinfo1, qinfo_out); } protected: @@ -66,14 +67,14 @@ protected: library->fill_tensor_uniform(tensor, seed_offset); } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // Create tensors TensorType src1 = create_tensor<TensorType>(shape0, dt_in1, 1, qinfo0); TensorType src2 = create_tensor<TensorType>(shape1, dt_in2, 1, qinfo1); - TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2, 1, qinfo_out); + TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_out, 1, qinfo_out); // Create and configure function FunctionType multiply; @@ -102,7 +103,7 @@ protected: return dst; } - SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + SimpleTensor<T3> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { @@ -114,24 +115,11 @@ protected: fill(src1, 0); fill(src2, 1); - return reference::pixel_wise_multiplication<T1, T2>(src1, src2, scale, convert_policy, rounding_policy, qinfo_out); + return reference::pixel_wise_multiplication<T1, T2, T3>(src1, src2, scale, convert_policy, rounding_policy, dt_out, qinfo_out); } TensorType _target{}; - SimpleTensor<T2> _reference{}; -}; - -template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> -class PixelWiseMultiplicationQuatizedValidationFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> -{ -public: - template <typename...> - void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, - QuantizationInfo in1_qua_info, QuantizationInfo in2_qua_info, QuantizationInfo out_qua_info) - { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy, - in1_qua_info, in2_qua_info, out_qua_info); - } + SimpleTensor<T3> _reference{}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> @@ -141,7 +129,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy, + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; @@ -153,21 +141,21 @@ public: template <typename...> void setup(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy, + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape0, shape1, dt_in1, dt_in2, dt_in2, scale, convert_policy, rounding_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> -class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2, typename T3 = T2> +class PixelWiseMultiplicationValidationQuantizedFixture : public PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3> { public: template <typename...> - void setup(const TensorShape &shape, DataType dt, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, DataType dt_out, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt, dt, scale, convert_policy, rounding_policy, - qinfo0, qinfo1, qinfo_out); + PixelWiseMultiplicationGenericValidationFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, shape, dt_in1, dt_in2, dt_out, scale, convert_policy, + rounding_policy, qinfo0, qinfo1, qinfo_out); } }; } // namespace validation |