From 68a3f56627b04acdefebe67d645727dd83889766 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 26 Jul 2018 11:44:03 +0100 Subject: COMPMID-1276 - Allow GEMM to work with 3D input tensor Skipped im2col in CLGEMMConvolutionLayer for 1x1 convolutions with NHWC data layout Change-Id: I894e6b952ed8605e8f3ffc0ffc25c24730d4664c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141909 Tested-by: Jenkins Reviewed-by: Anthony Barbier Reviewed-by: Georgios Pinitas --- tests/validation/fixtures/GEMMFixture.h | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index e4762cc5be..255b12c0ed 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -42,8 +42,8 @@ namespace test { namespace validation { -template -class GEMMValidationFixedPointFixture : public framework::Fixture +template +class GEMMValidationFixture : public framework::Fixture { public: template @@ -87,10 +87,7 @@ protected: // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output. // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 1), // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output). - bool is_output_reinterpreted_as_3D = output_shape.num_dimensions() > shape_a.num_dimensions(); - gemm.configure(&a, &b, &c, &dst, alpha, beta, - GEMMInfo(false, false, false, is_output_reinterpreted_as_3D ? output_shape[2] : 1)); - + gemm.configure(&a, &b, &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 1), reinterpret_input_as_3d)); ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -121,8 +118,15 @@ protected: SimpleTensor compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta, DataType data_type) { + TensorShape shape_a_to_use = shape_a; + if(reinterpret_input_as_3d) + { + // Collapse the second and third dimension if the input is 3D + shape_a_to_use.collapse(2U, 1U); + } + // Create reference - SimpleTensor a{ shape_a, data_type, 1 }; + SimpleTensor a{ shape_a_to_use, data_type, 1 }; SimpleTensor b{ shape_b, data_type, 1 }; SimpleTensor c{ shape_c, data_type, 1 }; @@ -139,16 +143,6 @@ protected: DataType _data_type{}; }; -template -class GEMMValidationFixture : public GEMMValidationFixedPointFixture -{ -public: - template - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, DataType data_type) - { - GEMMValidationFixedPointFixture::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type); - } -}; } // namespace validation } // namespace test } // namespace arm_compute -- cgit v1.2.1