aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2021-04-08 12:02:58 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2021-04-19 13:45:08 +0000
commit60c3b0e6821a80d78ffca5be30e05d062d071cd2 (patch)
tree3e263a45aa9617cfd7704b2b33ea4337f1582321 /arm_compute/core
parent4f1650f0c9919f0bac5024b8e31c0f754d25aec3 (diff)
downloadComputeLibrary-60c3b0e6821a80d78ffca5be30e05d062d071cd2.tar.gz
Port DepthwiseConvolution to new API
Resolves: COMPMID-4185 Change-Id: Ib5f22356356a022d567bb18d44ea272b62d10ebf Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5424 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core')
-rw-r--r--arm_compute/core/Types.h13
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h15
2 files changed, 19 insertions, 9 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 53333ff608..b1f340d18e 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1861,6 +1861,19 @@ private:
bool _broadcast_bias;
};
+struct ConvolutionInfo
+{
+ ConvolutionInfo() = default;
+ ConvolutionInfo(const PadStrideInfo &pad_stride_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation)
+ : pad_stride_info(pad_stride_info), depth_multiplier(depth_multiplier), act_info(act_info), dilation(dilation)
+ {
+ }
+ PadStrideInfo pad_stride_info{}; /**< Convolution info (Pads, strides,...) */
+ unsigned int depth_multiplier{ 1 }; /**< Multiplier to apply to input's depth to retrieve the output depth. Defaults to 1 */
+ ActivationLayerInfo act_info{}; /**< Fused activation to apply after convolution. */
+ Size2D dilation{ Size2D(1, 1) }; /**< Dilation, in elements, across x and y. Defaults to (1, 1). */
+};
+
struct DepthwiseConvolutionReshapeInfo
{
unsigned int c0{ 1 }; /**< Number of channels processed by the depth-wise convolution */
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 56038dd853..ba37f9a61e 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -435,16 +435,13 @@ inline TensorShape compute_transposed_shape(const ITensorInfo &input)
/** Calculate the depthwise convolution output shape of a tensor
*
- * @param[in] input Input tensor info
- * @param[in] weights Weights tensor info
- * @param[in] conv_info Padding and stride information to use for the convolution.
- * @param[in] depth_multiplier Multiplier to apply to the input's depth in order to retrieve the output's depth.
- * @param[in] dilation Dilation, in elements, across x and y. Defaults to (1, 1).
+ * @param[in] input Input tensor info
+ * @param[in] weights Weights tensor info
+ * @param[in] info Convolution info
*
* @return the calculated shape
*/
-inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, PadStrideInfo conv_info, unsigned int depth_multiplier, const Size2D &dilation = Size2D(1U,
- 1U))
+inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, const ConvolutionInfo &info)
{
const TensorShape input_shape{ input.tensor_shape() };
const TensorShape weights_shape{ weights.tensor_shape() };
@@ -462,12 +459,12 @@ inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input,
unsigned int output_height = 0;
std::tie(output_width, output_height) = scaled_dimensions(input_shape[width_idx], input_shape[height_idx],
weights_shape[weights_width_idx], weights_shape[weights_height_idx],
- conv_info, dilation);
+ info.pad_stride_info, info.dilation);
TensorShape output_shape{ input_shape };
output_shape.set(width_idx, output_width);
output_shape.set(height_idx, output_height);
- output_shape.set(channel_idx, input_shape[channel_idx] * depth_multiplier);
+ output_shape.set(channel_idx, input_shape[channel_idx] * info.depth_multiplier);
return output_shape;
}