diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 61834b88a9..6ecfdf0323 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017, 2018 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -39,12 +39,14 @@ inline TensorShape compute_permutation_output_shape(const ITensorInfo &input, co permute(output_shape, perm); return output_shape; } -inline TensorShape compute_interleaved_shape(const ITensorInfo &a) +inline TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_interleave4x4_height = 1) { - // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ] + // The interleaved output matrix will have the following shape: [ a_height * W, ceil(a_width / W) ] where W = 4 * mult_interleave4x4_height + ARM_COMPUTE_ERROR_ON(mult_interleave4x4_height < 1); + const int interleave_width = 4 * mult_interleave4x4_height; TensorShape shape_interleaved_a{ a.tensor_shape() }; - shape_interleaved_a.set(0, a.dimension(0) * 4); - shape_interleaved_a.set(1, std::ceil(a.dimension(1) / 4.f)); + shape_interleaved_a.set(0, a.dimension(0) * interleave_width); + shape_interleaved_a.set(1, std::ceil(a.dimension(1) / static_cast<float>(interleave_width))); return shape_interleaved_a; } @@ -57,12 +59,14 @@ inline TensorShape compute_transpose1xW_shape(const ITensorInfo &b) return shape_transposed1xW_b; } -inline TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b) +inline TensorShape compute_transpose1xW_with_element_size_shape(const ITensorInfo &b, int mult_transpose1xW_width = 1) { - // The transpose1xW output matrix will have the following shape: - // [ b_height * (16 / element_size), ceil(b_width / (16.0f / element_size) ] + // Note: mult_transpose1xW_width expresses the number of chunks with size 1x(W) we want to store on the same row + // The transpose1xW output matrix will have the following shape: + // [ b_height * W, ceil(b_width / W) ] where W = (16 / element size of the tensor) * mult_transpose1xW_width + ARM_COMPUTE_ERROR_ON(mult_transpose1xW_width < 1); TensorShape shape_transposed1xW_b{ b.tensor_shape() }; - const size_t transpose_width = 16 / b.element_size(); + const size_t transpose_width = (16 / b.element_size()) * mult_transpose1xW_width; shape_transposed1xW_b.set(0, b.dimension(1) * transpose_width); shape_transposed1xW_b.set(1, static_cast<size_t>(std::ceil(b.dimension(0) / static_cast<float>(transpose_width)))); |