From 1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 26 Mar 2018 16:20:05 +0100 Subject: COMPMID-812 Add NHWC data format support for NEON depthwise convolution (optimized case). Change-Id: Icdfd6c02ed526daf4f59a4b76c7bbc1bc48fde74 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125938 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../fixtures/DepthwiseConvolutionLayerFixture.h | 37 +++++++++++++--------- 1 file changed, 22 insertions(+), 15 deletions(-) (limited to 'tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h index df5436fcf7..ccdd443999 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h @@ -52,15 +52,22 @@ public: public: template - void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info) + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout) { _quantization_info = quantization_info; _data_type = data_type; const TensorShape biases_shape(weights_shape[2]); const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; - _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info); - _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info); + if(data_layout == DataLayout::NHWC) + { + permute(in_shape, PermutationVector(2U, 0U, 1U)); + permute(weights_shape, PermutationVector(2U, 0U, 1U)); + permute(out_shape, PermutationVector(2U, 0U, 1U)); + } + + _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout); + _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout); } protected: @@ -94,13 +101,13 @@ protected: } TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info, - const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info) + const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout) { // Create tensors - TensorType src = create_tensor(input_shape, data_type, 1, 0, quantization_info); - TensorType weights = create_tensor(weights_shape, data_type, 1, 0, quantization_info); - TensorType biases = create_tensor(biases_shape, bias_data_type, 1, 0, quantization_info); - TensorType dst = create_tensor(output_shape, data_type, 1, 0, quantization_info); + TensorType src = create_tensor(input_shape, data_type, 1, 0, quantization_info, data_layout); + TensorType weights = create_tensor(weights_shape, data_type, 1, 0, quantization_info, data_layout); + TensorType biases = create_tensor(biases_shape, bias_data_type, 1, 0, quantization_info, data_layout); + TensorType dst = create_tensor(output_shape, data_type, 1, 0, quantization_info, data_layout); // Create Depthwise Convolution configure function FunctionType dwc; @@ -134,11 +141,11 @@ protected: } SimpleTensor compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info, - const DataType data_type, const DataType bias_data_type, QuantizationInfo quantization_info) + const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout) { - SimpleTensor src{ in_shape, data_type, 1, 0, quantization_info }; - SimpleTensor weights{ weights_shape, data_type, 1, 0, quantization_info }; - SimpleTensor biases{ biases_shape, bias_data_type, 1, 0, quantization_info }; + SimpleTensor src{ in_shape, data_type, 1, 0, quantization_info, data_layout }; + SimpleTensor weights{ weights_shape, data_type, 1, 0, quantization_info, data_layout }; + SimpleTensor biases{ biases_shape, bias_data_type, 1, 0, quantization_info, data_layout }; fill(src, 0); fill(weights, 1); @@ -158,10 +165,10 @@ class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLa { public: template - void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type) + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, DataLayout data_layout) { DepthwiseConvolutionLayerValidationGenericFixture::setup(in_shape, weights_shape, out_shape, pad_stride_info, - data_type, QuantizationInfo()); + data_type, QuantizationInfo(), data_layout); } }; @@ -173,7 +180,7 @@ public: void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info) { DepthwiseConvolutionLayerValidationGenericFixture::setup(in_shape, weights_shape, out_shape, pad_stride_info, - data_type, quantization_info); + data_type, quantization_info, DataLayout::NCHW); } }; } // namespace validation -- cgit v1.2.1