aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-10-04 10:43:14 -0700
committerEric Kunze <eric.kunze@arm.com>2021-10-05 20:33:41 +0000
commit478101bebd3058a1917d9a9d87ca6d030af71c47 (patch)
treeda9babc14408efd5416e772e3aceb07c3a95fd9b
parent848efb46db2d407a9bb4fba1940d06e143a5dbad (diff)
downloadreference_model-478101bebd3058a1917d9a9d87ca6d030af71c47.tar.gz
Couple of reference model fixes
- comparison ops could have different type of input/output - add SUBGRAPH_ERROR_IF() when operator doesn't have any output tensor Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I10f2c10f92de1c7a979221a421fa8e86b26fcc72
-rw-r--r--reference_model/src/ops/ewise_binary.cc23
-rw-r--r--reference_model/src/subgraph_traverser.cc3
2 files changed, 15 insertions, 11 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index a11d855..023158c 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,23 +60,24 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- // Input and output rank must match
- // If it's not MUL, type also needs to match as well.
- if (nodeType != Op_MUL)
+ // In some ops, only rank of input and output tensor needs to match
+ if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL)
{
- if (inputs[0]->matchRankType(*outputs[0]))
+ if (inputs[0]->matchRank(*outputs[0]))
{
- printNodeValidationError("Binary operators (except MUL) input and output rank and type must match");
+ std::string err =
+ "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
+ printNodeValidationError(err.c_str());
return 1;
}
}
- else
+ // Otherwise both rand/type of input and output must match
+ else if (inputs[0]->matchRankType(*outputs[0]))
{
- if (inputs[0]->matchRank(*outputs[0]))
- {
- printNodeValidationError("MUL operator input and output rank must match");
- return 1;
- }
+ std::string err =
+ "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match";
+ printNodeValidationError(err.c_str());
+ return 1;
}
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 3597314..82de69c 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -186,6 +186,9 @@ int SubgraphTraverser::initializeGraph()
weight_rank = weight_tensor->GetShape().size();
}
+ SUBGRAPH_ERROR_IF(op->GetOutputTensorNames().size() == 0,
+ "SubgraphTraverser::initializeGraph(): Op=%s must have at least one output tensor.",
+ EnumNamesOp()[op->GetOp()]);
std::string output_name = op->GetOutputTensorNames()[0];
TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
SUBGRAPH_ERROR_IF(