diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-15 16:16:25 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-22 18:29:55 +0100 |
commit | 8ed39ae450a077c7e4d672b5f05ff1d68ee67aab (patch) | |
tree | 31a1cf006e50db54f3e7a605825c8e9e3f9d689e /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 15fcc7ed3163c9d4b1856955271854198c3c2696 (diff) | |
download | armnn-8ed39ae450a077c7e4d672b5f05ff1d68ee67aab.tar.gz |
MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 276 |
1 files changed, 274 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 8c78136185..3fe0823b03 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1959,7 +1959,6 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); } - // Inferring batch size, number of outputs and number of cells from the inputs. const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0]; @@ -1991,7 +1990,6 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output), descriptorName + " output_3"); - // check that dimensions of inputs/outputs and QueueDescriptor data match with each other if ( m_InputToInputWeights ) { @@ -3741,4 +3739,278 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } +void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm + + const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"}; + + // check dimensions of all inputs and outputs + if (workloadInfo.m_InputTensorInfos.size() != 3) + { + throw InvalidArgumentException(descriptorName + ": Invalid number of inputs."); + } + if (workloadInfo.m_OutputTensorInfos.size() != 1) + { + throw InvalidArgumentException(descriptorName + ": Invalid number of outputs."); + } + + std::vector<DataType> supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QAsymmS8 + }; + + // check for supported type of one input and match them with all the other input and output + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName); + + // type matches all other inputs + for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_InputTensorInfos[i], + descriptorName, + "input_0", + "input_" + std::to_string(i)); + } + // type matches all other outputs + for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[i], + "LstmQueueDescriptor", + "input_0", + "output_" + std::to_string(i)); + } + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + if (m_Parameters.m_ClippingThresCell < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid"); + } + if (m_Parameters.m_ClippingThresProj < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); + } + + unsigned int batchIndx = 0; + unsigned int inputIndx = 1; + uint32_t timeStep = 1; + unsigned int timeIndx = 1; + inputIndx = 2; + if (m_Parameters.m_TimeMajor) + { + batchIndx = 1; + timeIndx = 0; + + } + timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx]; + + // Inferring batch size, number of outputs and number of cells from the inputs. + const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx]; + const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx]; + ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights"); + const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0]; + ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights"); + const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1]; + + // input tensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input), + descriptorName + " input_0"); + // outputStateInTensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output), + descriptorName + " input_1"); + // outputStateInTensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell), + descriptorName + " input_2"); + + // outputTensor + ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output), + descriptorName + " output_0"); + + // check that dimensions of inputs/outputs and QueueDescriptor data match with each other + if ( m_InputToInputWeights ) + { + ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputLayerNormWeights"); + } + + ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights"); + ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights"); + ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToCellWeights"); + + if ( m_RecurrentToInputWeights ) + { + ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToInputWeights"); + } + + ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToCellWeights"); + + // Make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly. + bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights && + !m_Parameters.m_CifgEnabled) || + (!m_InputToInputWeights && !m_RecurrentToInputWeights && + m_Parameters.m_CifgEnabled)); + if (!cifg_weights_all_or_none) + { + throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and " + "RecurrentToInputWeights must either both be present (regular LSTM) " + "or both not present (CIFG-LSTM). In addition CifgEnable must be set " + "accordingly."); + } + + if ( m_CellToInputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1, + n_cell, "CellToInputWeights"); + } + if ( m_CellToForgetWeights ) + { + ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1, + n_cell, "CellToForgetWeights"); + } + if ( m_CellToOutputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1, + n_cell, "CellToOutputWeights"); + } + + // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly. + bool peephole_weights_all_or_none = + (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights + && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled) + || ( !m_CellToInputWeights && !m_CellToForgetWeights + && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled)); + if (!peephole_weights_all_or_none) + { + throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters."); + } + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + if (m_Parameters.m_CifgEnabled) + { + if (m_InputGateBias) + { + throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled."); + } + } + else + { + if (!m_InputGateBias) + { + throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias " + "must be present."); + } + ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1, + n_cell, "InputGateBias"); + } + + ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias"); + ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias"); + + ValidatePointer(m_CellBias, "Null pointer check", "CellBias"); + ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias"); + + ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias"); + ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias"); + + if (m_ProjectionWeights) + { + ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2, + (n_cell * n_output), "ProjectionWeights"); + } + if (m_ProjectionBias) + { + ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias"); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias && + !m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && !m_ProjectionBias && + m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && m_ProjectionBias && + m_Parameters.m_ProjectionEnabled)); + if (!projecton_tensors_consistent) + { + throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent."); + } + + // The four layer normalization weights either all have values or none of them have values. Additionally, if + // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights + // either all have values or none of them have values. Layer normalization is used when the values of all the + // layer normalization weights are present + if (m_InputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights"); + } + if (m_ForgetLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + } + if (m_CellLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + if (m_OutputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + } + + if (m_Parameters.m_LayerNormEnabled) + { + if (!m_Parameters.m_CifgEnabled) + { + if (!m_InputLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is " + "disabled but InputLayerNormWeights are not present"); + } + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), + 1, n_cell, "InputLayerNormWeights"); + } + else if (m_InputLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is " + "enabled"); + } + + ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled", + "ForgetLayerNormWeights"); + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + + ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled", + "OutputLayerNormWeights"); + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + + ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled", + "CellLayerNormWeights"); + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer " + "normalisation weights are present."); + } +} + + } // namespace armnn
\ No newline at end of file |