aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2019-04-05 17:18:36 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-04-16 16:02:13 +0000
commita9c4472188abef421adb589e2a6fef52727d465f (patch)
treef8f6540b05049074030c32332b5427e826cc58ea /arm_compute/core/utils
parent2ec6c1eb6ee77b79e8ab6b97b8cd70bcc4c5589d (diff)
downloadComputeLibrary-a9c4472188abef421adb589e2a6fef52727d465f.tar.gz
COMPMID-2051 Refactor shape_calculator::calculate_concatenate_shape
Change-Id: Ibf316718d11fa975d75f226925747b21c4efd127 Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/974 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h55
1 files changed, 22 insertions, 33 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 384bd460a0..f5058b35fb 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1173,6 +1173,11 @@ inline TensorShape extract_shape(const TensorShape *data)
return *data;
}
+inline TensorShape extract_shape(TensorShape *data)
+{
+ return *data;
+}
+
/** Calculate the unstack shape of a tensor
*
* @param[in] input_shape Input tensor shape
@@ -1187,37 +1192,6 @@ inline TensorShape calculate_unstack_shape(TensorShape input_shape, unsigned int
return input_shape;
}
-/** Calculate the depth concatenate output shape of a vector of tensors
- *
- * @param[in] inputs_vector Vector containing the shapes of the inputs
- *
- * @return the calculated shape
- */
-template <typename T>
-inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector)
-{
- TensorShape out_shape = extract_shape(inputs_vector[0]);
-
- size_t max_x = 0;
- size_t max_y = 0;
- size_t depth = 0;
-
- for(const auto &tensor : inputs_vector)
- {
- ARM_COMPUTE_ERROR_ON(tensor == nullptr);
- const TensorShape shape = extract_shape(tensor);
- max_x = std::max(shape.x(), max_x);
- max_y = std::max(shape.y(), max_y);
- depth += shape.z();
- }
-
- out_shape.set(0, max_x);
- out_shape.set(1, max_y);
- out_shape.set(2, depth);
-
- return out_shape;
-}
-
/** Calculate the concatenate output shape of the concatenate operation along a single axis
*
* @param[in] input Vector containing the shapes of the inputs
@@ -1230,12 +1204,27 @@ inline TensorShape calculate_concatenate_shape(const std::vector<T *> &input, si
{
TensorShape out_shape = extract_shape(input[0]);
+ // All dimensions must match except the axis one
+ for(unsigned int i = 0; i < MAX_DIMS; ++i)
+ {
+ if(i == axis)
+ {
+ continue;
+ }
+
+ for(const auto &tensor : input)
+ {
+ ARM_COMPUTE_ERROR_ON(tensor == nullptr);
+ const TensorShape shape = extract_shape(tensor);
+ ARM_COMPUTE_ERROR_ON(out_shape[i] != shape[i]);
+ }
+ }
+
+ // Calculate output shape
size_t new_size = 0;
for(const auto &tensor : input)
{
- ARM_COMPUTE_ERROR_ON(tensor == nullptr);
const TensorShape shape = extract_shape(tensor);
- ARM_COMPUTE_ERROR_ON(axis >= shape.num_dimensions());
new_size += shape[axis];
}