aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Graph.cpp
diff options
context:
space:
mode:
authorCathal Corbett <catcor01@e127348.nice.arm.com>2021-10-07 11:46:40 +0100
committerCathal Corbett <cathal.corbett@arm.com>2021-10-08 11:28:35 +0000
commit521032fd424cf86681eb125afbf5eaee47d8c585 (patch)
tree65162778f203638f1c039097b8240422f99dad76 /src/armnn/Graph.cpp
parent723bc3b5d8a911a369eee658631d9f107ea09896 (diff)
downloadarmnn-521032fd424cf86681eb125afbf5eaee47d8c585.tar.gz
IVGCVSW-6417: Catch AddFullyConnected API error when weights TensorInfo isn't set
* Updated code in Graph.cpp InferTensorInfos() to be more descriptive. * Added method VerifyConstantLayerSetTensorInfo() in Graph.cpp/hpp to error when ConstantLayer TensorInfo is not set. * Updated Optimize() in Network.cpp to call VerifyConstantLayerSetTensorInfo(). * Added unit test with ConstantLayer TensorInfo not set to catch error in VerifyConstantLayerSetTensorInfo(). * Added comments around method VerifyConstantLayerSetTensorInfo(). Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I366596243f7c5823676222e2d0cce1335bc8c325
Diffstat (limited to 'src/armnn/Graph.cpp')
-rw-r--r--src/armnn/Graph.cpp47
1 files changed, 41 insertions, 6 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 7b6f56f8b8..60bf328c9c 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -526,6 +526,33 @@ void Graph::EraseSubgraphLayers(SubgraphView &subgraph)
subgraph.Clear();
}
+/// For each ConstantLayer in Graph, ensures TensorInfo is set on all output slots.
+/// LayerValidationException thrown if no TensorInfo is set.
+///
+/// @throws LayerValidationException
+void Graph::VerifyConstantLayerSetTensorInfo() const
+{
+ for (auto&& layer : TopologicalSort())
+ {
+ if(layer->GetType() == armnn::LayerType::Constant)
+ {
+ for (auto&& output: layer->GetOutputSlots())
+ {
+ if (!output.IsTensorInfoSet())
+ {
+ std::ostringstream message;
+ message << "Output slot TensorInfo not set on "
+ << GetLayerTypeAsCString(layer->GetType())
+ << " layer \""
+ << layer->GetName()
+ << "\"";
+ throw LayerValidationException(message.str());
+ }
+ }
+ }
+ }
+}
+
void Graph::InferTensorInfos()
{
for (auto&& layer : TopologicalSort())
@@ -536,7 +563,9 @@ void Graph::InferTensorInfos()
if (source == NULL)
{
std::ostringstream message;
- message << "Input not connected on "
+ message << "Input slot "
+ << input.GetSlotIndex()
+ << " not connected to an output slot on "
<< GetLayerTypeAsCString(layer->GetType())
<< " layer \""
<< layer->GetName()
@@ -546,13 +575,19 @@ void Graph::InferTensorInfos()
if (!source->IsTensorInfoSet())
{
- throw LayerValidationException("All inputs must have the TensorInfo set at this point.");
+ std::ostringstream message;
+ message << "Output slot TensorInfo not set on "
+ << GetLayerTypeAsCString(layer->GetType())
+ << " layer \""
+ << layer->GetName()
+ << "\"";
+ throw LayerValidationException(message.str());
}
+ }
- if (layer->m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
- {
- layer->ValidateTensorShapesFromInputs();
- }
+ if (layer->m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
+ {
+ layer->ValidateTensorShapesFromInputs();
}
}
}