aboutsummaryrefslogtreecommitdiff
path: root/tests/validation
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-07-29 14:27:16 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-08-01 09:25:15 +0000
commitf3622becf1f0d6bf5147ebb7d6d0f14d5252860a (patch)
tree60c5a1de2d24f9828a7896d200759150b0f5eb42 /tests/validation
parentc4d5136707280d98f660a67219114f5ee5b10fb8 (diff)
downloadComputeLibrary-f3622becf1f0d6bf5147ebb7d6d0f14d5252860a.tar.gz
COMPMID-1979: Fuse Activation Function in CLGEMM - part 4
Fused activation function in CLGEMM Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1640 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation')
-rw-r--r--tests/validation/CL/GEMMMatrixMultiply.cpp2
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp2
-rw-r--r--tests/validation/fixtures/GEMMFixture.h38
3 files changed, 28 insertions, 14 deletions
diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp
index 21fd7125ec..8f7c0aaef1 100644
--- a/tests/validation/CL/GEMMMatrixMultiply.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiply.cpp
@@ -67,7 +67,7 @@ RelativeTolerance<half> rel_tolerance_f16(half(0.2));
constexpr float tolerance_num_f16 = 0.02f;
/** Alpha values to test - Precommit */
-const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} );
+const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
/** Beta values to test - Precommit */
const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
index cae94b2e15..5d21cf4f34 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
@@ -77,7 +77,7 @@ RelativeTolerance<half> rel_tolerance_f16(half(0.2));
constexpr float tolerance_num_f16 = 0.02f;
/** Alpha values to test - Precommit */
-const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} );
+const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
/** Beta values to test - Precommit */
const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index b36bb99246..a04a901b1c 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -44,7 +44,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false>
class GEMMValidationFixture : public framework::Fixture
{
public:
@@ -87,7 +87,13 @@ protected:
// The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
// If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0),
// in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
- gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d));
+ gemm.configure(&a,
+ &b,
+ (disable_c) ? nullptr : &c,
+ &dst,
+ alpha, beta,
+ GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, (reinterpret_input_as_3d
+ || reinterpret_output_as_3d)));
ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -122,6 +128,7 @@ protected:
DataType data_type)
{
TensorShape shape_a_to_use = shape_a;
+
if(reinterpret_input_as_3d)
{
// Collapse the second and third dimension if the input is 3D
@@ -131,22 +138,29 @@ protected:
// Create reference
SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
SimpleTensor<T> b{ shape_b, data_type, 1 };
- SimpleTensor<T> c{ shape_c, data_type, 1 };
+ SimpleTensor<T> c{ output_shape, data_type, 1 };
// Fill reference
fill(a, 0);
fill(b, 1);
- if(!disable_c)
- {
- fill(c, 2);
- return reference::gemm<T>(a, b, c, alpha, beta);
- }
- else
+ fill(c, 2);
+
+ if(reinterpret_input_as_3d || reinterpret_output_as_3d)
{
- // Setting beta to 0 will effectively disable C for the
- // computation of the reference: alpha * A * B + 0 * C
- return reference::gemm<T>(a, b, c, alpha, 0.f);
+ const int n = shape_b[0];
+ const int m = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1];
+ const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2];
+
+ // In case of broadcast, we need simply copy the first into the following "M" ones
+ for(int i = 1; i < m * batch_size; i++)
+ {
+ memcpy(c.data() + i * n, c.data(), n * sizeof(T));
+ }
}
+
+ // Setting beta to 0 will effectively disable C for the
+ // computation of the reference: alpha * A * B + 0 * C
+ return reference::gemm<T>(a, b, c, alpha, disable_c ? 0.f : beta);
}
TensorType _target{};