diff options
Diffstat (limited to 'tests/validation/CL/BatchNormalizationLayer.cpp')
-rw-r--r-- | tests/validation/CL/BatchNormalizationLayer.cpp | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp index ac30c638b5..69f8d7b635 100644 --- a/tests/validation/CL/BatchNormalizationLayer.cpp +++ b/tests/validation/CL/BatchNormalizationLayer.cpp @@ -43,9 +43,10 @@ namespace validation { namespace { -constexpr AbsoluteTolerance<float> tolerance_f(0.00001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ -constexpr AbsoluteTolerance<float> tolerance_qs8(3.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */ -constexpr AbsoluteTolerance<float> tolerance_qs16(6.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS16 */ +constexpr AbsoluteTolerance<float> tolerance_f32(0.00001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */ +constexpr AbsoluteTolerance<float> tolerance_f16(0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */ +constexpr AbsoluteTolerance<float> tolerance_qs8(3.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */ +constexpr AbsoluteTolerance<float> tolerance_qs16(6.0f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS16 */ } // namespace TEST_SUITE(CL) @@ -54,7 +55,7 @@ TEST_SUITE(BatchNormalizationLayer) template <typename T> using CLBatchNormalizationLayerFixture = BatchNormalizationLayerValidationFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>; -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F32 })), +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", { DataType::QS8, DataType::QS16, DataType::F16, DataType::F32 })), shape0, shape1, epsilon, dt) { // Set fixed point position data type allowed @@ -78,14 +79,25 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran } TEST_SUITE(Float) +TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), framework::dataset::make("DataType", DataType::F32))) { // Validate output - validate(CLAccessor(_target), _reference, tolerance_f, 0); + validate(CLAccessor(_target), _reference, tolerance_f32, 0); } TEST_SUITE_END() +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(Random, CLBatchNormalizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::RandomBatchNormalizationLayerDataset(), + framework::dataset::make("DataType", DataType::F16))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_f16, 0); +} +TEST_SUITE_END() +TEST_SUITE_END() + TEST_SUITE(Quantized) template <typename T> using CLBatchNormalizationLayerFixedPointFixture = BatchNormalizationLayerValidationFixedPointFixture<CLTensor, CLAccessor, CLBatchNormalizationLayer, T>; |