diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2021-04-08 12:02:58 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2021-04-19 13:45:08 +0000 |
commit | 60c3b0e6821a80d78ffca5be30e05d062d071cd2 (patch) | |
tree | 3e263a45aa9617cfd7704b2b33ea4337f1582321 /arm_compute/core | |
parent | 4f1650f0c9919f0bac5024b8e31c0f754d25aec3 (diff) | |
download | ComputeLibrary-60c3b0e6821a80d78ffca5be30e05d062d071cd2.tar.gz |
Port DepthwiseConvolution to new API
Resolves: COMPMID-4185
Change-Id: Ib5f22356356a022d567bb18d44ea272b62d10ebf
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5424
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core')
-rw-r--r-- | arm_compute/core/Types.h | 13 | ||||
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 15 |
2 files changed, 19 insertions, 9 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 53333ff608..b1f340d18e 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1861,6 +1861,19 @@ private: bool _broadcast_bias; }; +struct ConvolutionInfo +{ + ConvolutionInfo() = default; + ConvolutionInfo(const PadStrideInfo &pad_stride_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) + : pad_stride_info(pad_stride_info), depth_multiplier(depth_multiplier), act_info(act_info), dilation(dilation) + { + } + PadStrideInfo pad_stride_info{}; /**< Convolution info (Pads, strides,...) */ + unsigned int depth_multiplier{ 1 }; /**< Multiplier to apply to input's depth to retrieve the output depth. Defaults to 1 */ + ActivationLayerInfo act_info{}; /**< Fused activation to apply after convolution. */ + Size2D dilation{ Size2D(1, 1) }; /**< Dilation, in elements, across x and y. Defaults to (1, 1). */ +}; + struct DepthwiseConvolutionReshapeInfo { unsigned int c0{ 1 }; /**< Number of channels processed by the depth-wise convolution */ diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 56038dd853..ba37f9a61e 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -435,16 +435,13 @@ inline TensorShape compute_transposed_shape(const ITensorInfo &input) /** Calculate the depthwise convolution output shape of a tensor * - * @param[in] input Input tensor info - * @param[in] weights Weights tensor info - * @param[in] conv_info Padding and stride information to use for the convolution. - * @param[in] depth_multiplier Multiplier to apply to the input's depth in order to retrieve the output's depth. - * @param[in] dilation Dilation, in elements, across x and y. Defaults to (1, 1). + * @param[in] input Input tensor info + * @param[in] weights Weights tensor info + * @param[in] info Convolution info * * @return the calculated shape */ -inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, PadStrideInfo conv_info, unsigned int depth_multiplier, const Size2D &dilation = Size2D(1U, - 1U)) +inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, const ConvolutionInfo &info) { const TensorShape input_shape{ input.tensor_shape() }; const TensorShape weights_shape{ weights.tensor_shape() }; @@ -462,12 +459,12 @@ inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input, unsigned int output_height = 0; std::tie(output_width, output_height) = scaled_dimensions(input_shape[width_idx], input_shape[height_idx], weights_shape[weights_width_idx], weights_shape[weights_height_idx], - conv_info, dilation); + info.pad_stride_info, info.dilation); TensorShape output_shape{ input_shape }; output_shape.set(width_idx, output_width); output_shape.set(height_idx, output_height); - output_shape.set(channel_idx, input_shape[channel_idx] * depth_multiplier); + output_shape.set(channel_idx, input_shape[channel_idx] * info.depth_multiplier); return output_shape; } |