aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2019-01-30 15:45:42 +0000
committerIsabella Gottardi <isabella.gottardi@arm.com>2019-02-01 17:02:45 +0000
commit6a91440273444b14d158056b5c199e59ca0051e5 (patch)
tree95f82ddf4909d192f1cb70560a30bd4c11346281
parentb7c308a1ad32af4198dcc7eaa73f44fef27dc8fc (diff)
downloadComputeLibrary-6a91440273444b14d158056b5c199e59ca0051e5.tar.gz
COMPMID-1710: Allow NHWC datalayout in SAME pad calculation
Change-Id: Id3788772ad62c7e2d962bc2cb8812b9503e2e836 Signed-off-by: Isabella Gottardi <isabella.gottardi@arm.com> Reviewed-on: https://review.mlplatform.org/603 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: VidhyaSudhan Loganathan <vidhyasudhan.loganathan@arm.com>
-rw-r--r--arm_compute/core/Utils.h3
-rw-r--r--src/core/Utils.cpp25
2 files changed, 16 insertions, 12 deletions
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 816999db51..d112259d39 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -827,10 +827,11 @@ inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &
* @param[in] input_shape Input shape
* @param[in] weights_shape Weights shape
* @param[in] conv_info Convolution information (containing strides)
+ * @param[in] data_layout (Optional) Data layout of the input and weights tensor
*
* @return PadStrideInfo for SAME padding
*/
-PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info);
+PadStrideInfo calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout = DataLayout::NCHW);
/** Returns expected width and height of the deconvolution's output tensor.
*
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index 2df5f81c61..73eaf64228 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -22,8 +22,9 @@
* SOFTWARE.
*/
-#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/Utils.h"
#include "support/ToolchainSupport.h"
#include <algorithm>
@@ -325,17 +326,19 @@ std::string arm_compute::lower_string(const std::string &val)
return res;
}
-PadStrideInfo arm_compute::calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info)
+PadStrideInfo arm_compute::calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout)
{
- const auto &strides = conv_info.stride();
- const int out_width = std::ceil(float(input_shape.x()) / float(strides.first));
- const int out_height = std::ceil(float(input_shape.y()) / float(strides.second));
- const int pad_width = ((out_width - 1) * strides.first + weights_shape.x() - input_shape.x());
- const int pad_height = ((out_height - 1) * strides.second + weights_shape.y() - input_shape.y());
- const int same_pad_left = pad_width / 2;
- const int same_pad_top = pad_height / 2;
- const int same_pad_right = pad_width - same_pad_left;
- const int same_pad_bottom = pad_height - same_pad_top;
+ const unsigned int width_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const auto &strides = conv_info.stride();
+ const int out_width = std::ceil(float(input_shape[width_idx]) / float(strides.first));
+ const int out_height = std::ceil(float(input_shape[height_idx]) / float(strides.second));
+ const int pad_width = ((out_width - 1) * strides.first + weights_shape[width_idx] - input_shape[width_idx]);
+ const int pad_height = ((out_height - 1) * strides.second + weights_shape[height_idx] - input_shape[height_idx]);
+ const int same_pad_left = pad_width / 2;
+ const int same_pad_top = pad_height / 2;
+ const int same_pad_right = pad_width - same_pad_left;
+ const int same_pad_bottom = pad_height - same_pad_top;
return PadStrideInfo(strides.first, strides.second, same_pad_left, same_pad_right, same_pad_top, same_pad_bottom, DimensionRoundingType::CEIL);
}