From ca62c6f53eb7244e6fed9f7e932608aa2496d9eb Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 23 Mar 2021 11:50:34 +0000 Subject: Mixed data-layout testing on high priority operators Change data layouts after the configure in validation tests for: - Scale - Pooling - FullyConnected - DepthwiseConvolution - DirectConvolution - FFTConvolution - WinogradConvolution - GEMMConvolution (Indirect GEMM included) Extending fixtures Fixes for new mixed data layout tests Resolves: COMPMID-4162 Change-Id: I2f2eb2075f7e24ab3872249d88cadb57b82c5dde Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5326 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- tests/validation/fixtures/PoolingLayerFixture.h | 40 ++++++++++++++++++++----- 1 file changed, 32 insertions(+), 8 deletions(-) (limited to 'tests/validation/fixtures/PoolingLayerFixture.h') diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h index af078d4ce3..ee81ff5538 100644 --- a/tests/validation/fixtures/PoolingLayerFixture.h +++ b/tests/validation/fixtures/PoolingLayerFixture.h @@ -47,14 +47,31 @@ class PoolingLayerValidationGenericFixture : public framework::Fixture public: template void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false, - QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo()) + QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo(), bool mixed_layout = false) { + _mixed_layout = mixed_layout; _pool_info = pool_info; _target = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); _reference = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices); } protected: + + void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) + { + const DataLayout data_layout = src.info()->data_layout(); + // Test Multi DataLayout graph cases, when the data layout changes after configure + src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + + // Compute Convolution function + layer.run(); + + // Reinstating original data layout for the test suite to properly check the values + src.info()->set_data_layout(data_layout); + dst.info()->set_data_layout(data_layout); + } + template void fill(U &&tensor) { @@ -110,9 +127,15 @@ protected: // Fill tensors fill(AccessorType(src)); - // Compute function - pool_layer.run(); - + if(_mixed_layout) + { + mix_layout(pool_layer, src, dst); + } + else + { + // Compute function + pool_layer.run(); + } return dst; } @@ -129,6 +152,7 @@ protected: TensorType _target{}; SimpleTensor _reference{}; PoolingLayerInfo _pool_info{}; + bool _mixed_layout{false}; TensorType _target_indices{}; SimpleTensor _ref_indices{}; }; @@ -144,7 +168,7 @@ public: } }; -template +template class PoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture { public: @@ -152,7 +176,7 @@ public: void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout) { PoolingLayerValidationGenericFixture::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding), - data_type, data_layout); + data_type, data_layout, false, mixed_layout); } }; @@ -168,7 +192,7 @@ public: } }; -template +template class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGenericFixture { public: @@ -177,7 +201,7 @@ public: QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo()) { PoolingLayerValidationGenericFixture::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding), - data_type, data_layout, false, input_qinfo, output_qinfo); + data_type, data_layout, false, input_qinfo, output_qinfo, mixed_layout); } }; -- cgit v1.2.1