aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-14 16:11:10 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-20 16:02:39 +0000
commite16c8906a2aedf00e910754a01fca8bc4189cfc7 (patch)
treede9b88917bb00a76a9df68c9e92f05e38c5de817 /tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
parent0cbfda629dd8f684e625173341bab972f004222c (diff)
downloadComputeLibrary-e16c8906a2aedf00e910754a01fca8bc4189cfc7.tar.gz
COMPMID-2053: Fuse bias addition with CLGEMMMatrixMultiplyReshapedKernel
Change-Id: I5bfd38c94a6fd18a1cba2104f7e1b04e7bef6ec2 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1359 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp27
1 files changed, 12 insertions, 15 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index 23ae004912..133170e2d3 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -123,7 +123,7 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal
/** Transpose values to test with RHS matrix */
const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false });
-/**Broadcast bias from vector to matrix */
+/** Broadcast bias from vector to matrix */
const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} );
/** Configuration test */
@@ -155,18 +155,15 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned
TensorInfo(rhs_shape_reshaped, 1, data_type),
gemm_info);
+ const TensorShape bias_shape(N,
+ broadcast_bias? 1 : M,
+ broadcast_bias? 1 : b_value);
+
// Create tensors
CLTensor lhs = create_tensor<CLTensor>(lhs_shape, data_type);
CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type);
- CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
-
- TensorShape bias_shape = dst_shape;
- if (broadcast_bias)
- {
- bias_shape[1] = 1;
- bias_shape[2] = 1;
- }
CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type);
+ CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -257,7 +254,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -278,7 +275,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -300,7 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values),
+ beta_values),
broadcast_bias_values))
{
// Validate output
@@ -321,7 +318,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values),
+ beta_values),
broadcast_bias_values))
{
// Validate output
@@ -343,7 +340,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
@@ -364,7 +361,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);