diff options
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp index 23ae004912..133170e2d3 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp @@ -123,7 +123,7 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal /** Transpose values to test with RHS matrix */ const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false }); -/**Broadcast bias from vector to matrix */ +/** Broadcast bias from vector to matrix */ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} ); /** Configuration test */ @@ -155,18 +155,15 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned TensorInfo(rhs_shape_reshaped, 1, data_type), gemm_info); + const TensorShape bias_shape(N, + broadcast_bias? 1 : M, + broadcast_bias? 1 : b_value); + // Create tensors CLTensor lhs = create_tensor<CLTensor>(lhs_shape, data_type); CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type); - CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type); - - TensorShape bias_shape = dst_shape; - if (broadcast_bias) - { - bias_shape[1] = 1; - bias_shape[2] = 1; - } CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type); + CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -257,7 +254,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F32)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); @@ -278,7 +275,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F32)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); @@ -300,7 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values), + beta_values), broadcast_bias_values)) { // Validate output @@ -321,7 +318,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values), + beta_values), broadcast_bias_values)) { // Validate output @@ -343,7 +340,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); @@ -364,7 +361,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); |