diff options
Diffstat (limited to 'src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp | 214 |
1 files changed, 150 insertions, 64 deletions
diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp index 75f027e32d..68a0d8e2c2 100644 --- a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp +++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "UnidirectionalSequenceLstmLayer.hpp" @@ -150,7 +150,9 @@ UnidirectionalSequenceLstmLayer* UnidirectionalSequenceLstmLayer::Clone(Graph& g std::vector<TensorShape> UnidirectionalSequenceLstmLayer::InferOutputShapes( const std::vector<TensorShape>& inputShapes) const { - ARMNN_ASSERT(inputShapes.size() == 3); + ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(inputShapes.size() == 3, + "inputShapes' size is \"" + std::to_string(inputShapes.size()) + + "\" - should be \"3\"."); // Get input values for validation unsigned int outputSize = inputShapes[1][1]; @@ -181,94 +183,178 @@ void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs() GetInputSlot(2).GetTensorInfo().GetShape() }); - ARMNN_ASSERT(inferredShapes.size() == 1); + if (inferredShapes.size() != 1) + { + throw armnn::LayerValidationException("inferredShapes has " + + std::to_string(inferredShapes.size()) + + " elements - should only have 1."); + } // Check if the weights are nullptr - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights " - "should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights " - "should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_CellBias should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr, - "UnidirectionalSequenceLstmLayer: m_BasicParameters.m_OutputGateBias should not be null."); + if (!m_BasicParameters.m_InputToForgetWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_InputToForgetWeights should not be null."); + } + + if (!m_BasicParameters.m_InputToCellWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_InputToCellWeights should not be null."); + } + + if (!m_BasicParameters.m_InputToOutputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_InputToOutputWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToForgetWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_RecurrentToForgetWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToCellWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_RecurrentToCellWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToOutputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_RecurrentToOutputWeights should not be null."); + } + + if (!m_BasicParameters.m_ForgetGateBias) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_ForgetGateBias should not be null."); + } + + if (!m_BasicParameters.m_CellBias) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_CellBias should not be null."); + } + + if (!m_BasicParameters.m_OutputGateBias) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_BasicParameters.m_OutputGateBias should not be null."); + } if (!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights " - "should not be null."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr, - "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not be null."); + if (!m_CifgParameters.m_InputToInputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_CifgParameters.m_InputToInputWeights should not be null."); + } + + if (!m_CifgParameters.m_RecurrentToInputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_CifgParameters.m_RecurrentToInputWeights should not be null."); + } + + if (!m_CifgParameters.m_InputGateBias) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_CifgParameters.m_InputGateBias should not be null."); + } } else { - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr, - "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value " - "when CIFG is enabled."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, - "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value " - "when CIFG is enabled."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, - "UnidirectionalSequenceLstmLayer: m_CifgParameters.m_InputGateBias should not have a value " - "when CIFG is enabled."); + if (m_CifgParameters.m_InputToInputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_CifgParameters.m_InputToInputWeights should not have a value " + "when CIFG is enabled."); + } + + if (m_CifgParameters.m_RecurrentToInputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_CifgParameters.m_RecurrentToInputWeights should not have a value " + "when CIFG is enabled."); + } + + if (m_CifgParameters.m_InputGateBias) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_CifgParameters.m_InputGateBias should not have a value " + "when CIFG is enabled."); + } } if (m_Param.m_ProjectionEnabled) { - ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_ProjectionParameters.m_ProjectionWeights " - "should not be null."); + if (!m_ProjectionParameters.m_ProjectionWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_ProjectionParameters.m_ProjectionWeights should not be null."); + } } if (m_Param.m_PeepholeEnabled) { if (!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToInputWeights " - "should not be null " - "when Peephole is enabled and CIFG is disabled."); + if (!m_PeepholeParameters.m_CellToInputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_PeepholeParameters.m_CellToInputWeights should not be null " + "when Peephole is enabled and CIFG is disabled."); + } + } + + if (!m_PeepholeParameters.m_CellToForgetWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_PeepholeParameters.m_CellToForgetWeights should not be null."); + } + + if (!m_PeepholeParameters.m_CellToOutputWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_PeepholeParameters.m_CellToOutputWeights should not be null."); } - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToForgetWeights " - "should not be null."); - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_PeepholeParameters.m_CellToOutputWeights " - "should not be null."); } if (m_Param.m_LayerNormEnabled) { if(!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_inputLayerNormWeights " - "should not be null."); + if (!m_LayerNormParameters.m_InputLayerNormWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_LayerNormParameters.m_inputLayerNormWeights " + "should not be null."); + } + } + + if (!m_LayerNormParameters.m_ForgetLayerNormWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_LayerNormParameters.m_forgetLayerNormWeights " + "should not be null."); + } + + if (!m_LayerNormParameters.m_CellLayerNormWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_LayerNormParameters.m_cellLayerNormWeights " + "should not be null."); + } + + if (!m_LayerNormParameters.m_OutputLayerNormWeights) + { + throw armnn::LayerValidationException("UnidirectionalSequenceLstmLayer: " + "m_LayerNormParameters.m_outputLayerNormWeights " + "should not be null."); } - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights " - "should not be null."); - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_cellLayerNormWeights " - "should not be null."); - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr, - "UnidirectionalSequenceLstmLayer: m_LayerNormParameters.m_outputLayerNormWeights " - "should not be null."); } ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "UnidirectionalSequenceLstmLayer"); |