diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-06-14 16:11:10 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-20 16:02:39 +0000 |
commit | e16c8906a2aedf00e910754a01fca8bc4189cfc7 (patch) | |
tree | de9b88917bb00a76a9df68c9e92f05e38c5de817 /tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | |
parent | 0cbfda629dd8f684e625173341bab972f004222c (diff) | |
download | ComputeLibrary-e16c8906a2aedf00e910754a01fca8bc4189cfc7.tar.gz |
COMPMID-2053: Fuse bias addition with CLGEMMMatrixMultiplyReshapedKernel
Change-Id: I5bfd38c94a6fd18a1cba2104f7e1b04e7bef6ec2
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1359
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp')
-rw-r--r-- | tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp index 23ae004912..133170e2d3 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp @@ -123,7 +123,7 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal /** Transpose values to test with RHS matrix */ const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false }); -/**Broadcast bias from vector to matrix */ +/** Broadcast bias from vector to matrix */ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} ); /** Configuration test */ @@ -155,18 +155,15 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned TensorInfo(rhs_shape_reshaped, 1, data_type), gemm_info); + const TensorShape bias_shape(N, + broadcast_bias? 1 : M, + broadcast_bias? 1 : b_value); + // Create tensors CLTensor lhs = create_tensor<CLTensor>(lhs_shape, data_type); CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type); - CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type); - - TensorShape bias_shape = dst_shape; - if (broadcast_bias) - { - bias_shape[1] = 1; - bias_shape[2] = 1; - } CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type); + CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -257,7 +254,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F32)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); @@ -278,7 +275,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F32)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); @@ -300,7 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values), + beta_values), broadcast_bias_values)) { // Validate output @@ -321,7 +318,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values), + beta_values), broadcast_bias_values)) { // Validate output @@ -343,7 +340,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); @@ -364,7 +361,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture< t_values_rhs), framework::dataset::make("DataType", DataType::F16)), a_values), - b_values)) + beta_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16); |