diff options
Diffstat (limited to 'src/armnn/layers/QLstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/QLstmLayer.cpp | 200 |
1 files changed, 148 insertions, 52 deletions
diff --git a/src/armnn/layers/QLstmLayer.cpp b/src/armnn/layers/QLstmLayer.cpp index eeb01db51d..e98deb6a88 100644 --- a/src/armnn/layers/QLstmLayer.cpp +++ b/src/armnn/layers/QLstmLayer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "QLstmLayer.hpp" @@ -152,7 +152,11 @@ QLstmLayer* QLstmLayer::Clone(Graph& graph) const std::vector<TensorShape> QLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - ARMNN_ASSERT(inputShapes.size() == 3); + if (inputShapes.size() != 3) + { + throw armnn::Exception("inputShapes' size is \"" + std::to_string(inputShapes.size()) + + "\" - should be \"3\"."); + } // Get input values for validation unsigned int batchSize = inputShapes[0][0]; @@ -182,70 +186,147 @@ void QLstmLayer::ValidateTensorShapesFromInputs() GetInputSlot(2).GetTensorInfo().GetShape() // previousCellStateIn }); - ARMNN_ASSERT(inferredShapes.size() == 3); + if (inferredShapes.size() != 3) + { + throw armnn::LayerValidationException("inferredShapes has " + + std::to_string(inferredShapes.size()) + + " element(s) - should only have 3."); + } // Check if the weights are nullptr for basic params - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToForgetWeights != nullptr, - "QLstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToCellWeights != nullptr, - "QLstmLayer: m_BasicParameters.m_InputToCellWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_InputToOutputWeights != nullptr, - "QLstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToForgetWeights != nullptr, - "QLstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToCellWeights != nullptr, - "QLstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_RecurrentToOutputWeights != nullptr, - "QLstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_ForgetGateBias != nullptr, - "QLstmLayer: m_BasicParameters.m_ForgetGateBias should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_CellBias != nullptr, - "QLstmLayer: m_BasicParameters.m_CellBias should not be null."); - ARMNN_ASSERT_MSG(m_BasicParameters.m_OutputGateBias != nullptr, - "QLstmLayer: m_BasicParameters.m_OutputGateBias should not be null."); + if (!m_BasicParameters.m_InputToForgetWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_InputToForgetWeights should not be null."); + } + + if (!m_BasicParameters.m_InputToCellWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_InputToCellWeights should not be null."); + } + + if (!m_BasicParameters.m_InputToOutputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_InputToOutputWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToForgetWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_RecurrentToForgetWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToCellWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_RecurrentToCellWeights should not be null."); + } + + if (!m_BasicParameters.m_RecurrentToOutputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_RecurrentToOutputWeights should not be null."); + } + + if (!m_BasicParameters.m_ForgetGateBias) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_ForgetGateBias should not be null."); + } + + if (!m_BasicParameters.m_CellBias) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_CellBias should not be null."); + } + + if (!m_BasicParameters.m_OutputGateBias) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_BasicParameters.m_OutputGateBias should not be null."); + } if (!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights != nullptr, - "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not be null."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights != nullptr, - "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias != nullptr, - "QLstmLayer: m_CifgParameters.m_InputGateBias should not be null."); + if (!m_CifgParameters.m_InputToInputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_CifgParameters.m_InputToInputWeights should not be null."); + } + + if (!m_CifgParameters.m_RecurrentToInputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_CifgParameters.m_RecurrentToInputWeights should not be null."); + } + + if (!m_CifgParameters.m_InputGateBias) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_CifgParameters.m_InputGateBias should not be null."); + } ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer"); } else { - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputToInputWeights == nullptr, - "QLstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_RecurrentToInputWeights == nullptr, - "QLstmLayer: m_CifgParameters.m_RecurrentToInputWeights should " - "not have a value when CIFG is enabled."); - ARMNN_ASSERT_MSG(m_CifgParameters.m_InputGateBias == nullptr, - "QLstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled."); + if (m_CifgParameters.m_InputToInputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_CifgParameters.m_InputToInputWeights " + "should not have a value when CIFG is enabled."); + } + + if (m_CifgParameters.m_RecurrentToInputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_CifgParameters.m_RecurrentToInputWeights " + "should not have a value when CIFG is enabled."); + } + + if (m_CifgParameters.m_InputGateBias) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_CifgParameters.m_InputGateBias " + "should not have a value when CIFG is enabled."); + } ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QLstmLayer"); } if (m_Param.m_ProjectionEnabled) { - ARMNN_ASSERT_MSG(m_ProjectionParameters.m_ProjectionWeights != nullptr, - "QLstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null."); + if (!m_ProjectionParameters.m_ProjectionWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_ProjectionParameters.m_ProjectionWeights should not be null."); + } } if (m_Param.m_PeepholeEnabled) { if (!m_Param.m_CifgEnabled) { - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToInputWeights != nullptr, - "QLstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null " - "when Peephole is enabled and CIFG is disabled."); + if (!m_PeepholeParameters.m_CellToInputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_PeepholeParameters.m_CellToInputWeights should not be null " + "when Peephole is enabled and CIFG is disabled."); + } } - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToForgetWeights != nullptr, - "QLstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null."); - ARMNN_ASSERT_MSG(m_PeepholeParameters.m_CellToOutputWeights != nullptr, - "QLstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null."); + if (!m_PeepholeParameters.m_CellToForgetWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_PeepholeParameters.m_CellToForgetWeights should not be null."); + } + + if (!m_PeepholeParameters.m_CellToOutputWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_PeepholeParameters.m_CellToOutputWeights should not be null."); + } } ValidateAndCopyShape( @@ -255,17 +336,32 @@ void QLstmLayer::ValidateTensorShapesFromInputs() if (m_Param.m_LayerNormEnabled) { - if(!m_Param.m_CifgEnabled) + if (!m_Param.m_CifgEnabled) + { + if (!m_LayerNormParameters.m_InputLayerNormWeights) + { + throw armnn::LayerValidationException("QLstmLayer: m_LayerNormParameters.m_InputLayerNormWeights " + "should not be null."); + } + } + + if (!m_LayerNormParameters.m_ForgetLayerNormWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_LayerNormParameters.m_ForgetLayerNormWeights should not be null."); + } + + if (!m_LayerNormParameters.m_CellLayerNormWeights) + { + throw armnn::LayerValidationException("QLstmLayer: " + "m_LayerNormParameters.m_CellLayerNormWeights should not be null."); + } + + if (!m_LayerNormParameters.m_OutputLayerNormWeights) { - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr, - "QLstmLayer: m_LayerNormParameters.m_InputLayerNormWeights should not be null."); + throw armnn::LayerValidationException("QLstmLayer: " + "m_LayerNormParameters.m_UutputLayerNormWeights should not be null."); } - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr, - "QLstmLayer: m_LayerNormParameters.m_ForgetLayerNormWeights should not be null."); - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr, - "QLstmLayer: m_LayerNormParameters.m_CellLayerNormWeights should not be null."); - ARMNN_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr, - "QLstmLayer: m_LayerNormParameters.m_UutputLayerNormWeights should not be null."); } } |