aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2018-10-01 12:26:28 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:55:19 +0000
commitf02e52796c3e2bd4a88b696cbe8415cd36884c12 (patch)
tree42c29e2d49ba364f03d429108eb5bde9ee085b9c /src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
parenteddff5a6f43638a27b60564037324f040339faf5 (diff)
downloadComputeLibrary-f02e52796c3e2bd4a88b696cbe8415cd36884c12.tar.gz
COMPMID-1607 - (Nightly) CLGEMMLowpMatrixMultiplyCore errors and mismatches
Change-Id: I5f2e6843526cb154176a5b113627d4f36c3a8edd Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/150967 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: bsgcomp <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp18
1 files changed, 8 insertions, 10 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index ee364e5932..56f318d6a8 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -59,6 +59,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 2 && reshape_info.reinterpret_input_as_3d(), "The input1 tensor cannot have more than 2 dimensions if input0 has to be reinterpreted as 3D");
@@ -85,7 +87,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
- const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
+ const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height, reshape_info.reinterpret_input_as_3d()));
const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
@@ -122,8 +124,11 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
reinterpret_output_as_3d = false;
}
+ GEMMReshapeInfo reshape_info_to_use = GEMMReshapeInfo(reshape_info.m(), reshape_info.n(), reshape_info.k(), reshape_info.mult_transpose1xW_width(), reshape_info.mult_interleave4x4_height(),
+ reinterpret_output_as_3d ? reshape_info.depth_output_gemm3d() : 1, reinterpret_input_as_3d);
+
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)));
+ auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info_to_use)));
TensorInfo tmp_info(*output);
@@ -140,7 +145,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
if(is_interleaved_transposed)
{
// reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
- ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
+ ARM_COMPUTE_ERROR_ON(reinterpret_input_as_3d);
// Configure kernel window
num_elems_processed_per_iteration_x = 4;
@@ -216,13 +221,6 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
- // Output tensor auto inizialitation if not yet initialized
- TensorShape tensor_shape{ input0->info()->tensor_shape() };
- tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1->info()->dimension(0));
- tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0->info()->dimension(1));
-
- auto_init_if_empty(*output->info(), tensor_shape, 1, DataType::S32, QuantizationInfo());
-
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
_input0 = input0;