diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/ArithmeticAdditionFixture.h | 42 |
1 files changed, 25 insertions, 17 deletions
diff --git a/tests/validation/fixtures/ArithmeticAdditionFixture.h b/tests/validation/fixtures/ArithmeticAdditionFixture.h index 6d529a843c..8b14485aca 100644 --- a/tests/validation/fixtures/ArithmeticAdditionFixture.h +++ b/tests/validation/fixtures/ArithmeticAdditionFixture.h @@ -46,10 +46,10 @@ class ArithmeticAdditionGenericFixture : public framework::Fixture public: template <typename...> void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo quantization_info) + QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - _target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, quantization_info); - _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, quantization_info); + _target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); + _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); } protected: @@ -60,12 +60,12 @@ protected: } TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo quantization_info) + QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // Create tensors - TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, quantization_info); - TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, quantization_info); - TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, quantization_info); + TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0); + TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1); + TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out); // Create and configure function FunctionType add; @@ -94,18 +94,20 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo quantization_info) + SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, + DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // Create reference - SimpleTensor<T> ref_src1{ shape0, data_type0, 1, quantization_info }; - SimpleTensor<T> ref_src2{ shape1, data_type1, 1, quantization_info }; + SimpleTensor<T> ref_src1{ shape0, data_type0, 1, qinfo0 }; + SimpleTensor<T> ref_src2{ shape1, data_type1, 1, qinfo1 }; + SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out }; // Fill reference fill(ref_src1, 0); fill(ref_src2, 1); - return reference::arithmetic_addition<T>(ref_src1, ref_src2, output_data_type, convert_policy); + return reference::arithmetic_addition<T>(ref_src1, ref_src2, ref_dst, convert_policy); } TensorType _target{}; @@ -119,7 +121,8 @@ public: template <typename...> void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) { - ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, QuantizationInfo()); + ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, data_type0, data_type1, + output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; @@ -130,7 +133,8 @@ public: template <typename...> void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) { - ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, QuantizationInfo()); + ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, data_type0, data_type1, + output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; @@ -141,7 +145,8 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) { - ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, shape, data_type0, data_type1, output_data_type, convert_policy, QuantizationInfo()); + ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, shape, data_type0, data_type1, + output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo()); } }; @@ -161,9 +166,12 @@ class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticAdditionGe { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, QuantizationInfo quantization_info) + void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) + { - ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, shape, data_type0, data_type1, output_data_type, convert_policy, quantization_info); + ArithmeticAdditionGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, shape, data_type0, data_type1, + output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); } }; } // namespace validation |