aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp27
1 files changed, 12 insertions, 15 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index 23ae004912..133170e2d3 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -123,7 +123,7 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal
/** Transpose values to test with RHS matrix */
const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false });
-/**Broadcast bias from vector to matrix */
+/** Broadcast bias from vector to matrix */
const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} );
/** Configuration test */
@@ -155,18 +155,15 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned
TensorInfo(rhs_shape_reshaped, 1, data_type),
gemm_info);
+ const TensorShape bias_shape(N,
+ broadcast_bias? 1 : M,
+ broadcast_bias? 1 : b_value);
+
// Create tensors
CLTensor lhs = create_tensor<CLTensor>(lhs_shape, data_type);
CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type);
- CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
-
- TensorShape bias_shape = dst_shape;
- if (broadcast_bias)
- {
- bias_shape[1] = 1;
- bias_shape[2] = 1;
- }
CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type);
+ CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -257,7 +254,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -278,7 +275,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -300,7 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values),
+ beta_values),
broadcast_bias_values))
{
// Validate output
@@ -321,7 +318,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values),
+ beta_values),
broadcast_bias_values))
{
// Validate output
@@ -343,7 +340,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
@@ -364,7 +361,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);