aboutsummaryrefslogtreecommitdiff
path: root/src/core/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/Utils.cpp')
-rw-r--r--src/core/Utils.cpp34
1 files changed, 19 insertions, 15 deletions
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index f930804ce3..59dc3652fa 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -399,34 +399,38 @@ std::pair<unsigned int, unsigned int> arm_compute::deconvolution_output_dimensio
return std::make_pair<unsigned int, unsigned int>(w, h);
}
-std::pair<unsigned int, unsigned int> arm_compute::scaled_dimensions(unsigned int width, unsigned int height,
- unsigned int kernel_width, unsigned int kernel_height,
+std::pair<unsigned int, unsigned int> arm_compute::scaled_dimensions(int width, int height,
+ int kernel_width, int kernel_height,
const PadStrideInfo &pad_stride_info,
const Size2D &dilation)
{
- const unsigned int pad_left = pad_stride_info.pad_left();
- const unsigned int pad_top = pad_stride_info.pad_top();
- const unsigned int pad_right = pad_stride_info.pad_right();
- const unsigned int pad_bottom = pad_stride_info.pad_bottom();
- const unsigned int stride_x = pad_stride_info.stride().first;
- const unsigned int stride_y = pad_stride_info.stride().second;
- unsigned int w = 0;
- unsigned int h = 0;
+ const int dilation_x = dilation.x();
+ const int dilation_y = dilation.y();
+ const int pad_left = pad_stride_info.pad_left();
+ const int pad_top = pad_stride_info.pad_top();
+ const int pad_right = pad_stride_info.pad_right();
+ const int pad_bottom = pad_stride_info.pad_bottom();
+ const int stride_x = pad_stride_info.stride().first;
+ const int stride_y = pad_stride_info.stride().second;
+ int w = 0;
+ int h = 0;
switch(pad_stride_info.round())
{
case DimensionRoundingType::FLOOR:
- w = static_cast<unsigned int>(std::floor((static_cast<float>(width + pad_left + pad_right - (dilation.x() * (kernel_width - 1) + 1)) / stride_x) + 1));
- h = static_cast<unsigned int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - (dilation.y() * (kernel_height - 1) + 1)) / stride_y) + 1));
+ w = static_cast<int>(std::floor((static_cast<float>(width + pad_left + pad_right - (dilation_x * (kernel_width - 1) + 1)) / stride_x) + 1));
+ h = static_cast<int>(std::floor((static_cast<float>(height + pad_top + pad_bottom - (dilation_y * (kernel_height - 1) + 1)) / stride_y) + 1));
break;
case DimensionRoundingType::CEIL:
- w = static_cast<unsigned int>(std::ceil((static_cast<float>(width + pad_left + pad_right - (dilation.x() * (kernel_width - 1) + 1)) / stride_x) + 1));
- h = static_cast<unsigned int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - (dilation.y() * (kernel_height - 1) + 1)) / stride_y) + 1));
+ w = static_cast<int>(std::ceil((static_cast<float>(width + pad_left + pad_right - (dilation_x * (kernel_width - 1) + 1)) / stride_x) + 1));
+ h = static_cast<int>(std::ceil((static_cast<float>(height + pad_top + pad_bottom - (dilation_y * (kernel_height - 1) + 1)) / stride_y) + 1));
break;
default:
ARM_COMPUTE_ERROR("Unsupported rounding type");
}
- return std::make_pair(w, h);
+ w = std::max(1, w);
+ h = std::max(1, h);
+ return std::make_pair(static_cast<unsigned int>(w), static_cast<unsigned int>(h));
}
bool arm_compute::needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis)