aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/ops/tensor_ops.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 3f0e7b2..b9e2fbe 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1246,9 +1246,10 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (weight->getShape()[0] != bias->getShape()[0])
+ if (weight->getShape()[0] != bias->getShape()[0] && bias->getShape()[0] != 1)
{
- printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
+ printNodeValidationError(
+ "OpFullyConnected operator bias.shape[0] should match weight.shape[0] or be equal to 1");
return 1;
}