From e73686ac797be2d19cd9bed26d690e1431e3d848 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 8 Apr 2019 17:30:48 +0100 Subject: COMPMID-2047: Add support for dilation in CLDepthwiseConvolution. Change-Id: I3106aa34bd168985a56791613d95072756be6e9b Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/958 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- .../fixtures/DepthwiseConvolutionLayerFixture.h | 29 ++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) (limited to 'tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h index 5428154a2b..dd8bf232be 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -55,7 +55,8 @@ public: public: template - void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, unsigned int depth_multiplier, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout) + void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType data_type, QuantizationInfo quantization_info, + DataLayout data_layout) { _quantization_info = quantization_info; _data_type = data_type; @@ -65,13 +66,13 @@ public: const TensorInfo in_info(in_shape, 1, data_type); const TensorInfo we_info(weights_shape, 1, data_type); - TensorShape out_shape = compute_depthwise_convolution_shape(in_info, we_info, pad_stride_info, depth_multiplier); + const TensorShape out_shape = compute_depthwise_convolution_shape(in_info, we_info, pad_stride_info, depth_multiplier, dilation); weights_shape.set(2, out_shape.z()); const TensorShape biases_shape(weights_shape[2]); - _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, depth_multiplier, data_type, bias_data_type, quantization_info, data_layout); - _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, depth_multiplier, data_type, bias_data_type, quantization_info); + _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, dilation, depth_multiplier, data_type, bias_data_type, quantization_info, data_layout); + _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, dilation, depth_multiplier, data_type, bias_data_type, quantization_info); } protected: @@ -104,7 +105,8 @@ protected: } } - TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape output_shape, PadStrideInfo &pad_stride_info, unsigned int depth_multiplier, + TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape output_shape, PadStrideInfo &pad_stride_info, Size2D dilation, + unsigned int depth_multiplier, const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout) { if(data_layout == DataLayout::NHWC) @@ -122,7 +124,7 @@ protected: // Create Depthwise Convolution configure function FunctionType dwc; - dwc.configure(&src, &weights, &biases, &dst, pad_stride_info, depth_multiplier); + dwc.configure(&src, &weights, &biases, &dst, pad_stride_info, depth_multiplier, ActivationLayerInfo(), dilation); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -152,7 +154,7 @@ protected: } SimpleTensor compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info, - unsigned int depth_multiplier, + const Size2D &dilation, unsigned int depth_multiplier, const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info) { SimpleTensor src{ in_shape, data_type, 1, quantization_info }; @@ -163,7 +165,7 @@ protected: fill(weights, 1); fill(biases, 2); - return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info, depth_multiplier); + return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info, depth_multiplier, dilation); } TensorType _target{}; @@ -177,9 +179,9 @@ class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLa { public: template - void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, unsigned int depth_multiplier, DataType data_type, DataLayout data_layout) + void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType data_type, DataLayout data_layout) { - DepthwiseConvolutionLayerValidationGenericFixture::setup(in_shape, kernel_size, pad_stride_info, depth_multiplier, + DepthwiseConvolutionLayerValidationGenericFixture::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier, data_type, QuantizationInfo(), data_layout); } }; @@ -189,9 +191,10 @@ class DepthwiseConvolutionLayerValidationQuantizedFixture : public DepthwiseConv { public: template - void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, unsigned int depth_multiplier, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout) + void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType data_type, QuantizationInfo quantization_info, + DataLayout data_layout) { - DepthwiseConvolutionLayerValidationGenericFixture::setup(in_shape, kernel_size, pad_stride_info, depth_multiplier, + DepthwiseConvolutionLayerValidationGenericFixture::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier, data_type, quantization_info, data_layout); } }; -- cgit v1.2.1