aboutsummaryrefslogtreecommitdiff
path: root/src/core/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/Utils.cpp')
-rw-r--r--src/core/Utils.cpp51
1 files changed, 29 insertions, 22 deletions
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index 5d32750f0d..d0bffdf660 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -331,37 +331,44 @@ std::string arm_compute::lower_string(const std::string &val)
return res;
}
-PadStrideInfo arm_compute::calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout, const Size2D &dilation)
+PadStrideInfo arm_compute::calculate_same_pad(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info, DataLayout data_layout, const Size2D &dilation,
+ const DimensionRoundingType &rounding_type)
{
- const unsigned int width_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const unsigned int height_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const auto &strides = conv_info.stride();
+ const unsigned int width_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = arm_compute::get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int in_width = input_shape[width_idx];
+ const unsigned int in_height = input_shape[height_idx];
+ const unsigned int kernel_width = weights_shape[width_idx];
+ const unsigned int kernel_height = weights_shape[height_idx];
+ const auto &strides = conv_info.stride();
// Calculate output dimensions
- const int out_width = (input_shape[width_idx] + strides.first - 1) / strides.first;
- const int out_height = (input_shape[height_idx] + strides.second - 1) / strides.second;
+ const auto is_ceil = static_cast<unsigned int>(rounding_type == DimensionRoundingType::CEIL);
+ const unsigned int out_width = ((in_width - is_ceil) + strides.first - 1) / strides.first + is_ceil;
+ const unsigned int out_height = ((in_height - is_ceil) + strides.second - 1) / strides.second + is_ceil;
// Calculate effective weights sizes
- const int real_weight_width = (weights_shape[width_idx] - 1) * dilation.x() + 1;
- const int real_weight_height = (weights_shape[height_idx] - 1) * dilation.y() + 1;
+ const int real_weight_width = (kernel_width - 1) * dilation.x() + 1;
+ const int real_weight_height = (kernel_height - 1) * dilation.y() + 1;
// Calculate total pad
- const int pad_width = (out_width - 1) * strides.first + real_weight_width - input_shape[width_idx];
- const int pad_height = (out_height - 1) * strides.second + real_weight_height - input_shape[height_idx];
+ const int pad_width = std::max(0, static_cast<int>((out_width - 1) * strides.first + real_weight_width - in_width));
+ const int pad_height = std::max(0, static_cast<int>((out_height - 1) * strides.second + real_weight_height - in_height));
// Calculate individual paddings
- const int same_pad_left = pad_width / 2;
- const int same_pad_top = pad_height / 2;
- const int same_pad_right = pad_width - same_pad_left;
- const int same_pad_bottom = pad_height - same_pad_top;
-
- return { static_cast<unsigned int>(strides.first),
- static_cast<unsigned int>(strides.second),
- static_cast<unsigned int>(same_pad_left),
- static_cast<unsigned int>(same_pad_right),
- static_cast<unsigned int>(same_pad_top),
- static_cast<unsigned int>(same_pad_bottom),
- DimensionRoundingType::CEIL };
+ const unsigned int pad_left = pad_width / 2;
+ const unsigned int pad_top = pad_height / 2;
+ const unsigned int pad_right = pad_width - pad_left;
+ const unsigned int pad_bottom = pad_height - pad_top;
+
+ PadStrideInfo same_info(strides.first, strides.second, pad_left, pad_right, pad_top, pad_bottom, rounding_type);
+
+ // Check for correctness of predicted output shape against the one calculated using the generated info
+ const auto out_dims = scaled_dimensions(in_width, in_height, kernel_width, kernel_height, same_info, dilation);
+ ARM_COMPUTE_ERROR_ON(out_dims.first != out_width || out_dims.second != out_height);
+ ARM_COMPUTE_UNUSED(out_dims);
+
+ return same_info;
}
std::pair<unsigned int, unsigned int> arm_compute::deconvolution_output_dimensions(