diff options
Diffstat (limited to 'src/core/Utils.cpp')
-rw-r--r-- | src/core/Utils.cpp | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index d5ce1ea027..0a35e07430 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -288,37 +288,39 @@ const std::pair<unsigned int, unsigned int> arm_compute::scaled_dimensions(unsig unsigned int kernel_width, unsigned int kernel_height, const PadStrideInfo &pad_stride_info) { - const unsigned int pad_x = pad_stride_info.pad().first; - const unsigned int pad_y = pad_stride_info.pad().second; - const unsigned int stride_x = pad_stride_info.stride().first; - const unsigned int stride_y = pad_stride_info.stride().second; - unsigned int w = 0; - unsigned int h = 0; + const unsigned int pad_left = pad_stride_info.pad_left(); + const unsigned int pad_top = pad_stride_info.pad_top(); + const unsigned int pad_right = pad_stride_info.pad_right(); + const unsigned int pad_bottom = pad_stride_info.pad_bottom(); + const unsigned int stride_x = pad_stride_info.stride().first; + const unsigned int stride_y = pad_stride_info.stride().second; + unsigned int w = 0; + unsigned int h = 0; switch(pad_stride_info.round()) { case DimensionRoundingType::FLOOR: - w = static_cast<unsigned int>(std::floor((static_cast<float>(width + 2 * pad_x - kernel_width) / stride_x) + 1)); - h = static_cast<unsigned int>(std::floor((static_cast<float>(height + 2 * pad_y - kernel_height) / stride_y) + 1)); + w = static_cast<unsigned int>(std::floor((static_cast<float>(width + pad_left + pad_right - kernel_width) / stride_x) + 1)); + h = static_cast<unsigned int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - kernel_height) / stride_y) + 1)); break; case DimensionRoundingType::CEIL: - w = static_cast<unsigned int>(std::ceil((static_cast<float>(width + 2 * pad_x - kernel_width) / stride_x) + 1)); - h = static_cast<unsigned int>(std::ceil((static_cast<float>(height + 2 * pad_y - kernel_height) / stride_y) + 1)); + w = static_cast<unsigned int>(std::ceil((static_cast<float>(width + pad_left + pad_right - kernel_width) / stride_x) + 1)); + h = static_cast<unsigned int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - kernel_height) / stride_y) + 1)); break; default: ARM_COMPUTE_ERROR("Unsupported rounding type"); } // Make sure that border operations will start from inside the input and not the padded area - if(((w - 1) * stride_x) >= (width + pad_x)) + if(((w - 1) * stride_x) >= (width + pad_left)) { --w; } - if(((h - 1) * stride_y) >= (height + pad_y)) + if(((h - 1) * stride_y) >= (height + pad_top)) { --h; } - ARM_COMPUTE_ERROR_ON(((w - 1) * stride_x) >= (width + pad_x)); - ARM_COMPUTE_ERROR_ON(((h - 1) * stride_y) >= (height + pad_y)); + ARM_COMPUTE_ERROR_ON(((w - 1) * stride_x) >= (width + pad_left)); + ARM_COMPUTE_ERROR_ON(((h - 1) * stride_y) >= (height + pad_top)); return std::make_pair(w, h); } |