aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-01-11 11:30:55 +0000
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-01-15 15:06:12 +0000
commitbacfec5ecc3bd737c3d4eb2b0c165e0e55cc61f0 (patch)
tree12d115e3e158ac8bf434319fa9af44d75dc47785 /tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
parent55e167814d462a803dbac82db17603cbe1258b4f (diff)
downloadComputeLibrary-bacfec5ecc3bd737c3d4eb2b0c165e0e55cc61f0.tar.gz
COMPMID-1687: Optimize CLGEMMMatrixMultiplyKernel (part 1)
Extended CLGEMMMatrixMultiplyReshapedKernel to support more parameters Change-Id: I4a27f986e3fe2dd071a4ccba5cfa0565f3db39ad Reviewed-on: https://review.mlplatform.org/495 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp35
1 files changed, 16 insertions, 19 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index 1a41e459bd..564d3f4c2f 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,10 +47,10 @@ namespace validation
using namespace arm_compute::misc::shape_calculator;
// Create function for CLGEMMReshapeLHSMatrixKernel
-using CLGEMMReshapeLHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeLHSMatrixKernel, 16>;
+using CLGEMMReshapeLHSMatrix = CLSynthetizeFunction<CLGEMMReshapeLHSMatrixKernel>;
// Create function for CLGEMMReshapeRHSMatrixKernel
-using CLGEMMReshapeRHSMatrix = CLSynthetizeFunctionInitOutputWithZeroAndWithZeroConstantBorder<CLGEMMReshapeRHSMatrixKernel, 16>;
+using CLGEMMReshapeRHSMatrix = CLSynthetizeFunction<CLGEMMReshapeRHSMatrixKernel>;
// Create function for CLGEMMMatrixMultiplyReshapedKernel
using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction<CLGEMMMatrixMultiplyReshapedKernel>;
@@ -74,7 +74,7 @@ RelativeTolerance<half> rel_tolerance_f16(half(0.2));
constexpr float tolerance_num_f16 = 0.02f;
/** Alpha values to test - Precommit */
-const auto a_values_precommit = framework::dataset::make("alpha", {1.0f, -0.75f} );
+const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
/** M values to test */
const auto m_values = framework::dataset::make("M", 37);
@@ -89,7 +89,7 @@ const auto m_h_values = framework::dataset::make("M_H", 7);
const auto n_values = framework::dataset::make("N", 51);
/** K values to test */
-const auto k_values = framework::dataset::make("K", 43);
+const auto k_values = framework::dataset::make("K", 23);
/** Batch size values to test */
const auto b_values = framework::dataset::make("batch_size", 1, 3);
@@ -109,17 +109,14 @@ const auto v0_values_precommit = framework::dataset::make("V0", 1, 3);
/** H0 values to test - Precommit */
const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
-/** Alpha values to test - Nightly */
-const auto a_values_nightly = framework::dataset::make("alpha", {1.0f, -0.75f, 0.85f} );
-
/** M0 values to test - Nightly */
-const auto m0_values_nightly = framework::dataset::make("M0", 2, 8);
+const auto m0_values_nightly = framework::dataset::make("M0", 2, 7);
/** N0 values to test - Nightly */
-const auto n0_values_nightly = framework::dataset::make("N0", { 2, 4, 8, 16 });
+const auto n0_values_nightly = framework::dataset::make("N0", { 2, 3, 4, 8 });
/** K0 values to test - Nightly */
-const auto k0_values_nightly = framework::dataset::make("K0", { 4, 8, 16 });
+const auto k0_values_nightly = framework::dataset::make("K0", { 2, 3, 4, 8 });
/** V0 values to test - Nightly */
const auto v0_values_nightly = framework::dataset::make("V0", 1, 4);
@@ -219,7 +216,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values_precommit))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -239,7 +236,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values_nightly))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -260,7 +257,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values_precommit))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -281,7 +278,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values_nightly))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -303,7 +300,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, fram
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values_precommit))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
@@ -323,7 +320,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, fram
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values_nightly))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
@@ -344,7 +341,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values_precommit))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
@@ -365,7 +362,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values_nightly))
+ a_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);