diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMM.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMM.cpp | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 321ecf85d8..e8bf6732b2 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -140,7 +140,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso if(c != nullptr) { - ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0); ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d()); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A"); @@ -150,7 +150,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso if(output->total_size() != 0) { ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0)); - if(gemm_info.depth_output_gemm3d() != 1) + if(gemm_info.depth_output_gemm3d() != 0) { if(gemm_info.reinterpret_input_as_3d()) { @@ -174,7 +174,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso if(!run_optimised) { ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 1, "NEGEMM cannot reinterpret the output tensor as 3D"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "NEGEMM cannot reinterpret the output tensor as 3D"); // Check if the first input tensor is a vector. const bool run_vector_matrix_multiplication = a->dimension(1) < 2; |