diff options
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 9673df49a0..87e0da826f 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -93,7 +93,11 @@ public: const Graph& m_Graph; }; - Graph() : m_LayersInOrder(true) {} + Graph(bool shapeInferenceMethod = false) + : m_LayersInOrder(true) + , m_ShapeInferenceMethod(shapeInferenceMethod ? ShapeInferenceMethod::InferAndValidate : + ShapeInferenceMethod::ValidateOnly) + {} Graph(const Graph& other); @@ -200,7 +204,7 @@ public: void SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer); void SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph); - void InferTensorInfos(ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly); + void InferTensorInfos(); void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) { m_Views[notifyOnEvent].emplace_back(observable); @@ -260,6 +264,7 @@ private: mutable bool m_LayersInOrder; std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views; + ShapeInferenceMethod m_ShapeInferenceMethod; }; /// Common base class for layers in the graph. @@ -401,6 +406,8 @@ inline LayerT* Graph::AddLayer(Args&&... args) ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output)); LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...); + layer->SetShapeInferenceMethod(m_ShapeInferenceMethod); + NotifyObservables(GraphEvent::LayerAdded, layer); return layer; |