diff options
author | Jack Frankland <jack.frankland@arm.com> | 2023-11-21 17:08:37 +0000 |
---|---|---|
committer | Jack Frankland <jack.frankland@arm.com> | 2023-11-22 09:23:52 +0000 |
commit | ac40bd12192b6d41afa5d969578766e050c44398 (patch) | |
tree | 5546cf4205f3825ed98d40807ce4946a21308731 | |
parent | 5637a8606bc3caeec3c590350de770c7fcec8dd7 (diff) | |
download | reference_model-ac40bd12192b6d41afa5d969578766e050c44398.tar.gz |
Correct Fully Connected Validation Logic
The bias operand of the fully connected operator must be a 1D tensor
either equal to the output channel size or of size 1. Previously we
asserted the former case, we now include the second case.
Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Change-Id: I07dbc8a3aa1650703e5c50e1e7f36bb9539fd5db
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 3f0e7b2..b9e2fbe 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1246,9 +1246,10 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 1; } - if (weight->getShape()[0] != bias->getShape()[0]) + if (weight->getShape()[0] != bias->getShape()[0] && bias->getShape()[0] != 1) { - printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]"); + printNodeValidationError( + "OpFullyConnected operator bias.shape[0] should match weight.shape[0] or be equal to 1"); return 1; } |