diff options
author | Cathal Corbett <catcor01@e127348.nice.arm.com> | 2021-10-07 11:46:40 +0100 |
---|---|---|
committer | Cathal Corbett <cathal.corbett@arm.com> | 2021-10-08 11:28:35 +0000 |
commit | 521032fd424cf86681eb125afbf5eaee47d8c585 (patch) | |
tree | 65162778f203638f1c039097b8240422f99dad76 /src/armnn/Graph.cpp | |
parent | 723bc3b5d8a911a369eee658631d9f107ea09896 (diff) | |
download | armnn-521032fd424cf86681eb125afbf5eaee47d8c585.tar.gz |
IVGCVSW-6417: Catch AddFullyConnected API error when weights TensorInfo isn't set
* Updated code in Graph.cpp InferTensorInfos() to be more descriptive.
* Added method VerifyConstantLayerSetTensorInfo() in Graph.cpp/hpp
to error when ConstantLayer TensorInfo is not set.
* Updated Optimize() in Network.cpp to call VerifyConstantLayerSetTensorInfo().
* Added unit test with ConstantLayer TensorInfo not
set to catch error in VerifyConstantLayerSetTensorInfo().
* Added comments around method VerifyConstantLayerSetTensorInfo().
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I366596243f7c5823676222e2d0cce1335bc8c325
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r-- | src/armnn/Graph.cpp | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index 7b6f56f8b8..60bf328c9c 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -526,6 +526,33 @@ void Graph::EraseSubgraphLayers(SubgraphView &subgraph) subgraph.Clear(); } +/// For each ConstantLayer in Graph, ensures TensorInfo is set on all output slots. +/// LayerValidationException thrown if no TensorInfo is set. +/// +/// @throws LayerValidationException +void Graph::VerifyConstantLayerSetTensorInfo() const +{ + for (auto&& layer : TopologicalSort()) + { + if(layer->GetType() == armnn::LayerType::Constant) + { + for (auto&& output: layer->GetOutputSlots()) + { + if (!output.IsTensorInfoSet()) + { + std::ostringstream message; + message << "Output slot TensorInfo not set on " + << GetLayerTypeAsCString(layer->GetType()) + << " layer \"" + << layer->GetName() + << "\""; + throw LayerValidationException(message.str()); + } + } + } + } +} + void Graph::InferTensorInfos() { for (auto&& layer : TopologicalSort()) @@ -536,7 +563,9 @@ void Graph::InferTensorInfos() if (source == NULL) { std::ostringstream message; - message << "Input not connected on " + message << "Input slot " + << input.GetSlotIndex() + << " not connected to an output slot on " << GetLayerTypeAsCString(layer->GetType()) << " layer \"" << layer->GetName() @@ -546,13 +575,19 @@ void Graph::InferTensorInfos() if (!source->IsTensorInfoSet()) { - throw LayerValidationException("All inputs must have the TensorInfo set at this point."); + std::ostringstream message; + message << "Output slot TensorInfo not set on " + << GetLayerTypeAsCString(layer->GetType()) + << " layer \"" + << layer->GetName() + << "\""; + throw LayerValidationException(message.str()); } + } - if (layer->m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly) - { - layer->ValidateTensorShapesFromInputs(); - } + if (layer->m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly) + { + layer->ValidateTensorShapesFromInputs(); } } } |