diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-07-29 14:27:16 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-08-01 09:25:15 +0000 |
commit | f3622becf1f0d6bf5147ebb7d6d0f14d5252860a (patch) | |
tree | 60c5a1de2d24f9828a7896d200759150b0f5eb42 /tests/validation/fixtures | |
parent | c4d5136707280d98f660a67219114f5ee5b10fb8 (diff) | |
download | ComputeLibrary-f3622becf1f0d6bf5147ebb7d6d0f14d5252860a.tar.gz |
COMPMID-1979: Fuse Activation Function in CLGEMM - part 4
Fused activation function in CLGEMM
Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1640
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/GEMMFixture.h | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index b36bb99246..a04a901b1c 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -44,7 +44,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false> class GEMMValidationFixture : public framework::Fixture { public: @@ -87,7 +87,13 @@ protected: // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output. // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0), // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output). - gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d)); + gemm.configure(&a, + &b, + (disable_c) ? nullptr : &c, + &dst, + alpha, beta, + GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, (reinterpret_input_as_3d + || reinterpret_output_as_3d))); ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -122,6 +128,7 @@ protected: DataType data_type) { TensorShape shape_a_to_use = shape_a; + if(reinterpret_input_as_3d) { // Collapse the second and third dimension if the input is 3D @@ -131,22 +138,29 @@ protected: // Create reference SimpleTensor<T> a{ shape_a_to_use, data_type, 1 }; SimpleTensor<T> b{ shape_b, data_type, 1 }; - SimpleTensor<T> c{ shape_c, data_type, 1 }; + SimpleTensor<T> c{ output_shape, data_type, 1 }; // Fill reference fill(a, 0); fill(b, 1); - if(!disable_c) - { - fill(c, 2); - return reference::gemm<T>(a, b, c, alpha, beta); - } - else + fill(c, 2); + + if(reinterpret_input_as_3d || reinterpret_output_as_3d) { - // Setting beta to 0 will effectively disable C for the - // computation of the reference: alpha * A * B + 0 * C - return reference::gemm<T>(a, b, c, alpha, 0.f); + const int n = shape_b[0]; + const int m = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1]; + const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2]; + + // In case of broadcast, we need simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(c.data() + i * n, c.data(), n * sizeof(T)); + } } + + // Setting beta to 0 will effectively disable C for the + // computation of the reference: alpha * A * B + 0 * C + return reference::gemm<T>(a, b, c, alpha, disable_c ? 0.f : beta); } TensorType _target{}; |