From 478101bebd3058a1917d9a9d87ca6d030af71c47 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Mon, 4 Oct 2021 10:43:14 -0700 Subject: Couple of reference model fixes - comparison ops could have different type of input/output - add SUBGRAPH_ERROR_IF() when operator doesn't have any output tensor Signed-off-by: Kevin Cheng Change-Id: I10f2c10f92de1c7a979221a421fa8e86b26fcc72 --- reference_model/src/ops/ewise_binary.cc | 23 ++++++++++++----------- reference_model/src/subgraph_traverser.cc | 3 +++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index a11d855..023158c 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -60,23 +60,24 @@ int BinaryNodeBase::checkTensorAttributes() return 1; } - // Input and output rank must match - // If it's not MUL, type also needs to match as well. - if (nodeType != Op_MUL) + // In some ops, only rank of input and output tensor needs to match + if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL) { - if (inputs[0]->matchRankType(*outputs[0])) + if (inputs[0]->matchRank(*outputs[0])) { - printNodeValidationError("Binary operators (except MUL) input and output rank and type must match"); + std::string err = + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; + printNodeValidationError(err.c_str()); return 1; } } - else + // Otherwise both rand/type of input and output must match + else if (inputs[0]->matchRankType(*outputs[0])) { - if (inputs[0]->matchRank(*outputs[0])) - { - printNodeValidationError("MUL operator input and output rank must match"); - return 1; - } + std::string err = + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match"; + printNodeValidationError(err.c_str()); + return 1; } a = dynamic_cast*>(inputs[0]); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 3597314..82de69c 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -186,6 +186,9 @@ int SubgraphTraverser::initializeGraph() weight_rank = weight_tensor->GetShape().size(); } + SUBGRAPH_ERROR_IF(op->GetOutputTensorNames().size() == 0, + "SubgraphTraverser::initializeGraph(): Op=%s must have at least one output tensor.", + EnumNamesOp()[op->GetOp()]); std::string output_name = op->GetOutputTensorNames()[0]; TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name); SUBGRAPH_ERROR_IF( -- cgit v1.2.1