aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h55
1 files changed, 22 insertions, 33 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 384bd460a0..f5058b35fb 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1173,6 +1173,11 @@ inline TensorShape extract_shape(const TensorShape *data)
return *data;
}
+inline TensorShape extract_shape(TensorShape *data)
+{
+ return *data;
+}
+
/** Calculate the unstack shape of a tensor
*
* @param[in] input_shape Input tensor shape
@@ -1187,37 +1192,6 @@ inline TensorShape calculate_unstack_shape(TensorShape input_shape, unsigned int
return input_shape;
}
-/** Calculate the depth concatenate output shape of a vector of tensors
- *
- * @param[in] inputs_vector Vector containing the shapes of the inputs
- *
- * @return the calculated shape
- */
-template <typename T>
-inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector)
-{
- TensorShape out_shape = extract_shape(inputs_vector[0]);
-
- size_t max_x = 0;
- size_t max_y = 0;
- size_t depth = 0;
-
- for(const auto &tensor : inputs_vector)
- {
- ARM_COMPUTE_ERROR_ON(tensor == nullptr);
- const TensorShape shape = extract_shape(tensor);
- max_x = std::max(shape.x(), max_x);
- max_y = std::max(shape.y(), max_y);
- depth += shape.z();
- }
-
- out_shape.set(0, max_x);
- out_shape.set(1, max_y);
- out_shape.set(2, depth);
-
- return out_shape;
-}
-
/** Calculate the concatenate output shape of the concatenate operation along a single axis
*
* @param[in] input Vector containing the shapes of the inputs
@@ -1230,12 +1204,27 @@ inline TensorShape calculate_concatenate_shape(const std::vector<T *> &input, si
{
TensorShape out_shape = extract_shape(input[0]);
+ // All dimensions must match except the axis one
+ for(unsigned int i = 0; i < MAX_DIMS; ++i)
+ {
+ if(i == axis)
+ {
+ continue;
+ }
+
+ for(const auto &tensor : input)
+ {
+ ARM_COMPUTE_ERROR_ON(tensor == nullptr);
+ const TensorShape shape = extract_shape(tensor);
+ ARM_COMPUTE_ERROR_ON(out_shape[i] != shape[i]);
+ }
+ }
+
+ // Calculate output shape
size_t new_size = 0;
for(const auto &tensor : input)
{
- ARM_COMPUTE_ERROR_ON(tensor == nullptr);
const TensorShape shape = extract_shape(tensor);
- ARM_COMPUTE_ERROR_ON(axis >= shape.num_dimensions());
new_size += shape[axis];
}