From f3622becf1f0d6bf5147ebb7d6d0f14d5252860a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 29 Jul 2019 14:27:16 +0100 Subject: COMPMID-1979: Fuse Activation Function in CLGEMM - part 4 Fused activation function in CLGEMM Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1640 Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- tests/datasets/LargeGEMMDataset.h | 28 ++++++++-------- tests/datasets/SmallGEMMDataset.h | 24 +++++++------- tests/validation/CL/GEMMMatrixMultiply.cpp | 2 +- .../CL/GEMMMatrixMultiplyInterleavedTransposed.cpp | 2 +- tests/validation/fixtures/GEMMFixture.h | 38 +++++++++++++++------- 5 files changed, 54 insertions(+), 40 deletions(-) (limited to 'tests') diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h index 0876ae1d2c..0ca0b04460 100644 --- a/tests/datasets/LargeGEMMDataset.h +++ b/tests/datasets/LargeGEMMDataset.h @@ -55,13 +55,13 @@ class LargeGEMMOutput3DDataset final : public GEMMDataset public: LargeGEMMOutput3DDataset() { - add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); - add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); - add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f); - add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); - add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f); - add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f); - add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f); + add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); + add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); + add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f); + add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); + add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f); + add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f); + add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f); } }; @@ -70,13 +70,13 @@ class LargeGEMMInputOutput3DDataset final : public GEMMDataset public: LargeGEMMInputOutput3DDataset() { - add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); - add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); - add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f); - add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); - add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f); - add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U, 2U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f); - add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U, 3U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f); + add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f); + add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f); + add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f); + add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f); + add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f); + add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f); + add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f); } }; } // namespace datasets diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h index ae3c3ed86d..45d1a1e07e 100644 --- a/tests/datasets/SmallGEMMDataset.h +++ b/tests/datasets/SmallGEMMDataset.h @@ -55,12 +55,12 @@ class SmallGEMMOutput3DDataset final : public GEMMDataset public: SmallGEMMOutput3DDataset() { - add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U, 7U, 2U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f); - add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U, 1U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f); - add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 4U, 3U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f); - add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U, 1U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f); - add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U, 8U, 2U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f); - add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U, 8U, 2U, 5U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f); + add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f); + add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f); + add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f); + add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f); + add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f); + add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f); } }; @@ -69,12 +69,12 @@ class SmallGEMMInputOutput3DDataset final : public GEMMDataset public: SmallGEMMInputOutput3DDataset() { - add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U, 14U, 13U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f); - add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U, 1U, 3U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f); - add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U, 12U, 2U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f); - add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U, 1U, 4U, 3U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f); - add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U, 16U, 3U, 2U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f); - add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U, 16U, 5U, 3U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f); + add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f); + add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f); + add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f); + add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f); + add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f); + add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f); } }; } // namespace datasets diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp index 21fd7125ec..8f7c0aaef1 100644 --- a/tests/validation/CL/GEMMMatrixMultiply.cpp +++ b/tests/validation/CL/GEMMMatrixMultiply.cpp @@ -67,7 +67,7 @@ RelativeTolerance rel_tolerance_f16(half(0.2)); constexpr float tolerance_num_f16 = 0.02f; /** Alpha values to test - Precommit */ -const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} ); +const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); /** Beta values to test - Precommit */ const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp index cae94b2e15..5d21cf4f34 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp @@ -77,7 +77,7 @@ RelativeTolerance rel_tolerance_f16(half(0.2)); constexpr float tolerance_num_f16 = 0.02f; /** Alpha values to test - Precommit */ -const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} ); +const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); /** Beta values to test - Precommit */ const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} ); diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index b36bb99246..a04a901b1c 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -44,7 +44,7 @@ namespace test { namespace validation { -template +template class GEMMValidationFixture : public framework::Fixture { public: @@ -87,7 +87,13 @@ protected: // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output. // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0), // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output). - gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d)); + gemm.configure(&a, + &b, + (disable_c) ? nullptr : &c, + &dst, + alpha, beta, + GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, (reinterpret_input_as_3d + || reinterpret_output_as_3d))); ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -122,6 +128,7 @@ protected: DataType data_type) { TensorShape shape_a_to_use = shape_a; + if(reinterpret_input_as_3d) { // Collapse the second and third dimension if the input is 3D @@ -131,22 +138,29 @@ protected: // Create reference SimpleTensor a{ shape_a_to_use, data_type, 1 }; SimpleTensor b{ shape_b, data_type, 1 }; - SimpleTensor c{ shape_c, data_type, 1 }; + SimpleTensor c{ output_shape, data_type, 1 }; // Fill reference fill(a, 0); fill(b, 1); - if(!disable_c) - { - fill(c, 2); - return reference::gemm(a, b, c, alpha, beta); - } - else + fill(c, 2); + + if(reinterpret_input_as_3d || reinterpret_output_as_3d) { - // Setting beta to 0 will effectively disable C for the - // computation of the reference: alpha * A * B + 0 * C - return reference::gemm(a, b, c, alpha, 0.f); + const int n = shape_b[0]; + const int m = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1]; + const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2]; + + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(c.data() + i * n, c.data(), n * sizeof(T)); + } } + + // Setting beta to 0 will effectively disable C for the + // computation of the reference: alpha * A * B + 0 * C + return reference::gemm(a, b, c, alpha, disable_c ? 0.f : beta); } TensorType _target{}; -- cgit v1.2.1