aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/ElementwisePower.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/ElementwisePower.cpp')
-rw-r--r--tests/validation/NEON/ElementwisePower.cpp29
1 files changed, 21 insertions, 8 deletions
diff --git a/tests/validation/NEON/ElementwisePower.cpp b/tests/validation/NEON/ElementwisePower.cpp
index 4305387c5f..9ac9eec280 100644
--- a/tests/validation/NEON/ElementwisePower.cpp
+++ b/tests/validation/NEON/ElementwisePower.cpp
@@ -51,6 +51,8 @@ const auto ElementwisePowerFP16Dataset = combine(combine(framework:
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
const auto ElementwisePowerFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("DataType", DataType::F32));
+const auto InPlaceDataSet = framework::dataset::make("InPlace", { false, true });
+const auto OutOfPlaceDataSet = framework::dataset::make("InPlace", { false });
} // namespace
TEST_SUITE(NEON)
@@ -91,7 +93,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(F16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwisePowerFixture<half>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), ElementwisePowerFP16Dataset))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwisePowerFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ElementwisePowerFP16Dataset),
+ InPlaceDataSet))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp16, 0.01);
@@ -101,13 +104,15 @@ TEST_SUITE_END() // F16
TEST_SUITE(F32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwisePowerFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallShapes(), ElementwisePowerFP32Dataset))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEElementwisePowerFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ElementwisePowerFP32Dataset),
+ InPlaceDataSet))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEElementwisePowerFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapes(), ElementwisePowerFP32Dataset))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEElementwisePowerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ElementwisePowerFP32Dataset),
+ InPlaceDataSet))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
@@ -116,15 +121,23 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEElementwisePowerFixture<float>, framework::Da
template <typename T>
using NEElementwisePowerBroadcastFixture = ElementwisePowerBroadcastValidationFixture<Tensor, Accessor, NEElementwisePower, T>;
-FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEElementwisePowerBroadcastFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallShapesBroadcast(),
- ElementwisePowerFP32Dataset))
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEElementwisePowerBroadcastFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapesBroadcast(),
+ ElementwisePowerFP32Dataset),
+ OutOfPlaceDataSet))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
}
-
-FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEElementwisePowerBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeShapesBroadcast(),
- ElementwisePowerFP32Dataset))
+FIXTURE_DATA_TEST_CASE(RunTinyBroadcastInPlace, NEElementwisePowerBroadcastFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::TinyShapesBroadcastInplace(),
+ ElementwisePowerFP32Dataset),
+ InPlaceDataSet))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32, 0.01);
+}
+FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEElementwisePowerBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapesBroadcast(),
+ ElementwisePowerFP32Dataset),
+ OutOfPlaceDataSet))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32, 0.01);