From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- reference_model/src/graph_node.cc | 226 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 reference_model/src/graph_node.cc (limited to 'reference_model/src/graph_node.cc') diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc new file mode 100644 index 0000000..b57b9dd --- /dev/null +++ b/reference_model/src/graph_node.cc @@ -0,0 +1,226 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "graph_node.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_) +{ + nodeType = nodeType_; + nodeId = id_; + inputs.clear(); + outputs.clear(); + inputNames.clear(); + outputNames.clear(); + clearNodeMarked(); + evalCount = 0; + clearOnNextNodeList(); + setRequiredOperands(-1, -1); + setRequiredRank(-1); +} + +GraphNode::~GraphNode() +{} + +int GraphNode::addInputName(std::string& name) +{ + inputNames.push_back(name); + return 0; +} + +int GraphNode::addOutputName(std::string& name) +{ + outputNames.push_back(name); + return 0; +} + +int GraphNode::addInputTensor(Tensor* tens) +{ + ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided"); + inputs.push_back(tens); + return 0; +} + +int GraphNode::addOutputTensor(Tensor* tens) +{ + ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided"); + outputs.push_back(tens); + return 0; +} + +int GraphNode::checkTensorAttributes() +{ + // Placeholder + return 0; +} + +int GraphNode::eval() +{ + // Placeholder evaluation function + evalCount++; + + // this should be set by derived op + for (auto ct : getOutputs()) + { + ct->setIsValid(); + } + + return 0; +} + +int GraphNode::hasAllInputsReady() const +{ + for (size_t i = 0; i < inputs.size(); i++) + { + if (!inputs[i]->getIsValid()) + return false; + } + + return true; +} + +int GraphNode::hasAllOutputsReady() const +{ + for (size_t i = 0; i < outputs.size(); i++) + { + if (!outputs[i]->getIsValid()) + return false; + } + + return true; +} + +int GraphNode::dumpNode(FILE* out) +{ + int i; + fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType], + nodeId, evalCount, onNextNodeList, isMarked); + + i = 0; + for (Tensor* ins : inputs) + { + fprintf(out, " Input[%d] ", i++); + ins->dumpTensorParams(out); + } + + i = 0; + for (Tensor* outs : outputs) + { + fprintf(out, " Output[%d] ", i++); + outs->dumpTensorParams(out); + } + + return 0; +} + +int GraphNode::dumpNode(std::ostream& out) +{ + int i; + + out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount + << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl; + + out << " Inputs:"; + for (std::string& name : inputNames) + { + out << " " << name; + } + out << std::endl; + + i = 0; + for (Tensor* ins : inputs) + { + out << " Input[" << i++ << "]: "; + ins->dumpTensorParams(out); + } + + out << " Outputs:"; + for (std::string& name : outputNames) + { + out << " " << name; + } + out << std::endl; + + i = 0; + for (Tensor* outs : outputs) + { + out << " Output[" << i++ << "]: "; + outs->dumpTensorParams(out); + } + return 0; +} + +int GraphNode::printNodeValidationError(const std::string& msg) +{ + std::cout << "Operator validation error: " << msg << std::endl; + ; + dumpNode(std::cout); + + return 0; +} + +int GraphNode::validateRequiredOperands() +{ + if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " + + std::to_string(requiredInputCount) + " input(s)"); + return 1; + } + + if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " + + std::to_string(requiredOutputCount) + " output(s)"); + return 1; + } + + return 0; +} + +int GraphNode::validateRequiredRank(const Tensor* t) +{ + if (requiredRankMin >= 0 && requiredRankMax >= 0) + { + if (t->checkRequiredRank(requiredRankMin, requiredRankMax)) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + + std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + + "]. tensorName: " + t->getName()); + return 1; + } + else + { + return 0; + } + } + + if (requiredRankMin >= 0) + { + if (t->checkRequiredRank(requiredRankMin)) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + + std::to_string(requiredRankMin) + ". tensorName: " + t->getName()); + return 1; + } + } + + return 0; +} -- cgit v1.2.1