diff options
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/LstmLayer.cpp | 198 |
1 files changed, 148 insertions, 50 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 0e6f3d882b..d87ad6461e 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "LstmLayer.hpp" @@ -149,7 +149,11 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - ARMNN_ASSERT(inputShapes.size() == 3); + if (inputShapes.size() != 3) + { + throw armnn::Exception("inputShapes' size is \"" + std::to_string(inputShapes.size()) + + "\" - should be \"3\"."); + } // Get input values for validation unsigned int batchSize = inputShapes[0][0]; @@ -179,69 +183,148 @@ void LstmLayer::ValidateTensorShapesFromInputs() GetInputSlot(2).GetTensorInfo().GetShape() }); - ARMNN_ASSERT(inferredShapes.size() == 4); + if (inferredShapes.size() != 4) + { + throw armnn::Exception("inferredShapes has " + + std::to_string(inferredShapes.size()) + + " element(s) - should only have 4."); + } // Check if the weights are nullptr - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr, - "LstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr, - "LstmLayer: m_BasicParameters.m_InputToCellWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr, - "LstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr, - "LstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr, - "LstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr, - "LstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr, - "LstmLayer: m_BasicParameters.m_ForgetGateBias should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr, - "LstmLayer: m_BasicParameters.m_CellBias should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr, - "LstmLayer: m_BasicParameters.m_OutputGateBias should not be null."); + if (!m_BasicParameters.m_InputToForgetWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_InputToForgetWeights should not be null."); + } + + if (!m_BasicParameters.m_InputToCellWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_InputToCellWeights should not be null."); + } + + if (!m_BasicParameters.m_InputToOutputWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_InputToOutputWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToForgetWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_RecurrentToForgetWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToCellWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_RecurrentToCellWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToOutputWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_RecurrentToOutputWeights should not be null."); + } + + if (!m_BasicParameters.m_ForgetGateBias) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_ForgetGateBias should not be null."); + } + + if (!m_BasicParameters.m_CellBias) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_CellBias should not be null."); + } + + if (!m_BasicParameters.m_OutputGateBias) + { + throw armnn::NullPointerException("LstmLayer: " + "m_BasicParameters.m_OutputGateBias should not be null."); + } if (!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr, - "LstmLayer: m_CifgParameters.m_InputToInputWeights should not be null."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr, - "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr, - "LstmLayer: m_CifgParameters.m_InputGateBias should not be null."); + if (!m_CifgParameters.m_InputToInputWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_CifgParameters.m_InputToInputWeights should not be null."); + } + + if (!m_CifgParameters.m_RecurrentToInputWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_CifgParameters.m_RecurrentToInputWeights should not be null."); + } + + if (!m_CifgParameters.m_InputGateBias) + { + throw armnn::NullPointerException("LstmLayer: " + "m_CifgParameters.m_InputGateBias should not be null."); + } ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer"); } else { - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr, - "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, - "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, - "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled."); + if (m_CifgParameters.m_InputToInputWeights) + { + throw armnn::Exception("LstmLayer: " + "m_CifgParameters.m_InputToInputWeights should not have a value " + "when CIFG is enabled."); + } + + if (m_CifgParameters.m_RecurrentToInputWeights) + { + throw armnn::Exception("LstmLayer: " + "m_CifgParameters.m_RecurrentToInputWeights should not have a value " + "when CIFG is enabled."); + } + + if (m_CifgParameters.m_InputGateBias) + { + throw armnn::Exception("LstmLayer: " + "m_CifgParameters.m_InputGateBias should not have a value " + "when CIFG is enabled."); + } ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer"); } if (m_Param.m_ProjectionEnabled) { - ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr, - "LstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null."); + if (!m_ProjectionParameters.m_ProjectionWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_ProjectionParameters.m_ProjectionWeights should not be null."); + } } if (m_Param.m_PeepholeEnabled) { if (!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, - "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null " - "when Peephole is enabled and CIFG is disabled."); + if (!m_PeepholeParameters.m_CellToInputWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_PeepholeParameters.m_CellToInputWeights should not be null " + "when Peephole is enabled and CIFG is disabled."); + } + } + + if (!m_PeepholeParameters.m_CellToForgetWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_PeepholeParameters.m_CellToForgetWeights should not be null."); + } + + if (!m_PeepholeParameters.m_CellToOutputWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_PeepholeParameters.m_CellToOutputWeights should not be null."); } - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, - "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, - "LstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null."); } ValidateAndCopyShape( @@ -255,15 +338,30 @@ void LstmLayer::ValidateTensorShapesFromInputs() { if(!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr, - "LstmLayer: m_LayerNormParameters.m_inputLayerNormWeights should not be null."); + if (!m_LayerNormParameters.m_InputLayerNormWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_LayerNormParameters.m_inputLayerNormWeights should not be null."); + } + } + + if (!m_LayerNormParameters.m_ForgetLayerNormWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_LayerNormParameters.m_forgetLayerNormWeights should not be null."); + } + + if (!m_LayerNormParameters.m_CellLayerNormWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_LayerNormParameters.m_cellLayerNormWeights should not be null."); + } + + if (!m_LayerNormParameters.m_OutputLayerNormWeights) + { + throw armnn::NullPointerException("LstmLayer: " + "m_LayerNormParameters.m_outputLayerNormWeights should not be null."); } - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr, - "LstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights should not be null."); - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr, - "LstmLayer: m_LayerNormParameters.m_cellLayerNormWeights should not be null."); - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr, - "LstmLayer: m_LayerNormParameters.m_outputLayerNormWeights should not be null."); } } |