diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index a06d94c1f5..172facfa78 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -66,17 +66,21 @@ Status validate_arguments(const ITensorInfo *a, const ITensorInfo *b, const ICLT ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); if(c != nullptr) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c->info()); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->info()->dimension(0), "The C matrix must have the same number of columns as the matrix B"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(c->info()->dimension(0) != output->dimension(0), "The C matrix must have the same number of rows as the output matrix"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(c->info()->dimension(1) != output->dimension(1), "The C matrix must have the same number of columns as the output matrix"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->info()->dimension(1), "The matrix C must have the same number of rows as the matrix A"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->info()->dimension(0), "The matrix C must have the same number of columns as the matrix B"); + } + + if(output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != output->dimension(0), "The output matrix must have the same number of columns as the matrix B"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != output->dimension(1), "The output matrix must have the same number of rows as the matrix A"); } ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B"); |