diff options
Diffstat (limited to 'src/core/Utils.cpp')
-rw-r--r-- | src/core/Utils.cpp | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index f930804ce3..59dc3652fa 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -399,34 +399,38 @@ std::pair<unsigned int, unsigned int> arm_compute::deconvolution_output_dimensio return std::make_pair<unsigned int, unsigned int>(w, h); } -std::pair<unsigned int, unsigned int> arm_compute::scaled_dimensions(unsigned int width, unsigned int height, - unsigned int kernel_width, unsigned int kernel_height, +std::pair<unsigned int, unsigned int> arm_compute::scaled_dimensions(int width, int height, + int kernel_width, int kernel_height, const PadStrideInfo &pad_stride_info, const Size2D &dilation) { - const unsigned int pad_left = pad_stride_info.pad_left(); - const unsigned int pad_top = pad_stride_info.pad_top(); - const unsigned int pad_right = pad_stride_info.pad_right(); - const unsigned int pad_bottom = pad_stride_info.pad_bottom(); - const unsigned int stride_x = pad_stride_info.stride().first; - const unsigned int stride_y = pad_stride_info.stride().second; - unsigned int w = 0; - unsigned int h = 0; + const int dilation_x = dilation.x(); + const int dilation_y = dilation.y(); + const int pad_left = pad_stride_info.pad_left(); + const int pad_top = pad_stride_info.pad_top(); + const int pad_right = pad_stride_info.pad_right(); + const int pad_bottom = pad_stride_info.pad_bottom(); + const int stride_x = pad_stride_info.stride().first; + const int stride_y = pad_stride_info.stride().second; + int w = 0; + int h = 0; switch(pad_stride_info.round()) { case DimensionRoundingType::FLOOR: - w = static_cast<unsigned int>(std::floor((static_cast<float>(width + pad_left + pad_right - (dilation.x() * (kernel_width - 1) + 1)) / stride_x) + 1)); - h = static_cast<unsigned int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - (dilation.y() * (kernel_height - 1) + 1)) / stride_y) + 1)); + w = static_cast<int>(std::floor((static_cast<float>(width + pad_left + pad_right - (dilation_x * (kernel_width - 1) + 1)) / stride_x) + 1)); + h = static_cast<int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - (dilation_y * (kernel_height - 1) + 1)) / stride_y) + 1)); break; case DimensionRoundingType::CEIL: - w = static_cast<unsigned int>(std::ceil((static_cast<float>(width + pad_left + pad_right - (dilation.x() * (kernel_width - 1) + 1)) / stride_x) + 1)); - h = static_cast<unsigned int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - (dilation.y() * (kernel_height - 1) + 1)) / stride_y) + 1)); + w = static_cast<int>(std::ceil((static_cast<float>(width + pad_left + pad_right - (dilation_x * (kernel_width - 1) + 1)) / stride_x) + 1)); + h = static_cast<int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - (dilation_y * (kernel_height - 1) + 1)) / stride_y) + 1)); break; default: ARM_COMPUTE_ERROR("Unsupported rounding type"); } - return std::make_pair(w, h); + w = std::max(1, w); + h = std::max(1, h); + return std::make_pair(static_cast<unsigned int>(w), static_cast<unsigned int>(h)); } bool arm_compute::needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis) |