diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/NEON/ConvolutionLayer.cpp | 56 |
1 files changed, 42 insertions, 14 deletions
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 7a9230d37a..1f76925d96 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -767,21 +767,33 @@ FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath } #if defined(ARM_COMPUTE_ENABLE_BF16) - +// These tests currently only works with SVE length 256 +// If other SVE length is used a kernel will fail to be found +// This needs to be addressed in order to ensure it doesn't revert to FP32 kernels for systems with SVE length other than 256 FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length<float>() == 8)){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); + } } FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length<float>() == 8)){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); + } } #endif // ARM_COMPUTE_ENABLE_BF16 @@ -852,20 +864,36 @@ FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath<c combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16()){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + } } FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL, combine(framework::dataset::make("DataType", { DataType::F32 }), framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) { - ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + if(Scheduler::get().cpu_info().has_bf16()){ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + } + else{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); + } } #endif // ARM_COMPUTE_ENABLE_BF16 @@ -1141,7 +1169,7 @@ TEST_SUITE(Float) TEST_SUITE(BFLOAT16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::BFLOAT16)), + framework::dataset::make("DataType", Scheduler::get().cpu_info().has_bf16() ? DataType::BFLOAT16 : DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset)) { |