aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMM.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 321ecf85d8..e8bf6732b2 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -140,7 +140,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
if(c != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
@@ -150,7 +150,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
if(output->total_size() != 0)
{
ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != output->dimension(0));
- if(gemm_info.depth_output_gemm3d() != 1)
+ if(gemm_info.depth_output_gemm3d() != 0)
{
if(gemm_info.reinterpret_input_as_3d())
{
@@ -174,7 +174,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
if(!run_optimised)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "NEGEMM cannot reinterpret the input tensor as 3D");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 1, "NEGEMM cannot reinterpret the output tensor as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "NEGEMM cannot reinterpret the output tensor as 3D");
// Check if the first input tensor is a vector.
const bool run_vector_matrix_multiplication = a->dimension(1) < 2;