aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2023-03-13 15:02:23 +0000
committerSiCong Li <sicong.li@arm.com>2023-03-14 15:38:29 +0000
commit4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4 (patch)
tree13d56b417d5c2b186bde627f4f5d0f05b7228a53 /arm_compute/core/utils
parentaaa9da1efa83911c7a67d50811ad669a92a7d12f (diff)
downloadComputeLibrary-4ceb453b00185ded5ddbaf83d40eadeb2ed28ec4.tar.gz
Add CropInfo to BatchToSpace reference and fixture
Partially resolves COMPMID-5918, COMPMID-5865 Signed-off-by: SiCong Li <sicong.li@arm.com> Change-Id: Ib3b01e7dc1c944184a4c038045bf0469fbb9ff45 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9321 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h27
1 files changed, 19 insertions, 8 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 94bd3aca03..6655cc1439 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1072,13 +1072,14 @@ inline TensorShape compute_slice_shape(const TensorShape &input_shape, const Coo
/** Calculate the batch to space output shape of a tensor
*
- * @param[in] input Input tensor info
- * @param[in] block_x Block shape x value
- * @param[in] block_y Block shape y value
+ * @param[in] input Input tensor info
+ * @param[in] block_x Block shape x value
+ * @param[in] block_y Block shape y value
+ * @param[in] crop_info Information about how the output shape is cropped after batch to space is performed
*
* @return the calculated shape
*/
-inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const int block_x, const int block_y)
+inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const int block_x, const int block_y, const CropInfo &crop_info = CropInfo{})
{
ARM_COMPUTE_ERROR_ON(block_x <= 0 || block_y <= 0);
@@ -1088,8 +1089,18 @@ inline TensorShape compute_batch_to_space_shape(const ITensorInfo *input, const
const int idx_batch = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
TensorShape output_shape{ input->tensor_shape() };
- output_shape.set(idx_width, input->tensor_shape()[idx_width] * block_x);
- output_shape.set(idx_height, input->tensor_shape()[idx_height] * block_y);
+
+ auto new_width = input->tensor_shape()[idx_width] * block_x;
+ auto new_height = input->tensor_shape()[idx_height] * block_y;
+ const auto width_crop = crop_info.left + crop_info.right;
+ const auto height_crop = crop_info.top + crop_info.bottom;
+ ARM_COMPUTE_ERROR_ON(new_width <= width_crop);
+ ARM_COMPUTE_ERROR_ON(new_height <= height_crop);
+ new_width -= width_crop;
+ new_height -= height_crop;
+
+ output_shape.set(idx_width, new_width);
+ output_shape.set(idx_height, new_height);
output_shape.set(idx_batch, input->tensor_shape()[idx_batch] / (block_x * block_y));
return output_shape;
@@ -1537,14 +1548,14 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
*/
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- const auto input_num_dims = input_shape.num_dimensions();
+ const auto input_num_dims = input_shape.num_dimensions();
const auto indices_num_dims = indices_shape.num_dimensions();
ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims);
ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions);
TensorShape output_shape;
- size_t dim_no = 0;
+ size_t dim_no = 0;
for(; dim_no < actual_axis; ++dim_no)
{