aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r--src/armnn/Graph.hpp13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 9673df49a0..87e0da826f 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -93,7 +93,11 @@ public:
const Graph& m_Graph;
};
- Graph() : m_LayersInOrder(true) {}
+ Graph(bool shapeInferenceMethod = false)
+ : m_LayersInOrder(true)
+ , m_ShapeInferenceMethod(shapeInferenceMethod ? ShapeInferenceMethod::InferAndValidate :
+ ShapeInferenceMethod::ValidateOnly)
+ {}
Graph(const Graph& other);
@@ -200,7 +204,7 @@ public:
void SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer);
void SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph);
- void InferTensorInfos(ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly);
+ void InferTensorInfos();
void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
m_Views[notifyOnEvent].emplace_back(observable);
@@ -260,6 +264,7 @@ private:
mutable bool m_LayersInOrder;
std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views;
+ ShapeInferenceMethod m_ShapeInferenceMethod;
};
/// Common base class for layers in the graph.
@@ -401,6 +406,8 @@ inline LayerT* Graph::AddLayer(Args&&... args)
((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output));
LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...);
+ layer->SetShapeInferenceMethod(m_ShapeInferenceMethod);
+
NotifyObservables(GraphEvent::LayerAdded, layer);
return layer;