diff options
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 3f0e7b2..b9e2fbe 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1246,9 +1246,10 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 1; } - if (weight->getShape()[0] != bias->getShape()[0]) + if (weight->getShape()[0] != bias->getShape()[0] && bias->getShape()[0] != 1) { - printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]"); + printNodeValidationError( + "OpFullyConnected operator bias.shape[0] should match weight.shape[0] or be equal to 1"); return 1; } |