aboutsummaryrefslogtreecommitdiff
path: root/src/core/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/Utils.cpp')
-rw-r--r--src/core/Utils.cpp32
1 files changed, 21 insertions, 11 deletions
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index 499a6c8b29..5d32750f0d 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -333,17 +333,27 @@ std::string arm_compute::lower_string(const std::string &val)
PadStrideInfo arm_compute::calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout, const Size2D &dilation)
{
- const unsigned int width_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const unsigned int height_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const auto &strides = conv_info.stride();
- const int out_width = std::ceil(float(input_shape[width_idx]) / float(strides.first));
- const int out_height = std::ceil(float(input_shape[height_idx]) / float(strides.second));
- const int pad_width = (out_width - 1) * strides.first + (weights_shape[width_idx] + (dilation.x() - 1) * (weights_shape[width_idx] - 1) - input_shape[width_idx]);
- const int pad_height = (out_height - 1) * strides.second + (weights_shape[height_idx] + (dilation.y() - 1) * (weights_shape[height_idx] - 1) - input_shape[height_idx]);
- const int same_pad_left = pad_width / 2;
- const int same_pad_top = pad_height / 2;
- const int same_pad_right = pad_width - same_pad_left;
- const int same_pad_bottom = pad_height - same_pad_top;
+ const unsigned int width_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const auto &strides = conv_info.stride();
+
+ // Calculate output dimensions
+ const int out_width = (input_shape[width_idx] + strides.first - 1) / strides.first;
+ const int out_height = (input_shape[height_idx] + strides.second - 1) / strides.second;
+
+ // Calculate effective weights sizes
+ const int real_weight_width = (weights_shape[width_idx] - 1) * dilation.x() + 1;
+ const int real_weight_height = (weights_shape[height_idx] - 1) * dilation.y() + 1;
+
+ // Calculate total pad
+ const int pad_width = (out_width - 1) * strides.first + real_weight_width - input_shape[width_idx];
+ const int pad_height = (out_height - 1) * strides.second + real_weight_height - input_shape[height_idx];
+
+ // Calculate individual paddings
+ const int same_pad_left = pad_width / 2;
+ const int same_pad_top = pad_height / 2;
+ const int same_pad_right = pad_width - same_pad_left;
+ const int same_pad_bottom = pad_height - same_pad_top;
return { static_cast<unsigned int>(strides.first),
static_cast<unsigned int>(strides.second),