aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc27
1 files changed, 22 insertions, 5 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 0002b7b..3597314 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -14,6 +14,7 @@
// limitations under the License.
#include "subgraph_traverser.h"
+#include <unordered_set>
#ifndef SUBGRAPH_ERROR_IF
#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
@@ -117,6 +118,10 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin
int SubgraphTraverser::initializeGraph()
{
int idx = 0;
+
+ // tensor name set which contains all the name used by operator
+ std::unordered_set<std::string> used_tensor_name_set;
+
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -226,14 +231,22 @@ int SubgraphTraverser::initializeGraph()
SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
}
+ // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op.
+ // Check graph status after the op being constructed.
+ SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR,
+ "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.",
+ EnumNamesOp()[op->GetOp()]);
+
for (auto& name : op->GetInputTensorNames())
{
node->addInputName(name);
+ used_tensor_name_set.insert(name);
}
for (auto name : op->GetOutputTensorNames())
{
node->addOutputName(name);
+ used_tensor_name_set.insert(name);
}
addNode(node);
@@ -250,13 +263,17 @@ int SubgraphTraverser::initializeGraph()
for (auto ts : block->GetTensors())
{
- // Bail out if any dimension is invalid.
- for (auto& dim : ts->GetShape())
+ // Bail out if tensor is used and any of its dimension is invalid.
+ auto got = used_tensor_name_set.find(ts->GetName());
+ if (got != used_tensor_name_set.end())
{
- if (dim <= 0)
+ for (auto& dim : ts->GetShape())
{
- this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
- return 1;
+ if (dim <= 0)
+ {
+ this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+ return 1;
+ }
}
}