diff options
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r-- | src/armnn/Layer.cpp | 28 |
1 files changed, 6 insertions, 22 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp index 692ee32acd..dc211b7f2f 100644 --- a/src/armnn/Layer.cpp +++ b/src/armnn/Layer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Layer.hpp" @@ -67,6 +67,10 @@ const TensorInfo& OutputSlot::GetTensorInfo() const bool OutputSlot::IsTensorInfoSet() const { + if (GetOwningLayer().GetShapeInferenceMethod() == ShapeInferenceMethod::InferAndValidate) + { + GetOwningLayer().ValidateTensorShapesFromInputs(); + } return GetOutputHandler().IsTensorInfoSet(); } @@ -191,6 +195,7 @@ Layer::Layer(unsigned int numInputSlots, DataLayout layout, const char* name) : m_OutputHandlers(numOutputSlots) +, m_ShapeInferenceMethod(ShapeInferenceMethod::ValidateOnly) , m_LayerName(name ? name : "") , m_Type(type) , m_BackendId() @@ -354,18 +359,6 @@ void Layer::VerifyLayerConnections(unsigned int expectedConnections, const Check % GetNameStr() % location.AsString())); } - if(! GetInputSlot(i).GetConnection()->IsTensorInfoSet()) - { - throw LayerValidationException( - boost::str( - boost::format( - "TensorInfo of Input connection #%1% must be set on connected OutputSlot for " - "%2% layer %3% %4%") - % i - % GetLayerTypeAsCString(this->GetType()) - % GetNameStr() - % location.AsString())); - } } } @@ -448,15 +441,6 @@ void Layer::VerifyShapeInferenceType(const TensorShape& outputShape, ShapeInfere outputShape.AreAllDimensionsSpecified(), "Unspecified dimension while using ShapeInferenceMethod::ValidateOnly"); } - else - { - if (outputShape.GetDimensionality() == Dimensionality::Specified) - { - ConditionalThrow<LayerValidationException>( - !outputShape.AreAllDimensionsSpecified(), - "No unspecified dimension while using ShapeInferenceMethod::InferAndValidate"); - } - } } void Layer::SerializeLayerParameters(ParameterStringifyFunction& fn) const |