aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h22
1 files changed, 13 insertions, 9 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 61834b88a9..6ecfdf0323 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017, 2018 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,12 +39,14 @@ inline TensorShape compute_permutation_output_shape(const ITensorInfo &input, co
permute(output_shape, perm);
return output_shape;
}
-inline TensorShape compute_interleaved_shape(const ITensorInfo &a)
+inline TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height = 1)
{
- // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
+ // The interleaved output matrix will have the following shape: [ a_height * W, ceil(a_width / W) ] where W = 4 * mult_interleave4x4_height
+ ARM_COMPUTE_ERROR_ON(mult_interleave4x4_height < 1);
+ const int interleave_width = 4 * mult_interleave4x4_height;
TensorShape shape_interleaved_a{ a.tensor_shape() };
- shape_interleaved_a.set(0, a.dimension(0) * 4);
- shape_interleaved_a.set(1, std::ceil(a.dimension(1) / 4.f));
+ shape_interleaved_a.set(0, a.dimension(0) * interleave_width);
+ shape_interleaved_a.set(1, std::ceil(a.dimension(1) / static_cast<float>(interleave_width)));
return shape_interleaved_a;
}
@@ -57,12 +59,14 @@ inline TensorShape compute_transpose1xW_shape(const ITensorInfo &b)
return shape_transposed1xW_b;
}
-inline TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b)
+inline TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width = 1)
{
- // The transpose1xW output matrix will have the following shape:
- // [ b_height * (16 / element_size), ceil(b_width / (16.0f / element_size) ]
+ // Note: mult_transpose1xW_width expresses the number of chunks with size 1x(W) we want to store on the same row
+ // The transpose1xW output matrix will have the following shape:
+ // [ b_height * W, ceil(b_width / W) ] where W = (16 / element size of the tensor) * mult_transpose1xW_width
+ ARM_COMPUTE_ERROR_ON(mult_transpose1xW_width < 1);
TensorShape shape_transposed1xW_b{ b.tensor_shape() };
- const size_t transpose_width = 16 / b.element_size();
+ const size_t transpose_width = (16 / b.element_size()) * mult_transpose1xW_width;
shape_transposed1xW_b.set(0, b.dimension(1) * transpose_width);
shape_transposed1xW_b.set(1, static_cast<size_t>(std::ceil(b.dimension(0) / static_cast<float>(transpose_width))));