diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2018-08-13 11:20:41 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | 916d1bcee42051721a82cfb46b52855c2fe56646 (patch) | |
tree | e3e38a8deddc558cabeda6fb7d14b2d45c8db2c4 /tests/validation/fixtures/ConvolutionLayerFixture.h | |
parent | 61de78aba1b405663c6620be15418873a2ee914a (diff) | |
download | ComputeLibrary-916d1bcee42051721a82cfb46b52855c2fe56646.tar.gz |
COMPMID-1498 - Enable grouping in CLGEMMConvolutionLayer
Change-Id: I15c7df21773145b03f42b6f78bd7ad2e5b8a5219
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144126
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 4a6326480c..3b420eac09 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -102,6 +102,10 @@ protected: TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &info, bool reshape_weights, const Size2D &dilation, const ActivationLayerInfo act_info) { + ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0); + + const unsigned int num_groups = input_shape[2] / weights_shape[2]; + if(_data_layout == DataLayout::NHWC) { permute(input_shape, PermutationVector(2U, 0U, 1U)); @@ -123,7 +127,7 @@ protected: // Create and configure function FunctionType conv; - conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info); + conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -155,6 +159,10 @@ protected: SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, const Size2D &dilation, const ActivationLayerInfo act_info) { + ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0); + + const unsigned int num_groups = input_shape[2] / weights_shape[2]; + // Create reference SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info }; SimpleTensor<T> weights{ weights_shape, _data_type, 1, _quantization_info }; @@ -165,9 +173,9 @@ protected: fill(weights, 1); fill(bias, 2); - return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation), + return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups), act_info) : - reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation); + reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups); } TensorType _target{}; @@ -187,7 +195,8 @@ public: void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info) { - ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, data_layout, + ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, + data_type, data_layout, QuantizationInfo(), act_info); } }; |