aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ArithmeticSubtractionFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ArithmeticSubtractionFixture.h')
-rw-r--r--tests/validation/fixtures/ArithmeticSubtractionFixture.h22
1 files changed, 11 insertions, 11 deletions
diff --git a/tests/validation/fixtures/ArithmeticSubtractionFixture.h b/tests/validation/fixtures/ArithmeticSubtractionFixture.h
index 2c683d659f..9e65faef00 100644
--- a/tests/validation/fixtures/ArithmeticSubtractionFixture.h
+++ b/tests/validation/fixtures/ArithmeticSubtractionFixture.h
@@ -40,7 +40,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1>
class ArithmeticSubtractionValidationFixedPointFixture : public framework::Fixture
{
public:
@@ -93,31 +93,31 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position)
+ SimpleTensor<T3> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position)
{
// Create reference
- SimpleTensor<T> ref_src1{ shape, data_type0, 1, fixed_point_position };
- SimpleTensor<T> ref_src2{ shape, data_type1, 1, fixed_point_position };
+ SimpleTensor<T1> ref_src1{ shape, data_type0, 1, fixed_point_position };
+ SimpleTensor<T2> ref_src2{ shape, data_type1, 1, fixed_point_position };
// Fill reference
fill(ref_src1, 0);
fill(ref_src2, 1);
- return reference::arithmetic_subtraction<T>(ref_src1, ref_src2, output_data_type, convert_policy);
+ return reference::arithmetic_subtraction<T1, T2, T3>(ref_src1, ref_src2, output_data_type, convert_policy);
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
- int _fractional_bits{};
+ TensorType _target{};
+ SimpleTensor<T3> _reference{};
+ int _fractional_bits{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1>
+class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
template <typename...>
void setup(TensorShape shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
{
- ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0);
+ ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0);
}
};
} // namespace validation