diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2021-03-23 11:50:34 +0000 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2021-04-06 11:28:16 +0000 |
commit | ca62c6f53eb7244e6fed9f7e932608aa2496d9eb (patch) | |
tree | e5c7630c40d9f009e9baef4e849c6c7cc6ca90a7 /tests/validation/fixtures/ScaleFixture.h | |
parent | 4ed7b39dbbe8ccc6267a9eacefca51717c3b3e10 (diff) | |
download | ComputeLibrary-ca62c6f53eb7244e6fed9f7e932608aa2496d9eb.tar.gz |
Mixed data-layout testing on high priority operators
Change data layouts after the configure in validation tests for:
- Scale
- Pooling
- FullyConnected
- DepthwiseConvolution
- DirectConvolution
- FFTConvolution
- WinogradConvolution
- GEMMConvolution (Indirect GEMM included)
Extending fixtures
Fixes for new mixed data layout tests
Resolves: COMPMID-4162
Change-Id: I2f2eb2075f7e24ab3872249d88cadb57b82c5dde
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5326
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r-- | tests/validation/fixtures/ScaleFixture.h | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index dd521470e6..9e0f620abe 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -46,7 +46,7 @@ class ScaleValidationGenericFixture : public framework::Fixture public: template <typename...> void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, - bool align_corners) + bool align_corners, bool mixed_layout) { _shape = shape; _policy = policy; @@ -55,6 +55,7 @@ public: _data_type = data_type; _quantization_info = quantization_info; _align_corners = align_corners; + _mixed_layout = mixed_layout; generate_scale(shape); @@ -67,6 +68,22 @@ public: } protected: + + void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) + { + const DataLayout data_layout = src.info()->data_layout(); + // Test Multi DataLayout graph cases, when the data layout changes after configure + src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + + // Compute Convolution function + layer.run(); + + // Reinstating original data layout for the test suite to properly check the values + src.info()->set_data_layout(data_layout); + dst.info()->set_data_layout(data_layout); + } + void generate_scale(const TensorShape &shape) { static constexpr float _min_scale{ 0.25f }; @@ -155,9 +172,15 @@ protected: // Fill tensors fill(AccessorType(src)); - // Compute function - scale.run(); - + if(_mixed_layout) + { + mix_layout(scale, src, dst); + } + else + { + // Compute function + scale.run(); + } return dst; } @@ -182,11 +205,12 @@ protected: DataType _data_type{}; QuantizationInfo _quantization_info{}; bool _align_corners{ false }; + bool _mixed_layout{ false }; float _scale_x{ 1.f }; float _scale_y{ 1.f }; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ScaleValidationQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: @@ -201,10 +225,11 @@ public: policy, border_mode, sampling_policy, - align_corners); + align_corners, + mixed_layout); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ScaleValidationFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: @@ -218,7 +243,8 @@ public: policy, border_mode, sampling_policy, - align_corners); + align_corners, + mixed_layout); } }; } // namespace validation |