From c72b59cc5c1d9251c7794edbeae8fc6b7f30f783 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 29 Sep 2021 16:57:55 -0700 Subject: Fixes to pass NEGATE op test. - Elementwise unary op input/output type should match. - TOSA_UNPREDICTABLE should ONLY be sent when a tensor with negative dimension is read/written Signed-off-by: Kevin Cheng Change-Id: I689518933a2b56cd62793e3f28ea66a6e57b057c --- reference_model/src/ops/ewise_unary.cc | 4 ++-- reference_model/src/subgraph_traverser.cc | 27 ++++++++++++++++++++++----- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 041bbdb..13e517b 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -48,9 +48,9 @@ int UnaryNode::checkTensorAttributes() } // output and input must be the same types - if (inputs[0]->matchRankSize(*outputs[0])) + if (inputs[0]->matchRankTypeShape(*outputs[0])) { - printNodeValidationError("UnaryNode: input and output rank must match"); + printNodeValidationError("UnaryNode: input and output rank/type/shape must match"); return 1; } diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 0002b7b..3597314 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -14,6 +14,7 @@ // limitations under the License. #include "subgraph_traverser.h" +#include #ifndef SUBGRAPH_ERROR_IF #define SUBGRAPH_ERROR_IF(COND, fmt, ...) \ @@ -117,6 +118,10 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin int SubgraphTraverser::initializeGraph() { int idx = 0; + + // tensor name set which contains all the name used by operator + std::unordered_set used_tensor_name_set; + for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode @@ -226,14 +231,22 @@ int SubgraphTraverser::initializeGraph() SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank."); } + // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op. + // Check graph status after the op being constructed. + SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR, + "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.", + EnumNamesOp()[op->GetOp()]); + for (auto& name : op->GetInputTensorNames()) { node->addInputName(name); + used_tensor_name_set.insert(name); } for (auto name : op->GetOutputTensorNames()) { node->addOutputName(name); + used_tensor_name_set.insert(name); } addNode(node); @@ -250,13 +263,17 @@ int SubgraphTraverser::initializeGraph() for (auto ts : block->GetTensors()) { - // Bail out if any dimension is invalid. - for (auto& dim : ts->GetShape()) + // Bail out if tensor is used and any of its dimension is invalid. + auto got = used_tensor_name_set.find(ts->GetName()); + if (got != used_tensor_name_set.end()) { - if (dim <= 0) + for (auto& dim : ts->GetShape()) { - this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); - return 1; + if (dim <= 0) + { + this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); + return 1; + } } } -- cgit v1.2.1