From 089695f0d4b1ebd1bc76ba95e415bce1297808be Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Wed, 17 Oct 2018 18:04:15 +0100 Subject: COMPMID-1659: (Nightly) CLGEMMConvolutionLayer QASYMM8 TensorShape error Change-Id: Ib4ca28b82bd82f0ed4d2c906185d3f4010246616 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/153986 Reviewed-by: Giuseppe Rossini Tested-by: bsgcomp --- arm_compute/core/utils/misc/ShapeCalculator.h | 10 +++++++++- src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 11d20c919f..56f65d0ba8 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -113,7 +113,15 @@ inline TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_inte const int M = a.dimension(1) * a.dimension(2); const int height = std::ceil(M / static_cast(interleave_width)); shape_interleaved_a.set(1, height); - shape_interleaved_a.remove_dimension(2); + + // When the data format is NHWC and the shapes are Nx1x1 + // the tensor shape num_dimensions is automatically set to 1 instead of 3. + // To avoid failures by removing a dimension that doesn't exist + // check if the number of dimensions is greater than 2. + if(shape_interleaved_a.num_dimensions() > 2) + { + shape_interleaved_a.remove_dimension(2); + } } else { diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp index 509b668bc9..f79fb43073 100644 --- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp @@ -238,6 +238,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso // Validate transpose kernel auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width))); + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &tmp_b_info, mult_transpose1xW_width)); } // Validate matrix multiply -- cgit v1.2.1