diff options
Diffstat (limited to 'tests/validation/fixtures/ArithmeticOperationsFixture.h')
-rw-r--r-- | tests/validation/fixtures/ArithmeticOperationsFixture.h | 96 |
1 files changed, 42 insertions, 54 deletions
diff --git a/tests/validation/fixtures/ArithmeticOperationsFixture.h b/tests/validation/fixtures/ArithmeticOperationsFixture.h index 1dfc2ce579..7aa716d676 100644 --- a/tests/validation/fixtures/ArithmeticOperationsFixture.h +++ b/tests/validation/fixtures/ArithmeticOperationsFixture.h @@ -46,15 +46,14 @@ class ArithmeticOperationGenericFixture : public framework::Fixture { public: template <typename...> - void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, - DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool in_place) { _op = op; _act_info = act_info; _in_place = in_place; - _target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); - _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); + _target = compute_target(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out); + _reference = compute_reference(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out); } protected: @@ -64,13 +63,13 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // Create tensors - TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0); - TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1); - TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out); + TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type, 1, qinfo0); + TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type, 1, qinfo1); + TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), data_type, 1, qinfo_out); TensorType *dst_to_use = _in_place ? &ref_src1 : &dst; // Create and configure function @@ -104,8 +103,7 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, - DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // current in-place implementation only supports same metadata of input and output tensors. @@ -113,9 +111,9 @@ protected: QuantizationInfo output_qinfo = _in_place ? qinfo0 : qinfo_out; // Create reference - SimpleTensor<T> ref_src1{ shape0, data_type0, 1, qinfo0 }; - SimpleTensor<T> ref_src2{ shape1, data_type1, 1, qinfo1 }; - SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, output_qinfo }; + SimpleTensor<T> ref_src1{ shape0, data_type, 1, qinfo0 }; + SimpleTensor<T> ref_src2{ shape1, data_type, 1, qinfo1 }; + SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), data_type, 1, output_qinfo }; // Fill reference fill(ref_src1, 0); @@ -137,10 +135,10 @@ class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationG { public: template <typename...> - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); } }; @@ -149,10 +147,10 @@ class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFix { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); } }; @@ -161,10 +159,10 @@ class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOpera { public: template <typename...> - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); } }; @@ -173,10 +171,10 @@ class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGener { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); } }; @@ -185,12 +183,11 @@ class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationG { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1, - output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false); + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy, + qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false); } }; @@ -199,11 +196,9 @@ class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticO { public: template <typename...> - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, - ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false); } }; @@ -213,10 +208,9 @@ class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperati { public: template <typename...> - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool in_place) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place); } }; @@ -226,11 +220,10 @@ class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOp { public: template <typename...> - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place); } }; @@ -240,10 +233,9 @@ class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGeneric { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool in_place) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place); } }; @@ -253,10 +245,9 @@ class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGe { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place); } }; @@ -266,13 +257,11 @@ class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperati { public: template <typename...> - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, - data_type0, data_type1, output_data_type, - convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place); + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy, + qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place); } }; @@ -281,11 +270,10 @@ class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public Arithmet { public: template <typename...> - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, - ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, + bool in_place) { - ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place); } }; |