aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp12
1 files changed, 5 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index 56f318d6a8..99e184050e 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -57,6 +57,7 @@ using ElementsProcessed = Steps;
Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
@@ -87,7 +88,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
- const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height, reshape_info.reinterpret_input_as_3d()));
+ const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
@@ -124,11 +125,8 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
reinterpret_output_as_3d = false;
}
- GEMMReshapeInfo reshape_info_to_use = GEMMReshapeInfo(reshape_info.m(), reshape_info.n(), reshape_info.k(), reshape_info.mult_transpose1xW_width(), reshape_info.mult_interleave4x4_height(),
- reinterpret_output_as_3d ? reshape_info.depth_output_gemm3d() : 1, reinterpret_input_as_3d);
-
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info_to_use)));
+ auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)));
TensorInfo tmp_info(*output);
@@ -145,7 +143,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
if(is_interleaved_transposed)
{
// reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
- ARM_COMPUTE_ERROR_ON(reinterpret_input_as_3d);
+ ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
// Configure kernel window
num_elems_processed_per_iteration_x = 4;
@@ -198,7 +196,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
Coordinates coord;
coord.set_num_dimensions(output->num_dimensions());
- output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
+ output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape()));
}
// Collapse along the Z direction