aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVidhya Sudhan Loganathan <vidhyasudhan.loganathan@arm.com>2019-05-03 09:13:55 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-05-09 08:48:31 +0000
commitae1a89ed670956b9722fe57c2dc36c75e5f948ec (patch)
tree0fe640503957937dc753c81bac284f8378b2dcdf
parent976f11fc7964ea302997d7b04c4d5fb4765e1414 (diff)
downloadComputeLibrary-ae1a89ed670956b9722fe57c2dc36c75e5f948ec.tar.gz
COMPMID-2118 : (Nightly) : CLGroupedGEMMConvolutionLayer validation issues
Change-Id: I8cf3cf60302d9b1e0ffe37e9f441fb7e7fb0655c Signed-off-by: Vidhya Sudhan Loganathan <vidhyasudhan.loganathan@arm.com> Reviewed-on: https://review.mlplatform.org/c/1077 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h23
1 files changed, 15 insertions, 8 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index b46b1b2535..d66c87fb42 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -865,16 +865,23 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo
const bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d() != 0;
const int depth_output_gemm3d = reinterpret_output_as_3d ? gemm_info.depth_output_gemm3d() : 1;
- // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
- // dimension of the output tensor
- const int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
-
TensorShape output_shape{ input0.tensor_shape() };
- output_shape.set(0, gemm_info.n());
- output_shape.set(1, gemm_info.m() / depth_output_gemm3d);
- output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size);
- output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1);
+ if(!reinterpret_input_as_3d && !reinterpret_output_as_3d)
+ {
+ output_shape.set(0, gemm_info.n());
+ output_shape.set(1, gemm_info.m());
+ }
+ else
+ {
+ // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
+ // dimension of the output tensor
+ const int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
+ output_shape.set(0, gemm_info.n());
+ output_shape.set(1, gemm_info.m() / depth_output_gemm3d);
+ output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size);
+ output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1);
+ }
return output_shape;
}