diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2021-03-23 11:50:34 +0000 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2021-04-06 11:28:16 +0000 |
commit | ca62c6f53eb7244e6fed9f7e932608aa2496d9eb (patch) | |
tree | e5c7630c40d9f009e9baef4e849c6c7cc6ca90a7 /tests/validation/fixtures/FullyConnectedLayerFixture.h | |
parent | 4ed7b39dbbe8ccc6267a9eacefca51717c3b3e10 (diff) | |
download | ComputeLibrary-ca62c6f53eb7244e6fed9f7e932608aa2496d9eb.tar.gz |
Mixed data-layout testing on high priority operators
Change data layouts after the configure in validation tests for:
- Scale
- Pooling
- FullyConnected
- DepthwiseConvolution
- DirectConvolution
- FFTConvolution
- WinogradConvolution
- GEMMConvolution (Indirect GEMM included)
Extending fixtures
Fixes for new mixed data layout tests
Resolves: COMPMID-4162
Change-Id: I2f2eb2075f7e24ab3872249d88cadb57b82c5dde
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5326
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/FullyConnectedLayerFixture.h | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 3760cfb8b7..8f38aae187 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -56,11 +56,12 @@ public: public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, - DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo activation_info) + DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo activation_info, bool mixed_layout = false) { ARM_COMPUTE_UNUSED(weights_shape); ARM_COMPUTE_UNUSED(bias_shape); + _mixed_layout = mixed_layout; _data_type = data_type; _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; _quantization_info = quantization_info; @@ -71,6 +72,22 @@ public: } protected: + + void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) + { + const DataLayout data_layout = src.info()->data_layout(); + // Test Multi DataLayout graph cases, when the data layout changes after configure + src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + + // Compute Convolution function + layer.run(); + + // Reinstating original data layout for the test suite to properly check the values + src.info()->set_data_layout(data_layout); + dst.info()->set_data_layout(data_layout); + } + template <typename U> void fill(U &&tensor, int i) { @@ -189,8 +206,15 @@ protected: fill(AccessorType(weights), 1); } - // Compute NEFullyConnectedLayer function - fc.run(); + if(_mixed_layout) + { + mix_layout(fc, src, dst); + } + else + { + // Compute NEFullyConnectedLayer function + fc.run(); + } return dst; } @@ -214,11 +238,12 @@ protected: SimpleTensor<T> _reference{}; DataType _data_type{}; DataType _bias_data_type{}; + bool _mixed_layout{false}; QuantizationInfo _quantization_info{}; ActivationLayerInfo _activation_info{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: @@ -228,11 +253,11 @@ public: { FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - QuantizationInfo(), activation_info); + QuantizationInfo(), activation_info, mixed_layout); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> { public: @@ -242,7 +267,7 @@ public: { FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - quantization_info, activation_info); + quantization_info, activation_info, mixed_layout); } }; } // namespace validation |