diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-07-03 10:12:03 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-07-26 15:42:26 +0000 |
commit | f24effa4995ea4c3dd91e33d4a2787e02decf8b4 (patch) | |
tree | 56e0f22cab0fd8544693b9240bd8d74426eaa454 /src/armnn/Graph.hpp | |
parent | 8398edcfb933b638ddf4b88d84d6e188c49b1e0d (diff) | |
download | armnn-f24effa4995ea4c3dd91e33d4a2787e02decf8b4.tar.gz |
IVGCVSW-5155 Update Arm NN API to allow for call to shape inference
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I0a2babe5b5b09eb81c9900dc3a05071034a0440b
Diffstat (limited to 'src/armnn/Graph.hpp')
-rw-r--r-- | src/armnn/Graph.hpp | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp index 9673df49a0..87e0da826f 100644 --- a/src/armnn/Graph.hpp +++ b/src/armnn/Graph.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -93,7 +93,11 @@ public: const Graph& m_Graph; }; - Graph() : m_LayersInOrder(true) {} + Graph(bool shapeInferenceMethod = false) + : m_LayersInOrder(true) + , m_ShapeInferenceMethod(shapeInferenceMethod ? ShapeInferenceMethod::InferAndValidate : + ShapeInferenceMethod::ValidateOnly) + {} Graph(const Graph& other); @@ -200,7 +204,7 @@ public: void SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer); void SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph); - void InferTensorInfos(ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly); + void InferTensorInfos(); void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) { m_Views[notifyOnEvent].emplace_back(observable); @@ -260,6 +264,7 @@ private: mutable bool m_LayersInOrder; std::map<const GraphEvent, std::list<IGraphObservable*>> m_Views; + ShapeInferenceMethod m_ShapeInferenceMethod; }; /// Common base class for layers in the graph. @@ -401,6 +406,8 @@ inline LayerT* Graph::AddLayer(Args&&... args) ((LayerEnumOf<LayerT>() == LayerType::Input) || (LayerEnumOf<LayerT>() == LayerType::Output)); LayerT* const layer = new LayerInGraph<LayerT>(*this, std::forward<Args>(args)...); + layer->SetShapeInferenceMethod(m_ShapeInferenceMethod); + NotifyObservables(GraphEvent::LayerAdded, layer); return layer; |