From 916d1bcee42051721a82cfb46b52855c2fe56646 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 13 Aug 2018 11:20:41 +0100 Subject: COMPMID-1498 - Enable grouping in CLGEMMConvolutionLayer Change-Id: I15c7df21773145b03f42b6f78bd7ad2e5b8a5219 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144126 Tested-by: Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Georgios Pinitas --- tests/validation/fixtures/ConvolutionLayerFixture.h | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 4a6326480c..3b420eac09 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -102,6 +102,10 @@ protected: TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &info, bool reshape_weights, const Size2D &dilation, const ActivationLayerInfo act_info) { + ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0); + + const unsigned int num_groups = input_shape[2] / weights_shape[2]; + if(_data_layout == DataLayout::NHWC) { permute(input_shape, PermutationVector(2U, 0U, 1U)); @@ -123,7 +127,7 @@ protected: // Create and configure function FunctionType conv; - conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info); + conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -155,6 +159,10 @@ protected: SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation, const ActivationLayerInfo act_info) { + ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0); + + const unsigned int num_groups = input_shape[2] / weights_shape[2]; + // Create reference SimpleTensor src{ input_shape, _data_type, 1, _quantization_info }; SimpleTensor weights{ weights_shape, _data_type, 1, _quantization_info }; @@ -165,9 +173,9 @@ protected: fill(weights, 1); fill(bias, 2); - return (act_info.enabled()) ? reference::activation_layer(reference::convolution_layer(src, weights, bias, output_shape, info, dilation), + return (act_info.enabled()) ? reference::activation_layer(reference::convolution_layer(src, weights, bias, output_shape, info, dilation, num_groups), act_info) : - reference::convolution_layer(src, weights, bias, output_shape, info, dilation); + reference::convolution_layer(src, weights, bias, output_shape, info, dilation, num_groups); } TensorType _target{}; @@ -187,7 +195,8 @@ public: void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, data_layout, + ConvolutionValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_layout, QuantizationInfo(), act_info); } }; -- cgit v1.2.1