aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
authorJack Frankland <jack.frankland@arm.com>2023-11-21 17:08:37 +0000
committerJack Frankland <jack.frankland@arm.com>2023-11-22 09:23:52 +0000
commitac40bd12192b6d41afa5d969578766e050c44398 (patch)
tree5546cf4205f3825ed98d40807ce4946a21308731 /reference_model/src/ops/tensor_ops.cc
parent5637a8606bc3caeec3c590350de770c7fcec8dd7 (diff)
downloadreference_model-ac40bd12192b6d41afa5d969578766e050c44398.tar.gz
Correct Fully Connected Validation Logic
The bias operand of the fully connected operator must be a 1D tensor either equal to the output channel size or of size 1. Previously we asserted the former case, we now include the second case. Signed-off-by: Jack Frankland <jack.frankland@arm.com> Change-Id: I07dbc8a3aa1650703e5c50e1e7f36bb9539fd5db
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 3f0e7b2..b9e2fbe 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1246,9 +1246,10 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (weight->getShape()[0] != bias->getShape()[0])
+ if (weight->getShape()[0] != bias->getShape()[0] && bias->getShape()[0] != 1)
{
- printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
+ printNodeValidationError(
+ "OpFullyConnected operator bias.shape[0] should match weight.shape[0] or be equal to 1");
return 1;
}