aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Layer.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-07-03 10:12:03 +0100
committerJim Flynn <jim.flynn@arm.com>2020-07-26 15:42:26 +0000
commitf24effa4995ea4c3dd91e33d4a2787e02decf8b4 (patch)
tree56e0f22cab0fd8544693b9240bd8d74426eaa454 /src/armnn/Layer.cpp
parent8398edcfb933b638ddf4b88d84d6e188c49b1e0d (diff)
downloadarmnn-f24effa4995ea4c3dd91e33d4a2787e02decf8b4.tar.gz
IVGCVSW-5155 Update Arm NN API to allow for call to shape inference
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I0a2babe5b5b09eb81c9900dc3a05071034a0440b
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r--src/armnn/Layer.cpp28
1 files changed, 6 insertions, 22 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp
index 692ee32acd..dc211b7f2f 100644
--- a/src/armnn/Layer.cpp
+++ b/src/armnn/Layer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Layer.hpp"
@@ -67,6 +67,10 @@ const TensorInfo& OutputSlot::GetTensorInfo() const
bool OutputSlot::IsTensorInfoSet() const
{
+ if (GetOwningLayer().GetShapeInferenceMethod() == ShapeInferenceMethod::InferAndValidate)
+ {
+ GetOwningLayer().ValidateTensorShapesFromInputs();
+ }
return GetOutputHandler().IsTensorInfoSet();
}
@@ -191,6 +195,7 @@ Layer::Layer(unsigned int numInputSlots,
DataLayout layout,
const char* name)
: m_OutputHandlers(numOutputSlots)
+, m_ShapeInferenceMethod(ShapeInferenceMethod::ValidateOnly)
, m_LayerName(name ? name : "")
, m_Type(type)
, m_BackendId()
@@ -354,18 +359,6 @@ void Layer::VerifyLayerConnections(unsigned int expectedConnections, const Check
% GetNameStr()
% location.AsString()));
}
- if(! GetInputSlot(i).GetConnection()->IsTensorInfoSet())
- {
- throw LayerValidationException(
- boost::str(
- boost::format(
- "TensorInfo of Input connection #%1% must be set on connected OutputSlot for "
- "%2% layer %3% %4%")
- % i
- % GetLayerTypeAsCString(this->GetType())
- % GetNameStr()
- % location.AsString()));
- }
}
}
@@ -448,15 +441,6 @@ void Layer::VerifyShapeInferenceType(const TensorShape& outputShape, ShapeInfere
outputShape.AreAllDimensionsSpecified(),
"Unspecified dimension while using ShapeInferenceMethod::ValidateOnly");
}
- else
- {
- if (outputShape.GetDimensionality() == Dimensionality::Specified)
- {
- ConditionalThrow<LayerValidationException>(
- !outputShape.AreAllDimensionsSpecified(),
- "No unspecified dimension while using ShapeInferenceMethod::InferAndValidate");
- }
- }
}
void Layer::SerializeLayerParameters(ParameterStringifyFunction& fn) const