From aadf8466ffc9597f76d53ce7b87e722cff7e72e6 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Thu, 19 Dec 2019 09:35:40 +0000 Subject: COMPMID-2819 Fix depthwise reference when using ceil Change-Id: I4117320fd6ba11365db9a164e4e44509a9a7ba09 Signed-off-by: Giorgio Arena Reviewed-on: https://review.mlplatform.org/c/2498 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- tests/datasets/DepthwiseConvolutionLayerDataset.h | 2 ++ .../reference/DepthwiseConvolutionLayer.cpp | 20 ++++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/datasets/DepthwiseConvolutionLayerDataset.h b/tests/datasets/DepthwiseConvolutionLayerDataset.h index 2990b135d2..014207e4e9 100644 --- a/tests/datasets/DepthwiseConvolutionLayerDataset.h +++ b/tests/datasets/DepthwiseConvolutionLayerDataset.h @@ -126,6 +126,8 @@ public: // Asymmetric padding add_config(TensorShape(33U, 27U, 7U), Size2D(5U, 7U), PadStrideInfo(3, 2, 1, 1, 2, 0, DimensionRoundingType::FLOOR)); add_config(TensorShape(33U, 27U, 7U), Size2D(5U, 7U), PadStrideInfo(3, 2, 1, 1, 0, 2, DimensionRoundingType::FLOOR)); + // Ceil rounding + add_config(TensorShape(7U, 8U, 5U, 9U), Size2D(8U, 6U), PadStrideInfo(2, 3, 1, 1, 1, 3, DimensionRoundingType::CEIL), Size2D(1U, 2U)); } }; diff --git a/tests/validation/reference/DepthwiseConvolutionLayer.cpp b/tests/validation/reference/DepthwiseConvolutionLayer.cpp index 0c7e92b8d0..4245140373 100644 --- a/tests/validation/reference/DepthwiseConvolutionLayer.cpp +++ b/tests/validation/reference/DepthwiseConvolutionLayer.cpp @@ -67,10 +67,8 @@ SimpleTensor depthwise_convolution_fp(const SimpleTensor &src, const Simpl const int input_depth = src.shape().z(); const int num_batches = src.shape().total_size() / (input_width * input_height * input_depth); - const int pad_left = conv_info.pad_left(); - const int pad_top = conv_info.pad_top(); - const int pad_right = conv_info.pad_right(); - const int pad_bottom = conv_info.pad_bottom(); + const int pad_left = conv_info.pad_left(); + const int pad_top = conv_info.pad_top(); const float patch_width = (filter_width + (dilation.x() - 1) * (filter_width - 1)); const float patch_height = (filter_height + (dilation.y() - 1) * (filter_height - 1)); @@ -83,8 +81,8 @@ SimpleTensor depthwise_convolution_fp(const SimpleTensor &src, const Simpl const int minimum_x = -pad_left + patch_half_width_floor; const int minimum_y = -pad_top + patch_half_height_floor; - const int maximum_x = input_width + pad_left + pad_right - static_cast(patch_width); - const int maximum_y = input_height + pad_top + pad_bottom - static_cast(patch_height); + const int maximum_x = (conv_info.stride().first * (dst_shape[0] - 1)); + const int maximum_y = (conv_info.stride().second * (dst_shape[1] - 1)); const T border_value(0); @@ -162,10 +160,8 @@ SimpleTensor depthwise_convolution_quantized(const SimpleTensor &src, cons const int input_depth = src.shape().z(); const int num_batches = src.shape().total_size() / (input_width * input_height * input_depth); - const int pad_left = conv_info.pad_left(); - const int pad_top = conv_info.pad_top(); - const int pad_right = conv_info.pad_right(); - const int pad_bottom = conv_info.pad_bottom(); + const int pad_left = conv_info.pad_left(); + const int pad_top = conv_info.pad_top(); const float patch_width = (filter_width + (dilation.x() - 1) * (filter_width - 1)); const float patch_height = (filter_height + (dilation.y() - 1) * (filter_height - 1)); @@ -178,8 +174,8 @@ SimpleTensor depthwise_convolution_quantized(const SimpleTensor &src, cons const int minimum_x = -pad_left + patch_half_width_floor; const int minimum_y = -pad_top + patch_half_height_floor; - const int maximum_x = input_width + pad_left + pad_right - static_cast(patch_width); - const int maximum_y = input_height + pad_top + pad_bottom - static_cast(patch_height); + const int maximum_x = (conv_info.stride().first * (dst_shape[0] - 1)); + const int maximum_y = (conv_info.stride().second * (dst_shape[1] - 1)); const bool is_quantized_per_channel = is_data_type_quantized_per_channel(weights.data_type()); -- cgit v1.2.1