aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.h')
-rw-r--r--reference_model/src/subgraph_traverser.h9
1 files changed, 8 insertions, 1 deletions
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index ef6ea42..d6b0e8d 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -74,10 +74,14 @@ public:
int getNumOutputTensors() const;
Tensor* getOutputTensor(const unsigned int idx) const;
Tensor* getOutputTensorByName(const std::string name) const;
+ int getNumVariableTensors() const;
+ Tensor* getVariableTensor(const unsigned int idx) const;
+ Tensor* getVariableTensorByName(const std::string name) const;
+ int registerVariableTensor(Tensor* tensor);
int addToNextNodeList(GraphNode*);
private:
- int addTensor(Tensor* ct);
+ int addTensor(const TosaSerializationTensor* ts);
int addNode(GraphNode* cn);
Tensor* findTensorByName(const std::string& name) const;
@@ -103,6 +107,9 @@ private:
// The subset of tensors that are also output tensors
std::vector<Tensor*> outputTensors;
+ // The subset of tensors that are also variable tensors
+ std::vector<Tensor*> variableTensors;
+
// The definitive list of all nodes in the graph
std::vector<GraphNode*> nodes;