diff options
author | Isabella Gottardi <isabella.gottardi@arm.com> | 2017-10-30 15:28:13 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | b5908c257d554009a00de3aaa95b3721000ed185 (patch) | |
tree | 2c9fecbb241061e10bd9886bfe36056c4e6cf211 | |
parent | 05078ec491da8f282f4597b4cf1fe79cc16f4b22 (diff) | |
download | ComputeLibrary-b5908c257d554009a00de3aaa95b3721000ed185.tar.gz |
COMPMID-653 - Arithmetic Subtraction, add support different datatype
Change-Id: I2b3d65c8d8a85ad67b9972713d06f047f5bcd1ae
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/93693
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
-rw-r--r-- | tests/validation/CL/ArithmeticSubtraction.cpp | 106 | ||||
-rw-r--r-- | tests/validation/CPP/ArithmeticSubtraction.cpp | 13 | ||||
-rw-r--r-- | tests/validation/CPP/ArithmeticSubtraction.h | 4 | ||||
-rw-r--r-- | tests/validation/NEON/ArithmeticSubtraction.cpp | 102 | ||||
-rw-r--r-- | tests/validation/fixtures/ArithmeticSubtractionFixture.h | 22 |
5 files changed, 197 insertions, 50 deletions
diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp index 817a31fbd9..5e7d741009 100644 --- a/tests/validation/CL/ArithmeticSubtraction.cpp +++ b/tests/validation/CL/ArithmeticSubtraction.cpp @@ -47,8 +47,14 @@ namespace const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("DataType", DataType::U8)); -const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), +const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("DataType", DataType::S16)); +const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), + framework::dataset::make("DataType", DataType::S16)); +const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)), + framework::dataset::make("DataType", DataType::S16)); +const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)), + framework::dataset::make("DataType", DataType::S16)); const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)), framework::dataset::make("DataType", DataType::QS8)); const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)), @@ -62,8 +68,8 @@ const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset TEST_SUITE(CL) TEST_SUITE(ArithmeticSubtraction) -template <typename T> -using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>; +template <typename T1, typename T2 = T1, typename T3 = T1> +using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>; TEST_SUITE(U8) DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), @@ -89,7 +95,8 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::da validate(dst.info()->padding(), padding); } -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionU8Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -97,14 +104,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framew } TEST_SUITE_END() +template <typename T1, typename T2 = T1> +using CLArithmeticSubtractionToS16Fixture = CLArithmeticSubtractionFixture<T1, T2, int16_t>; + TEST_SUITE(S16) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })), +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), + framework::dataset::make("DataType", { DataType::U8, DataType::S16 })), + framework::dataset::make("DataType", { DataType::U8, DataType::S16 })), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), - shape, data_type, policy) + shape, data_type1, data_type2, policy) { // Create tensors - CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type); - CLTensor ref_src2 = create_tensor<CLTensor>(shape, DataType::S16); + CLTensor ref_src1 = create_tensor<CLTensor>(shape, data_type1); + CLTensor ref_src2 = create_tensor<CLTensor>(shape, data_type2); CLTensor dst = create_tensor<CLTensor>(shape, DataType::S16); // Create and Configure function @@ -121,28 +133,86 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame validate(ref_src2.info()->padding(), padding); validate(dst.info()->padding(), padding); } +TEST_SUITE(S16_S16_S16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(CLAccessor(_target), _reference); } +TEST_SUITE_END() -FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +TEST_SUITE(U8_U8_S16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionU8U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionU8U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(CLAccessor(_target), _reference); } TEST_SUITE_END() -template <typename T> -using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>; +TEST_SUITE(S16_U8_S16) +using CLAriSubS16U8ToS16Fixture = CLArithmeticSubtractionToS16Fixture<int16_t, uint8_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionS16U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionS16U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} +TEST_SUITE_END() + +TEST_SUITE(U8_S16_S16) +using CLAriSubU8S16ToS16Fixture = CLArithmeticSubtractionToS16Fixture<uint8_t, int16_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionU8S16S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionU8S16S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} +TEST_SUITE_END() +TEST_SUITE_END() + +template <typename T1, typename T2 = T1, typename T3 = T1> +using CLArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T1, T2, T3>; TEST_SUITE(Quantized) TEST_SUITE(QS8) -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionQS8Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), framework::dataset::make("FractionalBits", 1, 7))) { @@ -150,7 +220,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int8_t validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionQS8Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), framework::dataset::make("FractionalBits", 1, 7))) { @@ -169,7 +240,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixedPointFixture<int16_ validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionQS16Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), framework::dataset::make("FractionalBits", 1, 15))) { diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp index 80bdb15a49..bed2d37090 100644 --- a/tests/validation/CPP/ArithmeticSubtraction.cpp +++ b/tests/validation/CPP/ArithmeticSubtraction.cpp @@ -34,23 +34,26 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy) +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy) { - SimpleTensor<T> result(src1.shape(), dst_data_type); + SimpleTensor<T3> result(src1.shape(), dst_data_type); - using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type; + using intermediate_type = typename common_promoted_signed_type<typename std::conditional<sizeof(T1) >= sizeof(T2), T1, T2>::type >::intermediate_type; for(int i = 0; i < src1.num_elements(); ++i) { intermediate_type val = static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]); - result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val); + result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(val) : static_cast<T3>(val); } return result; } template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy); diff --git a/tests/validation/CPP/ArithmeticSubtraction.h b/tests/validation/CPP/ArithmeticSubtraction.h index 18b0d121a0..9308314bda 100644 --- a/tests/validation/CPP/ArithmeticSubtraction.h +++ b/tests/validation/CPP/ArithmeticSubtraction.h @@ -35,8 +35,8 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp index dcaf9d987b..fcd415b130 100644 --- a/tests/validation/NEON/ArithmeticSubtraction.cpp +++ b/tests/validation/NEON/ArithmeticSubtraction.cpp @@ -49,6 +49,12 @@ const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset:: DataType::U8)); const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("DataType", DataType::S16)); +const auto ArithmeticSubtractionU8U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), + framework::dataset::make("DataType", DataType::S16)); +const auto ArithmeticSubtractionS16U8S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::S16), framework::dataset::make("DataType", DataType::U8)), + framework::dataset::make("DataType", DataType::S16)); +const auto ArithmeticSubtractionU8S16S16Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::S16)), + framework::dataset::make("DataType", DataType::S16)); const auto ArithmeticSubtractionQS8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS8), framework::dataset::make("DataType", DataType::QS8)), framework::dataset::make("DataType", DataType::QS8)); const auto ArithmeticSubtractionQS16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QS16), framework::dataset::make("DataType", DataType::QS16)), @@ -64,8 +70,8 @@ const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset TEST_SUITE(NEON) TEST_SUITE(ArithmeticSubtraction) -template <typename T> -using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>; +template <typename T1, typename T2 = T1, typename T3 = T1> +using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>; TEST_SUITE(U8) DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), @@ -99,14 +105,19 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<uint8_t>, framew } TEST_SUITE_END() +template <typename T1, typename T2 = T1> +using NEArithmeticSubtractionToS16Fixture = NEArithmeticSubtractionFixture<T1, T2, int16_t>; + TEST_SUITE(S16) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("DataType", { DataType::U8, DataType::S16 })), +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), + framework::dataset::make("DataType", { DataType::U8, DataType::S16 })), + framework::dataset::make("DataType", { DataType::U8, DataType::S16 })), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), - shape, data_type, policy) + shape, data_type1, data_type2, policy) { // Create tensors - Tensor ref_src1 = create_tensor<Tensor>(shape, data_type); - Tensor ref_src2 = create_tensor<Tensor>(shape, DataType::S16); + Tensor ref_src1 = create_tensor<Tensor>(shape, data_type1); + Tensor ref_src2 = create_tensor<Tensor>(shape, data_type2); Tensor dst = create_tensor<Tensor>(shape, DataType::S16); // Create and Configure function @@ -124,27 +135,86 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(frame validate(dst.info()->padding(), padding); } -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +TEST_SUITE(S16_S16_S16) +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + +TEST_SUITE(U8_U8_S16) +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionU8U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionToS16Fixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionU8U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() + +TEST_SUITE(S16_U8_S16) +using NEAriSubS16U8ToS16Fixture = NEArithmeticSubtractionToS16Fixture<int16_t, uint8_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionS16U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubS16U8ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionS16U8S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(Accessor(_target), _reference); } TEST_SUITE_END() -template <typename T> -using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T>; +TEST_SUITE(U8_S16_S16) +using NEAriSubU8S16ToS16Fixture = NEArithmeticSubtractionToS16Fixture<uint8_t, int16_t>; +FIXTURE_DATA_TEST_CASE(RunSmall, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionU8S16S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, NEAriSubU8S16ToS16Fixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionU8S16S16Dataset), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) +{ + // Validate output + validate(Accessor(_target), _reference); +} +TEST_SUITE_END() +TEST_SUITE_END() + +template <typename T1, typename T2 = T1, typename T3 = T1> +using NEArithmeticSubtractionFixedPointFixture = ArithmeticSubtractionValidationFixedPointFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>; TEST_SUITE(Quantized) TEST_SUITE(QS8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQS8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), + ArithmeticSubtractionQS8Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), framework::dataset::make("FractionalBits", 1, 7))) { @@ -152,7 +222,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int8_t validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS8Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionQS8Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), framework::dataset::make("FractionalBits", 1, 7))) { @@ -171,7 +242,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixedPointFixture<int16_ validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionQS16Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), + ArithmeticSubtractionQS16Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), framework::dataset::make("FractionalBits", 1, 15))) { diff --git a/tests/validation/fixtures/ArithmeticSubtractionFixture.h b/tests/validation/fixtures/ArithmeticSubtractionFixture.h index 2c683d659f..9e65faef00 100644 --- a/tests/validation/fixtures/ArithmeticSubtractionFixture.h +++ b/tests/validation/fixtures/ArithmeticSubtractionFixture.h @@ -40,7 +40,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1> class ArithmeticSubtractionValidationFixedPointFixture : public framework::Fixture { public: @@ -93,31 +93,31 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position) + SimpleTensor<T3> compute_reference(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, int fixed_point_position) { // Create reference - SimpleTensor<T> ref_src1{ shape, data_type0, 1, fixed_point_position }; - SimpleTensor<T> ref_src2{ shape, data_type1, 1, fixed_point_position }; + SimpleTensor<T1> ref_src1{ shape, data_type0, 1, fixed_point_position }; + SimpleTensor<T2> ref_src2{ shape, data_type1, 1, fixed_point_position }; // Fill reference fill(ref_src1, 0); fill(ref_src2, 1); - return reference::arithmetic_subtraction<T>(ref_src1, ref_src2, output_data_type, convert_policy); + return reference::arithmetic_subtraction<T1, T2, T3>(ref_src1, ref_src2, output_data_type, convert_policy); } - TensorType _target{}; - SimpleTensor<T> _reference{}; - int _fractional_bits{}; + TensorType _target{}; + SimpleTensor<T3> _reference{}; + int _fractional_bits{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2 = T1, typename T3 = T1> +class ArithmeticSubtractionValidationFixture : public ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3> { public: template <typename...> void setup(TensorShape shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) { - ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0); + ArithmeticSubtractionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T1, T2, T3>::setup(shape, data_type0, data_type1, output_data_type, convert_policy, 0); } }; } // namespace validation |