aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h23
1 files changed, 15 insertions, 8 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index b46b1b2535..d66c87fb42 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -865,16 +865,23 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo
const bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d() != 0;
const int depth_output_gemm3d = reinterpret_output_as_3d ? gemm_info.depth_output_gemm3d() : 1;
- // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
- // dimension of the output tensor
- const int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
-
TensorShape output_shape{ input0.tensor_shape() };
- output_shape.set(0, gemm_info.n());
- output_shape.set(1, gemm_info.m() / depth_output_gemm3d);
- output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size);
- output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1);
+ if(!reinterpret_input_as_3d && !reinterpret_output_as_3d)
+ {
+ output_shape.set(0, gemm_info.n());
+ output_shape.set(1, gemm_info.m());
+ }
+ else
+ {
+ // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
+ // dimension of the output tensor
+ const int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
+ output_shape.set(0, gemm_info.n());
+ output_shape.set(1, gemm_info.m() / depth_output_gemm3d);
+ output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size);
+ output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1);
+ }
return output_shape;
}