// Copyright (c) 2020, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "graph_node.h" using namespace TosaReference; using namespace Eigen; using namespace tosa; GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_) { nodeType = nodeType_; nodeId = id_; inputs.clear(); outputs.clear(); inputNames.clear(); outputNames.clear(); clearNodeMarked(); evalCount = 0; clearOnNextNodeList(); setRequiredOperands(-1, -1); setRequiredRank(-1); } GraphNode::~GraphNode() {} int GraphNode::addInputName(std::string& name) { inputNames.push_back(name); return 0; } int GraphNode::addOutputName(std::string& name) { outputNames.push_back(name); return 0; } int GraphNode::addInputTensor(Tensor* tens) { ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided"); inputs.push_back(tens); return 0; } int GraphNode::addOutputTensor(Tensor* tens) { ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided"); outputs.push_back(tens); return 0; } int GraphNode::checkTensorAttributes() { // Placeholder return 0; } int GraphNode::eval() { // Placeholder evaluation function evalCount++; // this should be set by derived op for (auto ct : getOutputs()) { ct->setIsValid(); } return 0; } int GraphNode::hasAllInputsReady() const { for (size_t i = 0; i < inputs.size(); i++) { if (!inputs[i]->getIsValid()) return false; } return true; } int GraphNode::hasAllOutputsReady() const { for (size_t i = 0; i < outputs.size(); i++) { if (!outputs[i]->getIsValid()) return false; } return true; } int GraphNode::dumpNode(FILE* out) { int i; fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType], nodeId, evalCount, onNextNodeList, isMarked); i = 0; for (Tensor* ins : inputs) { fprintf(out, " Input[%d] ", i++); ins->dumpTensorParams(out); } i = 0; for (Tensor* outs : outputs) { fprintf(out, " Output[%d] ", i++); outs->dumpTensorParams(out); } return 0; } int GraphNode::dumpNode(std::ostream& out) { int i; out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl; out << " Inputs:"; for (std::string& name : inputNames) { out << " " << name; } out << std::endl; i = 0; for (Tensor* ins : inputs) { out << " Input[" << i++ << "]: "; ins->dumpTensorParams(out); } out << " Outputs:"; for (std::string& name : outputNames) { out << " " << name; } out << std::endl; i = 0; for (Tensor* outs : outputs) { out << " Output[" << i++ << "]: "; outs->dumpTensorParams(out); } return 0; } int GraphNode::printNodeValidationError(const std::string& msg) { std::cout << "Operator validation error: " << msg << std::endl; ; dumpNode(std::cout); return 0; } int GraphNode::validateRequiredOperands() { if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount) { printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " + std::to_string(requiredInputCount) + " input(s)"); return 1; } if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount) { printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " + std::to_string(requiredOutputCount) + " output(s)"); return 1; } return 0; } int GraphNode::validateRequiredRank(const Tensor* t) { if (requiredRankMin >= 0 && requiredRankMax >= 0) { if (t->checkRequiredRank(requiredRankMin, requiredRankMax)) { printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + "]. tensorName: " + t->getName()); return 1; } else { return 0; } } if (requiredRankMin >= 0) { if (t->checkRequiredRank(requiredRankMin)) { printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + std::to_string(requiredRankMin) + ". tensorName: " + t->getName()); return 1; } } return 0; }