aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp14
1 files changed, 9 insertions, 5 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index a06d94c1f5..172facfa78 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -66,17 +66,21 @@ Status validate_arguments(const ITensorInfo *a, const ITensorInfo *b, const ICLT
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
if(c != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c->info());
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->info()->dimension(0), "The C matrix must have the same number of columns as the matrix B");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(c->info()->dimension(0) != output->dimension(0), "The C matrix must have the same number of rows as the output matrix");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(c->info()->dimension(1) != output->dimension(1), "The C matrix must have the same number of columns as the output matrix");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->info()->dimension(1), "The matrix C must have the same number of rows as the matrix A");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->info()->dimension(0), "The matrix C must have the same number of columns as the matrix B");
+ }
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != output->dimension(0), "The output matrix must have the same number of columns as the matrix B");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != output->dimension(1), "The output matrix must have the same number of rows as the matrix A");
}
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");