diff options
Diffstat (limited to 'reference_model/src/ops/data_nodes.cc')
-rw-r--r-- | reference_model/src/ops/data_nodes.cc | 172 |
1 files changed, 172 insertions, 0 deletions
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc new file mode 100644 index 0000000..2ee4935 --- /dev/null +++ b/reference_model/src/ops/data_nodes.cc @@ -0,0 +1,172 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "data_nodes.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpConst::OpConst(uint64_t id_) + : GraphNode(Op_CONST, id_) +{ + setRequiredOperands(0, 1); +} + +OpConst::~OpConst() +{} + +int OpConst::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpConst::eval() +{ + // Evaluation is trivial for constants + return GraphNode::eval(); +} + +OpPlaceholder::OpPlaceholder(uint64_t id_) + : GraphNode(Op_PLACEHOLDER, id_) +{ + setRequiredOperands(0, 1); +} + +OpPlaceholder::~OpPlaceholder() +{} + +int OpPlaceholder::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpPlaceholder::eval() +{ + // Evaluation is trivial for placeholders + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_IDENTITY, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); +} + +template <int Rank, DType Dtype> +OpIdentity<Rank, Dtype>::~OpIdentity() +{} + +template <int Rank, DType Dtype> +int OpIdentity<Rank, Dtype>::checkTensorAttributes() +{ + + if (inputs.size() != outputs.size()) + { + printNodeValidationError("Input and output tensor list lengths are not equal"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + if (in->matchRankTypeShape(*out)) + { + printNodeValidationError("Input and output tensor rank, type, or shape do not match"); + return 1; + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpIdentity<Rank, Dtype>::eval() +{ + out->getTensor() = in->getTensor(); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_IDENTITYN, id_) +{ + setRequiredRank(0, 6); +} + +template <int Rank, DType Dtype> +OpIdentityN<Rank, Dtype>::~OpIdentityN() +{} + +template <int Rank, DType Dtype> +int OpIdentityN<Rank, Dtype>::checkTensorAttributes() +{ + + if (inputs.size() != outputs.size()) + { + printNodeValidationError("Input and output tensor list lengths are not equal"); + return 1; + } + + for (size_t i = 0; i < inputs.size(); i++) + { + ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i])); + outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i])); + + if (ins[i]->matchRankTypeShape(*outs[i])) + { + printNodeValidationError("Input and output tensor rank, type, or shape do not match"); + return 1; + } + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpIdentityN<Rank, Dtype>::eval() +{ + for (size_t i = 0; i < ins.size(); i++) + { + outs[i]->getTensor() = ins[i]->getTensor(); + } + + return GraphNode::eval(); +} + +// template explicit instantiation +// note OpConst and OpPlaceholder are not templated + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL); |