diff options
Diffstat (limited to 'src/armnn/layers/FullyConnectedLayer.cpp')
-rw-r--r-- | src/armnn/layers/FullyConnectedLayer.cpp | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp index bd947b7678..174459b565 100644 --- a/src/armnn/layers/FullyConnectedLayer.cpp +++ b/src/armnn/layers/FullyConnectedLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "FullyConnectedLayer.hpp" @@ -65,21 +65,20 @@ void FullyConnectedLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod sh { IgnoreUnused(shapeInferenceMethod); - VerifyLayerConnections(1, CHECK_LOCATION()); + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, shapeInferenceMethod); // check if we m_Weight data is not nullptr ARMNN_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null."); - auto inferredShapes = InferOutputShapes({ - GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), - m_Weight->GetTensorInfo().GetShape() }); + auto inferredShapes = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + m_Weight->GetTensorInfo().GetShape() }); ARMNN_ASSERT(inferredShapes.size() == 1); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); - ConditionalThrowIfNotEqual<LayerValidationException>( - "FullyConnectedLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", - GetOutputSlot(0).GetTensorInfo().GetShape(), - inferredShapes[0]); + ValidateAndCopyShape(outputShape, inferredShapes[0], shapeInferenceMethod, "FullyConnectedLayer"); } Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef() |