diff options
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index a4db49fc8e..07790e84d9 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -69,8 +69,9 @@ public: public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, - DataType data_type, DataType weights_data_type, DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info, ActivationLayerInfo act_info) + DataType data_type, DataType weights_data_type, DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info, ActivationLayerInfo act_info, bool mixed_layout = false) { + _mixed_layout = mixed_layout; _data_type = data_type; _weights_data_type = weights_data_type; _is_quantized = is_data_type_quantized_asymmetric(data_type); @@ -86,6 +87,21 @@ public: } protected: + + void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) + { + // Test Multi DataLayout graph cases, when the data layout changes after configure + src.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + dst.info()->set_data_layout(_data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); + + // Compute Convolution function + layer.run(); + + // Reinstating original data layout for the test suite to properly check the values + src.info()->set_data_layout(_data_layout); + dst.info()->set_data_layout(_data_layout); + } + void regularize_values(void *values, size_t size) { float *fvalues = static_cast<float *>(values); @@ -214,8 +230,15 @@ protected: fill(AccessorType(weights), 1); fill(AccessorType(bias), 2); - // Compute NEConvolutionLayer function - conv.run(); + if(_mixed_layout) + { + mix_layout(conv, src, dst); + } + else + { + // Compute Convolution function + conv.run(); + } return dst; } @@ -264,9 +287,10 @@ protected: QuantizationInfo _weight_quantization_info{}; bool _is_quantized = false; bool _is_bfloat16 = false; + bool _mixed_layout = false; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T> { public: @@ -276,11 +300,11 @@ public: { ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, data_type, data_layout, - QuantizationInfo(), QuantizationInfo(), act_info); + QuantizationInfo(), QuantizationInfo(), act_info, mixed_layout); } }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T> { public: @@ -289,7 +313,7 @@ public: DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info) { ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, - data_type, data_type, data_layout, quantization_info, quantization_info, act_info); + data_type, data_type, data_layout, quantization_info, quantization_info, act_info, mixed_layout); } }; |