diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 4756ff4f97..ba0d8e254d 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -874,6 +874,18 @@ inline TensorShape compute_stack_shape(const ITensorInfo &a, unsigned int axis, } return shape_out; } + +inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis) +{ + ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 1); + ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4); + ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions()); + + TensorShape output_shape = input_shape; + output_shape[actual_axis] = indices_shape[0]; + + return output_shape; +} } // namespace shape_calculator } // namespace misc } // namespace arm_compute |