diff options
Diffstat (limited to 'tests/validation/CL/BatchNormalizationLayer.cpp')
-rw-r--r-- | tests/validation/CL/BatchNormalizationLayer.cpp | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp index 69f8d7b635..c29a400402 100644 --- a/tests/validation/CL/BatchNormalizationLayer.cpp +++ b/tests/validation/CL/BatchNormalizationLayer.cpp @@ -78,6 +78,43 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran validate(dst.info()->valid_region(), valid_region); } +// *INDENT-OFF* +// clang-format off +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Invalid mean/var/beta/gamma shape + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point position + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + }), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + })), + framework::dataset::make("MVBGInfo",{ TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F16), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(5U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + })), + framework::dataset::make("Expected", { false, true, true, true, true, false})), + input_info, output_info, mvbg_info, expected) +{ + auto mean_info = mvbg_info; + auto var_info = mvbg_info; + auto beta_info = mvbg_info; + auto gamma_info = mvbg_info; + bool has_error = bool(CLBatchNormalizationLayer::validate(&input_info, &output_info, &mean_info, &var_info, &beta_info, &gamma_info, 1.f)); + ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); +} +// clang-format on +// *INDENT-ON* + TEST_SUITE(Float) TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), |