aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h8
1 files changed, 8 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9543d989b8..30d3f9bb62 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -358,6 +358,14 @@ inline TensorShape compute_rnn_shape(const ITensorInfo *input, const unsigned in
return output_shape;
}
+inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+{
+ TensorShape tensor_shape{ input0.tensor_shape() };
+ tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1.dimension(0));
+ tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0.dimension(1));
+
+ return tensor_shape;
+}
} // namespace shape_calculator
} // namespace misc
} // namespace arm_compute