From beabe3bdf47306d0940ddf2ddf52ada6903a0875 Mon Sep 17 00:00:00 2001 From: Moritz Pflanzer Date: Thu, 31 Aug 2017 14:56:32 +0100 Subject: COMPMID-481: Add AArch64 GEMM Change-Id: I34f94f99cb05f0eabafee13c5e623ee779b72360 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83741 Tested-by: Kaizen Reviewed-by: Anthony Barbier Reviewed-by: Pablo Tello --- tests/validation/fixtures/ConvolutionLayerFixture.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index dd2df727e9..fcaf4ef42b 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -26,6 +26,7 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" #include "tests/IAccessor.h" @@ -39,6 +40,8 @@ namespace arm_compute { +class NEConvolutionLayer; + namespace test { namespace validation @@ -85,6 +88,8 @@ protected: { // Check if its a "fully connected" convolution const bool is_fully_connected_convolution = (output_shape.x() == 1 && output_shape.y() == 1); + const bool is_optimised = std::is_same::value && NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && data_type == DataType::F32; + reshaped_weights_shape.collapse(3); if(bias_shape.total_size() > 0) @@ -92,7 +97,7 @@ protected: reshaped_weights_shape.set(0, reshaped_weights_shape.x() + 1); } - if(is_fully_connected_convolution) + if(is_fully_connected_convolution || is_optimised) { const size_t shape_x = reshaped_weights_shape.x(); reshaped_weights_shape.set(0, reshaped_weights_shape.y()); @@ -138,6 +143,7 @@ protected: if(!reshape_weights) { const bool is_fully_connected_convolution = (output_shape.x() == 1 && output_shape.y() == 1); + const bool is_optimised = std::is_same::value && NEScheduler::get().cpu_info().CPU >= CPUTarget::ARMV8 && data_type == DataType::F32; TensorShape tmp_weights_shape(weights_shape); SimpleTensor tmp_weights(tmp_weights_shape, data_type, 1, fixed_point_position); @@ -149,7 +155,7 @@ protected: tmp_weights = linearise_weights(tmp_weights, &tmp_bias); - if(!is_fully_connected_convolution) + if(!is_fully_connected_convolution && !is_optimised) { // Transpose with interleave const int interleave_size = 16 / tmp_weights.element_size(); -- cgit v1.2.1