aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.h
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2020-10-13 16:11:07 -0700
committerKevin Cheng <kevin.cheng@arm.com>2020-10-14 11:11:43 -0700
commite5e2676409a936431f87d31fb74d825257b20804 (patch)
tree304d93d993ef6417b02a515025f9030367682774 /reference_model/src/subgraph_traverser.h
parent88b7860f180f91b5b66764c61cfd97d8bc53cece (diff)
downloadreference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Diffstat (limited to 'reference_model/src/subgraph_traverser.h')
-rw-r--r--reference_model/src/subgraph_traverser.h90
1 files changed, 90 insertions, 0 deletions
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
new file mode 100644
index 0000000..3f4eecf
--- /dev/null
+++ b/reference_model/src/subgraph_traverser.h
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SUBGRAPH_TRAVERSER_H
+#define SUBGRAPH_TRAVERSER_H
+
+#include "model_common.h"
+
+#include "graph_node.h"
+#include "ops/op_factory.h"
+#include "tosa_serialization_handler.h"
+
+namespace TosaReference
+{
+
+class SubgraphTraverser
+{
+public:
+ SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh);
+ ~SubgraphTraverser();
+
+ int initializeGraph();
+ int isFullyEvaluated() const;
+ int evaluateNextNode();
+ int evaluateAll();
+
+ int linkTensorsAndNodes();
+ int validateGraph();
+
+ int dumpGraph(FILE* out) const;
+ int dumpNextNodeList(FILE* out) const;
+ int clearAllNodeMarkings();
+
+ int getNumInputTensors() const;
+ Tensor* getInputTensor(const unsigned int idx) const;
+ Tensor* getInputTensorByName(const std::string name) const;
+ int getNumOutputTensors() const;
+ Tensor* getOutputTensor(const unsigned int idx) const;
+ Tensor* getOutputTensorByName(const std::string name) const;
+ int addToNextNodeList(GraphNode*);
+
+private:
+ int addTensor(Tensor* ct);
+ int addNode(GraphNode* cn);
+
+ Tensor* findTensorByName(const std::string& name) const;
+
+ GraphNode* getNextNode();
+
+ // pointer to serialization library and corresponding basic block
+ TosaSerializationBasicBlock* block;
+ TosaSerializationHandler* tsh;
+
+ // The definitive list of all tensors
+ std::vector<Tensor*> tensors;
+
+ // The subset of tensors that are also input tensors
+ std::vector<Tensor*> inputTensors;
+
+ // The subset of tensors that are also output tensors
+ std::vector<Tensor*> outputTensors;
+
+ // The definitive list of all nodes in the graph
+ std::vector<GraphNode*> nodes;
+
+ // The subset of node that have all of their input tensors ready, but
+ // have not yet been evaluated to produce their output tensors.
+ // With control flow, a node may appear on this list more than once during its
+ // lifetime, although the list itself should only contain unique nodes.
+ std::list<GraphNode*> nextNodeList;
+
+ // Maximum number of times to evalute a node before
+ // warning.
+ const int MAX_EVAL_COUNT = 10000;
+};
+}; // namespace TosaReference
+
+#endif