aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-08-13 11:20:41 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit916d1bcee42051721a82cfb46b52855c2fe56646 (patch)
treee3e38a8deddc558cabeda6fb7d14b2d45c8db2c4 /tests/validation/fixtures/ConvolutionLayerFixture.h
parent61de78aba1b405663c6620be15418873a2ee914a (diff)
downloadComputeLibrary-916d1bcee42051721a82cfb46b52855c2fe56646.tar.gz
COMPMID-1498 - Enable grouping in CLGEMMConvolutionLayer
Change-Id: I15c7df21773145b03f42b6f78bd7ad2e5b8a5219 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144126 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h17
1 files changed, 13 insertions, 4 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 4a6326480c..3b420eac09 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -102,6 +102,10 @@ protected:
TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &info,
bool reshape_weights, const Size2D &dilation, const ActivationLayerInfo act_info)
{
+ ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0);
+
+ const unsigned int num_groups = input_shape[2] / weights_shape[2];
+
if(_data_layout == DataLayout::NHWC)
{
permute(input_shape, PermutationVector(2U, 0U, 1U));
@@ -123,7 +127,7 @@ protected:
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info);
+ conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info, num_groups);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -155,6 +159,10 @@ protected:
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
const Size2D &dilation, const ActivationLayerInfo act_info)
{
+ ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0);
+
+ const unsigned int num_groups = input_shape[2] / weights_shape[2];
+
// Create reference
SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info };
SimpleTensor<T> weights{ weights_shape, _data_type, 1, _quantization_info };
@@ -165,9 +173,9 @@ protected:
fill(weights, 1);
fill(bias, 2);
- return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation),
+ return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups),
act_info) :
- reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation);
+ reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups);
}
TensorType _target{};
@@ -187,7 +195,8 @@ public:
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
DataLayout data_layout, ActivationLayerInfo act_info)
{
- ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, data_layout,
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
+ data_type, data_layout,
QuantizationInfo(), act_info);
}
};