aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/LstmLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r--src/armnn/layers/LstmLayer.cpp40
1 files changed, 15 insertions, 25 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index af708e4e06..44f5d1f40b 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "LstmLayer.hpp"
@@ -165,15 +165,17 @@ std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorSh
void LstmLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferenceMethod)
{
- IgnoreUnused(shapeInferenceMethod);
-
VerifyLayerConnections(3, CHECK_LOCATION());
+ const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
+
+ VerifyShapeInferenceType(outputShape, shapeInferenceMethod);
+
auto inferredShapes = InferOutputShapes( {
GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
- GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()}
- );
+ GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()
+ });
ARMNN_ASSERT(inferredShapes.size() == 4);
@@ -206,10 +208,7 @@ void LstmLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferen
ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr,
"LstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
- ConditionalThrowIfNotEqual<LayerValidationException>(
- "LstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
- GetOutputSlot(0).GetTensorInfo().GetShape(),
- inferredShapes[0]);
+ ValidateAndCopyShape(outputShape, inferredShapes[0], shapeInferenceMethod, "LstmLayer");
}
else
{
@@ -220,10 +219,7 @@ void LstmLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferen
ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr,
"LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
- ConditionalThrowIfNotEqual<LayerValidationException>(
- "LstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
- GetOutputSlot(0).GetTensorInfo().GetShape(),
- inferredShapes[0]);
+ ValidateAndCopyShape(outputShape, inferredShapes[0], shapeInferenceMethod, "LstmLayer");
}
if (m_Param.m_ProjectionEnabled)
@@ -246,18 +242,12 @@ void LstmLayer::ValidateTensorShapesFromInputs(ShapeInferenceMethod shapeInferen
"LstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
}
- ConditionalThrowIfNotEqual<LayerValidationException>(
- "LstmLayer: TensorShape set on OutputSlot[1] does not match the inferred shape.",
- GetOutputSlot(1).GetTensorInfo().GetShape(),
- inferredShapes[1]);
- ConditionalThrowIfNotEqual<LayerValidationException>(
- "LstmLayer: TensorShape set on OutputSlot[2] does not match the inferred shape.",
- GetOutputSlot(2).GetTensorInfo().GetShape(),
- inferredShapes[2]);
- ConditionalThrowIfNotEqual<LayerValidationException>(
- "LstmLayer: TensorShape set on OutputSlot[3] does not match the inferred shape.",
- GetOutputSlot(3).GetTensorInfo().GetShape(),
- inferredShapes[3]);
+ ValidateAndCopyShape(
+ GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], shapeInferenceMethod, "LstmLayer", 1);
+ ValidateAndCopyShape(
+ GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], shapeInferenceMethod, "LstmLayer", 2);
+ ValidateAndCopyShape(
+ GetOutputSlot(3).GetTensorInfo().GetShape(), inferredShapes[3], shapeInferenceMethod, "LstmLayer", 3);
if (m_Param.m_LayerNormEnabled)
{