aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/graph_node.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/graph_node.cc')
-rw-r--r--reference_model/src/graph_node.cc226
1 files changed, 226 insertions, 0 deletions
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
new file mode 100644
index 0000000..b57b9dd
--- /dev/null
+++ b/reference_model/src/graph_node.cc
@@ -0,0 +1,226 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "graph_node.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_)
+{
+ nodeType = nodeType_;
+ nodeId = id_;
+ inputs.clear();
+ outputs.clear();
+ inputNames.clear();
+ outputNames.clear();
+ clearNodeMarked();
+ evalCount = 0;
+ clearOnNextNodeList();
+ setRequiredOperands(-1, -1);
+ setRequiredRank(-1);
+}
+
+GraphNode::~GraphNode()
+{}
+
+int GraphNode::addInputName(std::string& name)
+{
+ inputNames.push_back(name);
+ return 0;
+}
+
+int GraphNode::addOutputName(std::string& name)
+{
+ outputNames.push_back(name);
+ return 0;
+}
+
+int GraphNode::addInputTensor(Tensor* tens)
+{
+ ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided");
+ inputs.push_back(tens);
+ return 0;
+}
+
+int GraphNode::addOutputTensor(Tensor* tens)
+{
+ ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided");
+ outputs.push_back(tens);
+ return 0;
+}
+
+int GraphNode::checkTensorAttributes()
+{
+ // Placeholder
+ return 0;
+}
+
+int GraphNode::eval()
+{
+ // Placeholder evaluation function
+ evalCount++;
+
+ // this should be set by derived op
+ for (auto ct : getOutputs())
+ {
+ ct->setIsValid();
+ }
+
+ return 0;
+}
+
+int GraphNode::hasAllInputsReady() const
+{
+ for (size_t i = 0; i < inputs.size(); i++)
+ {
+ if (!inputs[i]->getIsValid())
+ return false;
+ }
+
+ return true;
+}
+
+int GraphNode::hasAllOutputsReady() const
+{
+ for (size_t i = 0; i < outputs.size(); i++)
+ {
+ if (!outputs[i]->getIsValid())
+ return false;
+ }
+
+ return true;
+}
+
+int GraphNode::dumpNode(FILE* out)
+{
+ int i;
+ fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType],
+ nodeId, evalCount, onNextNodeList, isMarked);
+
+ i = 0;
+ for (Tensor* ins : inputs)
+ {
+ fprintf(out, " Input[%d] ", i++);
+ ins->dumpTensorParams(out);
+ }
+
+ i = 0;
+ for (Tensor* outs : outputs)
+ {
+ fprintf(out, " Output[%d] ", i++);
+ outs->dumpTensorParams(out);
+ }
+
+ return 0;
+}
+
+int GraphNode::dumpNode(std::ostream& out)
+{
+ int i;
+
+ out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount
+ << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl;
+
+ out << " Inputs:";
+ for (std::string& name : inputNames)
+ {
+ out << " " << name;
+ }
+ out << std::endl;
+
+ i = 0;
+ for (Tensor* ins : inputs)
+ {
+ out << " Input[" << i++ << "]: ";
+ ins->dumpTensorParams(out);
+ }
+
+ out << " Outputs:";
+ for (std::string& name : outputNames)
+ {
+ out << " " << name;
+ }
+ out << std::endl;
+
+ i = 0;
+ for (Tensor* outs : outputs)
+ {
+ out << " Output[" << i++ << "]: ";
+ outs->dumpTensorParams(out);
+ }
+ return 0;
+}
+
+int GraphNode::printNodeValidationError(const std::string& msg)
+{
+ std::cout << "Operator validation error: " << msg << std::endl;
+ ;
+ dumpNode(std::cout);
+
+ return 0;
+}
+
+int GraphNode::validateRequiredOperands()
+{
+ if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount)
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " +
+ std::to_string(requiredInputCount) + " input(s)");
+ return 1;
+ }
+
+ if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount)
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " +
+ std::to_string(requiredOutputCount) + " output(s)");
+ return 1;
+ }
+
+ return 0;
+}
+
+int GraphNode::validateRequiredRank(const Tensor* t)
+{
+ if (requiredRankMin >= 0 && requiredRankMax >= 0)
+ {
+ if (t->checkRequiredRank(requiredRankMin, requiredRankMax))
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
+ std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
+ "]. tensorName: " + t->getName());
+ return 1;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ if (requiredRankMin >= 0)
+ {
+ if (t->checkRequiredRank(requiredRankMin))
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
+ std::to_string(requiredRankMin) + ". tensorName: " + t->getName());
+ return 1;
+ }
+ }
+
+ return 0;
+}