aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/CL/ArithmeticSubtraction.cpp106
-rw-r--r--tests/validation/CPP/ArithmeticSubtraction.cpp13
-rw-r--r--tests/validation/CPP/ArithmeticSubtraction.h4
-rw-r--r--tests/validation/NEON/ArithmeticSubtraction.cpp102
-rw-r--r--tests/validation/fixtures/ArithmeticSubtractionFixture.h22
5 files changed, 197 insertions, 50 deletions
diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp
index 817a31fbd9..5e7d741009 100644
--- a/tests/validation/CL/ArithmeticSubtraction.cpp
+++ b/tests/validation/CL/ArithmeticSubtraction.cpp
@@ -47,8 +47,14 @@ namespace
const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
framework::dataset::make("DataType",
DataType::U8));
-const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
+ framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
+ framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
+ framework::dataset::make("DataType", DataType::S16));
const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)),
framework::dataset::make("DataType", DataType::QS8));
const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)),
@@ -62,8 +68,8 @@ const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset
TEST_SUITE(CL)
TEST_SUITE(ArithmeticSubtraction)
-template <typename T>
-using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
TEST_SUITE(U8)
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
@@ -89,7 +95,8 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::da
validate(dst.info()->padding(), padding);
}
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionU8Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -97,14 +104,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framew
}
TEST_SUITE_END()
+template <typename T1, typename T2 = T1>
+using CLArithmeticSubtractionToS16Fixture = CLArithmeticSubtractionFixture<T1, T2, int16_t>;
+
TEST_SUITE(S16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
+ framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+ framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
- shape, data_type, policy)
+ shape, data_type1, data_type2, policy)
{
// Create tensors
- CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type);
- CLTensor ref_src2 = create_tensor<CLTensor>(shape, DataType::S16);
+ CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type1);
+ CLTensor ref_src2 = create_tensor<CLTensor>(shape, data_type2);
CLTensor dst = create_tensor<CLTensor>(shape, DataType::S16);
// Create and Configure function
@@ -121,28 +133,86 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
validate(ref_src2.info()->padding(), padding);
validate(dst.info()->padding(), padding);
}
+TEST_SUITE(S16_S16_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
+TEST_SUITE_END()
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+TEST_SUITE(U8_U8_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionU8U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionU8U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
TEST_SUITE_END()
-template <typename T>
-using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
+TEST_SUITE(S16_U8_S16)
+using CLAriSubS16U8ToS16Fixture = CLArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionS16U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionS16U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(U8_S16_S16)
+using CLAriSubU8S16ToS16Fixture = CLArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionU8S16S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionU8S16S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>;
TEST_SUITE(Quantized)
TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionQS8Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
framework::dataset::make("FractionalBits", 1, 7)))
{
@@ -150,7 +220,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionQS8Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
framework::dataset::make("FractionalBits", 1, 7)))
{
@@ -169,7 +240,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int16_
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionQS16Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
framework::dataset::make("FractionalBits", 1, 15)))
{
diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp
index 80bdb15a49..bed2d37090 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.cpp
+++ b/tests/validation/CPP/ArithmeticSubtraction.cpp
@@ -34,23 +34,26 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
{
- SimpleTensor<T> result(src1.shape(), dst_data_type);
+ SimpleTensor<T3> result(src1.shape(), dst_data_type);
- using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
+ using intermediate_type = typename common_promoted_signed_type<typename std::conditional<sizeof(T1) >= sizeof(T2), T1, T2>::type >::intermediate_type;
for(int i = 0; i < src1.num_elements(); ++i)
{
intermediate_type val = static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]);
- result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val);
+ result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(val) : static_cast<T3>(val);
}
return result;
}
template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
diff --git a/tests/validation/CPP/ArithmeticSubtraction.h b/tests/validation/CPP/ArithmeticSubtraction.h
index 18b0d121a0..9308314bda 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.h
+++ b/tests/validation/CPP/ArithmeticSubtraction.h
@@ -35,8 +35,8 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index dcaf9d987b..fcd415b130 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -49,6 +49,12 @@ const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::
DataType::U8));
const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
+ framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)),
+ framework::dataset::make("DataType", DataType::S16));
+const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)),
+ framework::dataset::make("DataType", DataType::S16));
const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)),
framework::dataset::make("DataType", DataType::QS8));
const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)),
@@ -64,8 +70,8 @@ const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset
TEST_SUITE(NEON)
TEST_SUITE(ArithmeticSubtraction)
-template <typename T>
-using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;
TEST_SUITE(U8)
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
@@ -99,14 +105,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<uint8_t>, framew
}
TEST_SUITE_END()
+template <typename T1, typename T2 = T1>
+using NEArithmeticSubtractionToS16Fixture = NEArithmeticSubtractionFixture<T1, T2, int16_t>;
+
TEST_SUITE(S16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()),
+ framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
+ framework::dataset::make("DataType", { DataType::U8, DataType::S16 })),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
- shape, data_type, policy)
+ shape, data_type1, data_type2, policy)
{
// Create tensors
- Tensor ref_src1 = create_tensor<Tensor>(shape, data_type);
- Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::S16);
+ Tensor ref_src1 = create_tensor<Tensor>(shape, data_type1);
+ Tensor ref_src2 = create_tensor<Tensor>(shape, data_type2);
Tensor dst = create_tensor<Tensor>(shape, DataType::S16);
// Create and Configure function
@@ -124,27 +135,86 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame
validate(dst.info()->padding(), padding);
}
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+TEST_SUITE(S16_S16_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(U8_U8_S16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionU8U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionU8U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+
+TEST_SUITE(S16_U8_S16)
+using NEAriSubS16U8ToS16Fixture = NEArithmeticSubtractionToS16Fixture<int16_t, uint8_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionS16U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionS16U8S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(Accessor(_target), _reference);
}
TEST_SUITE_END()
-template <typename T>
-using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
+TEST_SUITE(U8_S16_S16)
+using NEAriSubU8S16ToS16Fixture = NEArithmeticSubtractionToS16Fixture<uint8_t, int16_t>;
+FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionU8S16S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionU8S16S16Dataset),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END()
+TEST_SUITE_END()
+
+template <typename T1, typename T2 = T1, typename T3 = T1>
+using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;
TEST_SUITE(Quantized)
TEST_SUITE(QS8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(),
+ ArithmeticSubtractionQS8Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
framework::dataset::make("FractionalBits", 1, 7)))
{
@@ -152,7 +222,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionQS8Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
framework::dataset::make("FractionalBits", 1, 7)))
{
@@ -171,7 +242,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int16_
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(),
+ ArithmeticSubtractionQS16Dataset),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
framework::dataset::make("FractionalBits", 1, 15)))
{
diff --git a/tests/validation/fixtures/ArithmeticSubtractionFixture.h b/tests/validation/fixtures/ArithmeticSubtractionFixture.h
index 2c683d659f..9e65faef00 100644
--- a/tests/validation/fixtures/ArithmeticSubtractionFixture.h
+++ b/tests/validation/fixtures/ArithmeticSubtractionFixture.h
@@ -40,7 +40,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1>
class ArithmeticSubtractionValidationFixedPointFixture : public framework::Fixture
{
public:
@@ -93,31 +93,31 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position)
+ SimpleTensor<T3> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position)
{
// Create reference
- SimpleTensor<T> ref_src1{ shape, data_type0, 1, fixed_point_position };
- SimpleTensor<T> ref_src2{ shape, data_type1, 1, fixed_point_position };
+ SimpleTensor<T1> ref_src1{ shape, data_type0, 1, fixed_point_position };
+ SimpleTensor<T2> ref_src2{ shape, data_type1, 1, fixed_point_position };
// Fill reference
fill(ref_src1, 0);
fill(ref_src2, 1);
- return reference::arithmetic_subtraction<T>(ref_src1, ref_src2, output_data_type, convert_policy);
+ return reference::arithmetic_subtraction<T1, T2, T3>(ref_src1, ref_src2, output_data_type, convert_policy);
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
- int _fractional_bits{};
+ TensorType _target{};
+ SimpleTensor<T3> _reference{};
+ int _fractional_bits{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1>
+class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>
{
public:
template <typename...>
void setup(TensorShape shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
{
- ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0);
+ ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0);
}
};
} // namespace validation