From 64f23003b90f23ba22a6b5bf377f347fcb3e235f Mon Sep 17 00:00:00 2001 From: David Svantesson-Yeung Date: Wed, 27 Mar 2024 12:55:45 +0000 Subject: Runtime checks for bf16 fixed format tests Add checks for bf16 support for bf16 fixed format tests. This ensures tests pass in multi_isa setting where library was compiled with bf16 support, even on systems that do not support bf16. Also adds runtime check to GEMMConvolutionLayer/Float/BFLOAT16/RunSmall. Resolves: COMPMID-6922 Signed-off-by: David Svantesson-Yeung Change-Id: Ic0f09ba34b5a2c64be8bfc848a4457a6b1c4d1c3 Signed-off-by: David Svantesson-Yeung Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11408 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/NEON/ConvolutionLayer.cpp | 56 ++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 7a9230d37a..1f76925d96 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -767,21 +767,33 @@ FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath } #if defined(ARM_COMPUTE_ENABLE_BF16) - +// These tests currently only works with SVE length 256 +// If other SVE length is used a kernel will fail to be found +// This needs to be addressed in order to ensure it doesn't revert to FP32 kernels for systems with SVE length other than 256 FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length() == 8)){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); + } } FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length() == 8)){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); + } } #endif // ARM_COMPUTE_ENABLE_BF16 @@ -852,20 +864,36 @@ FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16()){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + } } #endif // ARM_COMPUTE_ENABLE_BF16 @@ -1141,7 +1169,7 @@ TEST_SUITE(Float) TEST_SUITE(BFLOAT16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::BFLOAT16)), + framework::dataset::make("DataType", Scheduler::get().cpu_info().has_bf16() ? DataType::BFLOAT16 : DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset)) { -- cgit v1.2.1