aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
authorPablo Marquez Tello <pablo.tello@arm.com>2022-04-27 11:46:31 +0100
committerPablo Marquez Tello <pablo.tello@arm.com>2022-05-10 09:48:59 +0000
commit920f2b6c2070f6328891e26538e8bcad63e2a79c (patch)
treedfa769580dd15083c6690b7b4019ad23948f8f36 /arm_compute/core/utils
parent06adbc56e9c4a7947e6bc843da6687b3ff357de4 (diff)
downloadComputeLibrary-920f2b6c2070f6328891e26538e8bcad63e2a79c.tar.gz
Add support for 2d and 3d indices for axis 0
* Partially resolves COMPMID-5055 Change-Id: Id05374b8c69e6b9ab4c2790a4de93d7172063b71 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com> Change-Id: Ic6e2c2d1d34abbf6222c8d56859514e267447266 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7488 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h19
1 files changed, 15 insertions, 4 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index df907c106e..aa51ad209a 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1496,13 +1496,24 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 1);
ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
- TensorShape output_shape = input_shape;
- output_shape[actual_axis] = indices_shape[0];
-
+ TensorShape output_shape = input_shape;
+ if(indices_shape.num_dimensions() == 1u)
+ {
+ output_shape[actual_axis] = indices_shape[0];
+ }
+ else
+ {
+ const auto inddims{ indices_shape.num_dimensions() };
+ output_shape.shift_right(indices_shape.num_dimensions() - 1);
+ output_shape[0] = input_shape[0];
+ for(size_t idx(1); (idx - 1) < inddims; ++idx)
+ {
+ output_shape.set(actual_axis + idx, indices_shape[idx - 1], false);
+ }
+ }
return output_shape;
}
} // namespace shape_calculator