diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-07-03 10:12:03 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-07-26 15:42:26 +0000 |
commit | f24effa4995ea4c3dd91e33d4a2787e02decf8b4 (patch) | |
tree | 56e0f22cab0fd8544693b9240bd8d74426eaa454 /src/armnn/Layer.cpp | |
parent | 8398edcfb933b638ddf4b88d84d6e188c49b1e0d (diff) | |
download | armnn-f24effa4995ea4c3dd91e33d4a2787e02decf8b4.tar.gz |
IVGCVSW-5155 Update Arm NN API to allow for call to shape inference
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I0a2babe5b5b09eb81c9900dc3a05071034a0440b
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r-- | src/armnn/Layer.cpp | 28 |
1 files changed, 6 insertions, 22 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp index 692ee32acd..dc211b7f2f 100644 --- a/src/armnn/Layer.cpp +++ b/src/armnn/Layer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Layer.hpp" @@ -67,6 +67,10 @@ const TensorInfo& OutputSlot::GetTensorInfo() const bool OutputSlot::IsTensorInfoSet() const { + if (GetOwningLayer().GetShapeInferenceMethod() == ShapeInferenceMethod::InferAndValidate) + { + GetOwningLayer().ValidateTensorShapesFromInputs(); + } return GetOutputHandler().IsTensorInfoSet(); } @@ -191,6 +195,7 @@ Layer::Layer(unsigned int numInputSlots, DataLayout layout, const char* name) : m_OutputHandlers(numOutputSlots) +, m_ShapeInferenceMethod(ShapeInferenceMethod::ValidateOnly) , m_LayerName(name ? name : "") , m_Type(type) , m_BackendId() @@ -354,18 +359,6 @@ void Layer::VerifyLayerConnections(unsigned int expectedConnections, const Check % GetNameStr() % location.AsString())); } - if(! GetInputSlot(i).GetConnection()->IsTensorInfoSet()) - { - throw LayerValidationException( - boost::str( - boost::format( - "TensorInfo of Input connection #%1% must be set on connected OutputSlot for " - "%2% layer %3% %4%") - % i - % GetLayerTypeAsCString(this->GetType()) - % GetNameStr() - % location.AsString())); - } } } @@ -448,15 +441,6 @@ void Layer::VerifyShapeInferenceType(const TensorShape& outputShape, ShapeInfere outputShape.AreAllDimensionsSpecified(), "Unspecified dimension while using ShapeInferenceMethod::ValidateOnly"); } - else - { - if (outputShape.GetDimensionality() == Dimensionality::Specified) - { - ConditionalThrow<LayerValidationException>( - !outputShape.AreAllDimensionsSpecified(), - "No unspecified dimension while using ShapeInferenceMethod::InferAndValidate"); - } - } } void Layer::SerializeLayerParameters(ParameterStringifyFunction& fn) const |