aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h32
1 files changed, 16 insertions, 16 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index a895b58aba..916da1bd9d 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1100,28 +1100,28 @@ inline TensorShape compute_slice_shape(const TensorShape &input_shape, const Coo
/** Calculate the batch to space output shape of a tensor
*
- * @param[in] input Input tensor info
- * @param[in] block_x Block shape x value
- * @param[in] block_y Block shape y value
- * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed
+ * @param[in] data_layout Data layout
+ * @param[in] input Input tensor shape
+ * @param[in] block_x Block shape x value
+ * @param[in] block_y Block shape y value
+ * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed
*
* @return the calculated shape
*/
-inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const int block_x, const int block_y, const CropInfo &crop_info = CropInfo{})
+inline TensorShape compute_batch_to_space_shape(DataLayout data_layout, const TensorShape &input, int block_x, int block_y, const CropInfo &crop_info = CropInfo{})
{
- ARM_COMPUTE_ERROR_ON(block_x <= 0 || block_y <= 0);
+ ARM_COMPUTE_ERROR_ON(block_x < 1 || block_y < 1);
- const DataLayout data_layout = input->data_layout();
- const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
- const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
- TensorShape output_shape{ input->tensor_shape() };
+ TensorShape output_shape{ input };
- auto new_width = input->tensor_shape()[idx_width] * block_x;
- auto new_height = input->tensor_shape()[idx_height] * block_y;
- const auto width_crop = crop_info.left + crop_info.right;
- const auto height_crop = crop_info.top + crop_info.bottom;
+ unsigned int new_width = input[idx_width] * static_cast<unsigned int>(block_x);
+ unsigned int new_height = input[idx_height] * static_cast<unsigned int>(block_y);
+ const unsigned int width_crop = crop_info.left + crop_info.right;
+ const unsigned int height_crop = crop_info.top + crop_info.bottom;
ARM_COMPUTE_ERROR_ON(new_width <= width_crop);
ARM_COMPUTE_ERROR_ON(new_height <= height_crop);
new_width -= width_crop;
@@ -1129,7 +1129,7 @@ inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const
output_shape.set(idx_width, new_width);
output_shape.set(idx_height, new_height);
- output_shape.set(idx_batch, input->tensor_shape()[idx_batch] / (block_x * block_y));
+ output_shape.set(idx_batch, input[idx_batch] / (block_x * block_y));
return output_shape;
}