diff options
Diffstat (limited to 'tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h | 47 |
1 files changed, 39 insertions, 8 deletions
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h index d9806b5c84..0aa43d82b4 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h @@ -59,8 +59,9 @@ public: void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType input_data_type, DataType weights_data_type, QuantizationInfo input_quantization_info, QuantizationInfo weights_quantization_info, QuantizationInfo output_quantization_info, - DataLayout data_layout, ActivationLayerInfo act_info) + DataLayout data_layout, ActivationLayerInfo act_info, bool mixed_layout = false) { + _mixed_layout = mixed_layout; _input_shape = in_shape; _input_data_type = input_data_type; _weights_data_type = weights_data_type; @@ -130,9 +131,16 @@ public: fill(AccessorType(_src), 0); fill(AccessorType(_weights), 1); fill(AccessorType(_biases), 2); - - // Compute function - _dwc.run(); + + if(_mixed_layout) + { + mix_layout(_dwc, _src, _target); + } + else + { + // Compute function + _dwc.run(); + } } void compute_reference() @@ -150,6 +158,21 @@ public: } protected: + + void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) + { + // Test Multi DataLayout graph cases, when the data layout changes after configure + src.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + dst.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + + // Compute Convolution function + layer.run(); + + // Reinstating original data layout for the test suite to properly check the values + src.info()->set_data_layout(_data_layout); + dst.info()->set_data_layout(_data_layout); + } + template <typename U> void fill(U &&tensor, int i) { @@ -214,9 +237,10 @@ protected: ActivationLayerInfo _act_info{}; unsigned int _depth_multiplier{}; Size2D _dilation{}; + bool _mixed_layout{false}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T> { public: @@ -226,7 +250,7 @@ public: { DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier, data_type, data_type, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), - data_layout, act_info); + data_layout, act_info, mixed_layout); } }; @@ -434,8 +458,15 @@ public: fill(AccessorType(_weights), 1); fill(AccessorType(_biases), 2); + // Test Multi DataLayout graph cases, when the data layout changes after configure + _src.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + _target.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + // Compute function _dwc.run(); + + // Reinstating original data layout for the test suite to properly check the values + _target.info()->set_data_layout(_data_layout); } void compute_reference() @@ -496,7 +527,7 @@ protected: unsigned int _n0{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class DepthwiseConvolutionLayerValidationQuantizedFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T> { public: @@ -506,7 +537,7 @@ public: { DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier, data_type, data_type, input_quantization_info, input_quantization_info, output_quantization_info, - data_layout, act_info); + data_layout, act_info, mixed_layout); } }; |