diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-05-21 13:32:43 +0100 |
---|---|---|
committer | Giuseppe Rossini <giuseppe.rossini@arm.com> | 2019-06-04 15:58:08 +0000 |
commit | b0f342ec315397e4b87d3a9cc3d12f3645c153bc (patch) | |
tree | 3bfd95d4196f6c45feb368b0a020f3bb304e79cd /tests/validation/fixtures | |
parent | bbac660f1959ed2ab58b31a8d5db524883da1754 (diff) | |
download | ComputeLibrary-b0f342ec315397e4b87d3a9cc3d12f3645c153bc.tar.gz |
COMPMID-2171: Fuse bias addition with CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
Change-Id: I1d1e1f28fe7022309d72900893e8368820ca0f89
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1259
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/GEMMFixture.h | 88 |
1 files changed, 68 insertions, 20 deletions
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index b7976104aa..34f9bd848c 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -390,7 +390,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fix public: template <typename...> void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, - bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha) + bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha, float beta, bool broadcast_bias) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0; @@ -407,8 +407,18 @@ public: const TensorShape lhs_shape(k, m, batch_size); const TensorShape rhs_shape(n, k, batch_size); - _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha); - _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha); + TensorShape bias_shape; + if(broadcast_bias) + { + bias_shape = TensorShape(n, 1, 1); + } + else + { + bias_shape = TensorShape(n, m, batch_size); + } + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias); } protected: @@ -423,11 +433,13 @@ protected: library->fill_borders_with_garbage(tensor, distribution_inf, i); } - TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha) + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, bool broadcast_bias) { // Create tensors - TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1); - TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1); + TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1); + TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1); + TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1); TensorType rhs_reshaped; TensorType dst; @@ -441,7 +453,7 @@ protected: ReshapeRHSFunctionType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K)); + gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, 0, false, broadcast_bias)); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -450,6 +462,7 @@ protected: lhs.allocator()->allocate(); rhs.allocator()->allocate(); rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); dst.allocator()->allocate(); ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -460,6 +473,7 @@ protected: // Fill tensors fill(AccessorType(lhs), 0); fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); // Compute GEMM reshape_rhs.run(); @@ -468,7 +482,7 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha) + SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, bool broadcast_bias) { TensorShape dst_shape = lhs_shape; dst_shape[0] = rhs_shape[0]; @@ -477,13 +491,31 @@ protected: // Create reference SimpleTensor<T> lhs{ lhs_shape, data_type, 1 }; SimpleTensor<T> rhs{ rhs_shape, data_type, 1 }; - SimpleTensor<T> c{ dst_shape, data_type, 1 }; + SimpleTensor<T> bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; // Fill reference fill(lhs, 0); fill(rhs, 1); - return reference::gemm<T>(lhs, rhs, c, alpha, 0.0f); + if(broadcast_bias) + { + SimpleTensor<T> tmp{ bias_shape, data_type, 1 }; + fill(tmp, 2); + for(int i = 0; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, tmp.data(), n * sizeof(T)); + } + } + else + { + fill(bias, 2); + } + + return (reference::gemm<T>(lhs, rhs, bias, alpha, beta)); } TensorType _target{}; @@ -590,7 +622,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::F public: template <typename...> void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, - bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha) + bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha, float beta) { GEMMLHSMatrixInfo lhs_info; lhs_info.m0 = m0; @@ -609,9 +641,10 @@ public: // Set the tensor shapes for LHS and RHS matrices const TensorShape lhs_shape(k, m, batch_size); const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, 1, 1); - _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha, m_h); - _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, m_h); + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h); + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h); } protected: @@ -622,12 +655,14 @@ protected: library->fill(tensor, distribution, i); } - TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha, + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, unsigned int m_h) { // Create tensors - TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1); - TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1); + TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1); + TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1); + TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1); TensorType rhs_reshaped; TensorType dst; @@ -641,7 +676,7 @@ protected: ReshapeRHSFunctionType reshape_rhs; GEMMFunctionType gemm; reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info); - gemm.configure(&lhs, &rhs_reshaped, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h)); + gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h, false, true)); ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -650,6 +685,7 @@ protected: lhs.allocator()->allocate(); rhs.allocator()->allocate(); rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); dst.allocator()->allocate(); ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -660,6 +696,7 @@ protected: // Fill tensors fill(AccessorType(lhs), 0); fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); // Compute GEMM reshape_rhs.run(); @@ -668,7 +705,7 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, unsigned int m_h) + SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h) { TensorShape dst_shape = lhs_shape; dst_shape.set(0, rhs_shape[0]); @@ -679,13 +716,24 @@ protected: // Create reference SimpleTensor<T> lhs{ lhs_shape, data_type, 1 }; SimpleTensor<T> rhs{ rhs_shape, data_type, 1 }; - SimpleTensor<T> c{ dst_shape, data_type, 1 }; + SimpleTensor<T> bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; // Fill reference fill(lhs, 0); fill(rhs, 1); - return reference::gemm<T>(lhs, rhs, c, alpha, 0.0f); + SimpleTensor<T> tmp{ bias_shape, data_type, 1 }; + fill(tmp, 2); + for(int i = 0; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, tmp.data(), n * sizeof(T)); + } + + return reference::gemm<T>(lhs, rhs, bias, alpha, beta); } TensorType _target{}; |