// Copyright (c) 2020-2021, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "data_nodes.h" using namespace TosaReference; using namespace Eigen; using namespace tosa; OpConst::OpConst(uint64_t id_) : GraphNode(Op_CONST, id_) { setRequiredOperands(0, 1); } OpConst::~OpConst() {} int OpConst::checkTensorAttributes() { if (validateRequiredOperands()) return 1; return 0; } int OpConst::eval() { // Evaluation is trivial for constants return GraphNode::eval(); } OpPlaceholder::OpPlaceholder(uint64_t id_) : GraphNode(Op_PLACEHOLDER, id_) { setRequiredOperands(0, 1); } OpPlaceholder::~OpPlaceholder() {} int OpPlaceholder::checkTensorAttributes() { if (validateRequiredOperands()) return 1; return 0; } int OpPlaceholder::eval() { // Evaluation is trivial for placeholders return GraphNode::eval(); } template OpIdentity::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_IDENTITY, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); } template OpIdentity::~OpIdentity() {} template int OpIdentity::checkTensorAttributes() { if (inputs.size() != outputs.size()) { printNodeValidationError("Input and output tensor list lengths are not equal"); return 1; } in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); if (in->matchRankTypeShape(*out)) { printNodeValidationError("Input and output tensor rank, type, or shape do not match"); return 1; } return 0; } template int OpIdentity::eval() { out->getTensor() = in->getTensor(); return GraphNode::eval(); } template OpIdentityN::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_IDENTITYN, id_) { setRequiredRank(0, 6); } template OpIdentityN::~OpIdentityN() {} template int OpIdentityN::checkTensorAttributes() { if (inputs.size() != outputs.size()) { printNodeValidationError("Input and output tensor list lengths are not equal"); return 1; } for (size_t i = 0; i < inputs.size(); i++) { ins.push_back(dynamic_cast*>(inputs[i])); outs.push_back(dynamic_cast*>(outputs[i])); if (ins[i]->matchRankTypeShape(*outs[i])) { printNodeValidationError("Input and output tensor rank, type, or shape do not match"); return 1; } } return 0; } template int OpIdentityN::eval() { for (size_t i = 0; i < ins.size(); i++) { outs[i]->getTensor() = ins[i]->getTensor(); } return GraphNode::eval(); } // template explicit instantiation // note OpConst and OpPlaceholder are not templated DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);