diff options
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyNative.cpp')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyNative.cpp | 76 |
1 files changed, 70 insertions, 6 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp index c060a8bc05..0ddf43766f 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyNative.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyNative.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,12 +21,12 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h" #include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/CLTensorAllocator.h" +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h" #include "tests/CL/CLAccessor.h" #include "tests/CL/Helper.h" #include "tests/PaddingCalculator.h" @@ -44,9 +44,10 @@ namespace test namespace validation { using namespace arm_compute::misc::shape_calculator; +using namespace arm_compute::opencl::kernels; -// Create function for CLGEMMMatrixMultiplyNativeKernel -using CLGEMMMatrixMultiplyNative = CLSynthetizeFunction<CLGEMMMatrixMultiplyNativeKernel>; +// Create function for ClGemmMatrixMultiplyNativeKernel +using CLGEMMMatrixMultiplyNative = CLSynthetizeOperator<ClGemmMatrixMultiplyNativeKernel>; // Fixture for CLGEMMMatrixMultiplyNative template <typename T> @@ -90,8 +91,8 @@ const auto b_values = framework::dataset::make("batch_size", 1, 3); /** Activation values to test */ const auto act_values = framework::dataset::make("Activation", { - ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::ELU), }); /** M0 values to test - Precommit */ @@ -118,6 +119,28 @@ const auto k0_values_nightly = framework::dataset::make("K0", { 2, 3, 4, 8 }); /** Broadcast bias from vector to matrix */ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); +/** Boundary handling cases for testing partial/non-partial (full) block dimensions, resulting from different combinations + * of M, M0, N and N0 values. + * M0 and N0 are kept constant, while the different test cases need to vary M and N. + * + * Eg. M = 64 and N = 33 result in a block dimension that has no partial blocks (all full blocks) in Y dimension and + * parital blocks in X dimension. + */ +const auto boundary_handling_cases = combine(combine(combine(combine(combine(combine(combine(combine(combine( + // Large k to force potential out-of-bound reads on input0 + framework::dataset::make("K", 315), + // Batch size == 1 to force potential out-of-bound reads on input0 + framework::dataset::make("batch_size", 1)), + framework::dataset::make("M0", 4)), + framework::dataset::make("N0", 4)), + framework::dataset::make("K0", 4)), + // Only need to test F32 as F16 shares identical boundary handling logics + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("alpha", -0.75f )), + framework::dataset::make("beta", -0.35f )), + broadcast_bias_values), + framework::dataset::make("Activation", ActivationLayerInfo())); + /** Configuration test */ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, bool broadcast_bias, DataType data_type, const ActivationLayerInfo &act_info) { @@ -162,7 +185,7 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned // Create and configure function CLGEMMMatrixMultiplyNative gemm; - gemm.configure(&lhs, &rhs, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, kernel_info); + gemm.configure(lhs.info(), rhs.info(), bias.info(), dst.info(), 1.0f, 1.0f, lhs_info, rhs_info, kernel_info); } } // namespace @@ -185,6 +208,46 @@ m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, broadcast_bias validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, broadcast_bias, DataType::F32, act_value); } +FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingPartialInXPartialInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL, + combine(combine( + framework::dataset::make("M", 3), + framework::dataset::make("N", 1)), + boundary_handling_cases)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingPartialInXFullInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL, + combine(combine( + framework::dataset::make("M", 64), + framework::dataset::make("N", 51)), + boundary_handling_cases)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingFullInXFullInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL, + combine(combine( + framework::dataset::make("M", 64), + framework::dataset::make("N", 32)), + boundary_handling_cases)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunSmallBoundaryHandlingFullInXPartialInY, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL, + combine(combine( + framework::dataset::make("M", 37), + framework::dataset::make("N", 32)), + boundary_handling_cases)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyNativeFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, @@ -260,6 +323,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyNative3DFixture<float>, f // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); } + TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float TEST_SUITE_END() // GEMMMatrixMulipltyNative |