From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- reference_model/src/graph_node.h | 354 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 reference_model/src/graph_node.h (limited to 'reference_model/src/graph_node.h') diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h new file mode 100644 index 0000000..5b4a767 --- /dev/null +++ b/reference_model/src/graph_node.h @@ -0,0 +1,354 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRAPH_NODE_H +#define GRAPH_NODE_H + +#include "attribute.h" +#include "quant_info.h" +#include "tensor.h" +#include "tosa_generated.h" +#include + +#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP; + +#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ + template class TosaReference::OP; + +#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ + template class TosaReference::OP; + +#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ + template class TosaReference::OP; + +#define DEF_INSTANTIATE_ONE_RANK_0_6(OP) \ + template class TosaReference::OP<0>; \ + template class TosaReference::OP<1>; \ + template class TosaReference::OP<2>; \ + template class TosaReference::OP<3>; \ + template class TosaReference::OP<4>; \ + template class TosaReference::OP<5>; \ + template class TosaReference::OP<6>; + +#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP; + +#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP; + +#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) + +#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) + +#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) + +#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) + +#define DEF_INSTANTIATE_GATHER(OP, DTYPE) \ + /* gather op takes input and index rank as template argument */ \ + /* note output rank = input rank - 1 + index rank */ \ + /* and max rank allowed in tosa_reference is 6 */ \ + /* so only specific input and index pair is instantiated */ \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) + +#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \ + if (auto p = dynamic_cast(attribute_)) \ + { \ + attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \ + ASSERT_MEM(attribute); \ + } \ + else \ + { \ + FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \ + } + +#define INIT_QINFO(QINFO_NAME) \ + if (auto p = dynamic_cast(qinfo_)) \ + { \ + qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \ + ASSERT_MEM(qinfo); \ + } \ + else \ + { \ + qinfo = nullptr; \ + } + +namespace TosaReference +{ + +// Nodes in the graph (e.g., tosa operators) are defined with this base +// class. +class GraphNode +{ +public: + GraphNode(const tosa::Op& nodeType, const uint64_t id_); + virtual ~GraphNode(); + + int addInputName(std::string& name); + int addOutputName(std::string& name); + + int addInputTensor(Tensor* tens); + int addOutputTensor(Tensor* tens); + + // Validate that the input tensors match properly + // in their types, attributes, rank, etc well enough to be + // processed. + // + // This function should be pure virtual (eventually) in order to force + // derivative operators to implement the check, but we'll initially + // provide a default function so that GraphNode can be instantiated + // directly for testing purposes. + virtual int checkTensorAttributes(); + + // Evalute the node/operator + virtual int eval(); + + int hasAllInputsReady() const; + int hasAllOutputsReady() const; + + int dumpNode(FILE* out); + int dumpNode(std::ostream& out); + + int setNodeMarked() + { + isMarked = true; + return 0; + } + + int getNodeMarked() const + { + return isMarked; + } + + int clearNodeMarked() + { + isMarked = false; + return 0; + } + + int getEvalCount() const + { + return evalCount; + } + + uint64_t getID() const + { + return nodeId; + } + + std::vector& getInputNames() + { + return inputNames; + } + + std::vector& getOutputNames() + { + return outputNames; + } + + std::vector& getOutputs() + { + return outputs; + } + + std::vector& getInputs() + { + return inputs; + } + + int getOnNextNodeList() const + { + return onNextNodeList; + } + + int setOnNextNodeList() + { + onNextNodeList = true; + return 0; + } + + int clearOnNextNodeList() + { + onNextNodeList = false; + return 0; + } + + tosa::Op getOp() const + { + return nodeType; + } + +protected: + // Print out a node validation error + int printNodeValidationError(const std::string& msg); + + int setRequiredOperands(const int in, const int out) + { + requiredInputCount = in; + requiredOutputCount = out; + return 0; + } + + int setRequiredRank(const int min, const int max = -1) + { + if (max == -1) + { + requiredRankMin = requiredRankMax = min; + } + else + { + requiredRankMin = min; + requiredRankMax = max; + } + + ASSERT_MSG(requiredRankMin <= requiredRankMax, + "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin, + requiredRankMax); + + return 0; + } + + int validateRequiredOperands(); + int validateRequiredRank(const Tensor* t); + + // Description of the node type (e.g., CONST, CONV2D, etc...) + tosa::Op nodeType; + + // A list of input tensor names + std::vector inputNames; + + // A list of the output tensor names + std::vector outputNames; + + // A list of the input tensors (after names have been matched up) + std::vector inputs; + + // A list of the output tensors (after names have been matched up) + std::vector outputs; + + // Unique node ID for debugging + uint64_t nodeId; + + // Flag used for graph analysis + int isMarked; + + // Number of times eval() has been called for this node + int evalCount; + + // Flag indicating that this node is ready and is on the + // next-node list. + int onNextNodeList; + + // Required input/output tensor counts for node validation + // -1 means any number is allowed + int requiredInputCount; + int requiredOutputCount; + + // Required rank ranges for input/output tensors + // -1 means n/a + int requiredRankMin; + int requiredRankMax; +}; + +}; // namespace TosaReference + +#endif -- cgit v1.2.1