aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h48
1 files changed, 43 insertions, 5 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index df907c106e..9f9f53ed8b 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1494,15 +1494,53 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
return output_shape;
}
+/** Calculate the gather output shape of a tensor
+ *
+ * @param[in] input_shape Input tensor shape
+ * @param[in] indices_shape Indices tensor shape. Only supports for 2d and 3d indices
+ * @param[in] actual_axis Axis to be used in the computation
+ *
+ * @note Let input_shape be (X,Y,Z) and indices shape (W,O,P) and axis 1
+ * the new shape is computed by replacing the axis in the input shape with
+ * the indice shape so the output shape will be (X,W,O,P,Z)
+ *
+ * @return the calculated shape
+ */
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 1);
ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
-
- TensorShape output_shape = input_shape;
- output_shape[actual_axis] = indices_shape[0];
-
+ ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
+ TensorShape output_shape = input_shape;
+ if(indices_shape.num_dimensions() == 1u)
+ {
+ output_shape[actual_axis] = indices_shape[0];
+ }
+ else
+ {
+ const auto ind_num_dims
+ {
+ indices_shape.num_dimensions()
+ };
+ output_shape.shift_right(ind_num_dims - 1);
+ switch(actual_axis)
+ {
+ case 1:
+ {
+ output_shape[0] = input_shape[0];
+ for(size_t idx = 0; idx < ind_num_dims; ++idx)
+ {
+ output_shape.set(actual_axis + idx, indices_shape[idx], false);
+ }
+ break;
+ }
+ default:
+ {
+ // 2d and 3d indices are only supported for axis == 1
+ ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1);
+ }
+ }
+ }
return output_shape;
}
} // namespace shape_calculator