aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-09-29 16:57:55 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-09-30 23:11:26 +0100
commitc72b59cc5c1d9251c7794edbeae8fc6b7f30f783 (patch)
tree537e43be22c7b7f12258315b26e1f14c67028503
parent903763c07f1c8a77783735b05a6a9d722bee1639 (diff)
downloadreference_model-c72b59cc5c1d9251c7794edbeae8fc6b7f30f783.tar.gz
Fixes to pass NEGATE op test.
- Elementwise unary op input/output type should match. - TOSA_UNPREDICTABLE should ONLY be sent when a tensor with negative dimension is read/written Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I689518933a2b56cd62793e3f28ea66a6e57b057c
-rw-r--r--reference_model/src/ops/ewise_unary.cc4
-rw-r--r--reference_model/src/subgraph_traverser.cc27
2 files changed, 24 insertions, 7 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 041bbdb..13e517b 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -48,9 +48,9 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes()
}
// output and input must be the same types
- if (inputs[0]->matchRankSize(*outputs[0]))
+ if (inputs[0]->matchRankTypeShape(*outputs[0]))
{
- printNodeValidationError("UnaryNode: input and output rank must match");
+ printNodeValidationError("UnaryNode: input and output rank/type/shape must match");
return 1;
}
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 0002b7b..3597314 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -14,6 +14,7 @@
// limitations under the License.
#include "subgraph_traverser.h"
+#include <unordered_set>
#ifndef SUBGRAPH_ERROR_IF
#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
@@ -117,6 +118,10 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin
int SubgraphTraverser::initializeGraph()
{
int idx = 0;
+
+ // tensor name set which contains all the name used by operator
+ std::unordered_set<std::string> used_tensor_name_set;
+
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -226,14 +231,22 @@ int SubgraphTraverser::initializeGraph()
SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
}
+ // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op.
+ // Check graph status after the op being constructed.
+ SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR,
+ "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.",
+ EnumNamesOp()[op->GetOp()]);
+
for (auto& name : op->GetInputTensorNames())
{
node->addInputName(name);
+ used_tensor_name_set.insert(name);
}
for (auto name : op->GetOutputTensorNames())
{
node->addOutputName(name);
+ used_tensor_name_set.insert(name);
}
addNode(node);
@@ -250,13 +263,17 @@ int SubgraphTraverser::initializeGraph()
for (auto ts : block->GetTensors())
{
- // Bail out if any dimension is invalid.
- for (auto& dim : ts->GetShape())
+ // Bail out if tensor is used and any of its dimension is invalid.
+ auto got = used_tensor_name_set.find(ts->GetName());
+ if (got != used_tensor_name_set.end())
{
- if (dim <= 0)
+ for (auto& dim : ts->GetShape())
{
- this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
- return 1;
+ if (dim <= 0)
+ {
+ this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+ return 1;
+ }
}
}