diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 271 |
1 files changed, 265 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index e7915dd40b..3766f5f7ca 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -114,6 +114,30 @@ void ValidateTensorNumDimensions(const TensorInfo& tensor, } //--------------------------------------------------------------- +void ValidateTensorNumElements(const TensorInfo& tensor, + std::string const& descName, + unsigned int numElements, + std::string const& tensorName) +{ + if (tensor.GetNumElements() != numElements) + { + throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " + + to_string(tensor.GetNumDimensions()) + " elements for " + + tensorName + " tensor."); + } +} + +//--------------------------------------------------------------- +void ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo, + unsigned int numDimension, + unsigned int numElements, + std::string const& tensorName) +{ + ValidateTensorNumDimensions(tensorInfo, "ValidateTensorNumDimNumElem: NumDimensionCheck", numDimension, tensorName); + ValidateTensorNumElements(tensorInfo, "ValidateTensorNumDimNumElem: NumElementsCheck", numElements, tensorName); +} + +//--------------------------------------------------------------- void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType, const std::string& descName, std::string const& tensorName) { @@ -1238,22 +1262,257 @@ void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { - ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "LstmQueueDescriptor", 2, "input"); - ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "LstmQueueDescriptor", 2, "output"); - std::vector<DataType> supportedTypes = { DataType::Float16, DataType::Float32, DataType::QuantisedSymm16 }; + // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions() + // check for supported type of one input and match them with all the other input and output ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, "LstmQueueDescriptor"); + // type matches all other inputs + for (uint32_t i = 1; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_InputTensorInfos[i], + "LstmQueueDescriptor", + "InputTensor[0]", + "InputTensor[" + std::to_string(i) + "]"); + } + // type matches all other outputs + for (uint32_t i = 0; i < workloadInfo.m_OutputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[i], + "LstmQueueDescriptor", + "InputTensor[0]", + "OutputTensor[" + std::to_string(i) + "]"); + } - ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], - supportedTypes, - "LstmQueueDescriptor"); + // TODO: check clipping parameter is valid + + // Inferring batch size, number of outputs and number of cells from the inputs. + // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo + const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; + const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0]; + ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights"); + const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0]; + ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights"); + const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1]; + + // check dimensions of all inputs and outputs + if (workloadInfo.m_InputTensorInfos.size() != 3) + { + throw InvalidArgumentException("Invalid number of inputs."); + } + if (workloadInfo.m_OutputTensorInfos.size() != 4) + { + throw InvalidArgumentException("Invalid number of outputs."); + } + // input tensor + ValidateTensorNumDimNumElem( workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input), + "LstmQueueDescriptor input[0]"); + // outputStateInTensor + ValidateTensorNumDimNumElem( workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output), + "LstmQueueDescriptor input[1]"); + // outputStateInTensor + ValidateTensorNumDimNumElem( workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell), + "LstmQueueDescriptor input[2]"); + // scratchBufferTensor + unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4; + ValidateTensorNumDimNumElem( workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize), + "LstmQueueDescriptor output[0]"); + // outputStateOutTensor + ValidateTensorNumDimNumElem( workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output), + "LstmQueueDescriptor output[1]"); + // cellStateOutTensor + ValidateTensorNumDimNumElem( workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell), + "LstmQueueDescriptor output[2]"); + // outputTensor + ValidateTensorNumDimNumElem( workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output), + "LstmQueueDescriptor output[3]"); + + + // check that dimensions of inputs/outputs and QueueDescriptor data match with each other + if ( m_InputToInputWeights ) + { + ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputLayerNormWeights"); + } + + ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights"); + ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights"); + ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToCellWeights"); + + if ( m_RecurrentToInputWeights ) + { + ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToInputWeights"); + } + + ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToCellWeights"); + + // Make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly. + bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights && + !m_Parameters.m_CifgEnabled) || + (!m_InputToInputWeights && !m_RecurrentToInputWeights && + m_Parameters.m_CifgEnabled)); + if (!cifg_weights_all_or_none) + { + throw InvalidArgumentException("Input-Gate's parameters InputToInputWeights and RecurrentToInputWeights must " + "either both be present (regular LSTM) or both not present (CIFG-LSTM). In " + "addition CifgEnable must be set accordingly"); + } + + if ( m_CellToInputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1, + n_cell, "CellToInputWeights"); + } + if ( m_CellToForgetWeights ) + { + ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1, + n_cell, "CellToForgetWeights"); + } + if ( m_CellToOutputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1, + n_cell, "CellToOutputWeights"); + } + + // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly. + bool peephole_weights_all_or_none = + (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights + && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled) + || ( !m_CellToInputWeights && !m_CellToForgetWeights + && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled)); + if (!peephole_weights_all_or_none) + { + throw InvalidArgumentException("Invalid combination of peephole parameters"); + } + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + if (m_Parameters.m_CifgEnabled) + { + if (m_InputGateBias) + { + throw InvalidArgumentException("InputGateBias is present and CIFG-LSTM is enabled"); + } + } + else + { + if (!m_InputGateBias) + { + throw InvalidArgumentException("If CIFG-LSTM is disabled InputGateBias must be present."); + } + ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1, + n_cell, "InputGateBias"); + } + + ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias"); + ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias"); + + ValidatePointer(m_CellBias, "Null pointer check", "CellBias"); + ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias"); + + ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias"); + ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias"); + + if (m_ProjectionWeights) + { + ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2, + (n_cell * n_output), "ProjectionWeights"); + } + if (m_ProjectionBias) + { + ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias"); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias && + !m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && !m_ProjectionBias && + m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && m_ProjectionBias && + m_Parameters.m_ProjectionEnabled)); + if (!projecton_tensors_consistent) + { + throw InvalidArgumentException("Projection tensors are inconsistent."); + } + + // The four layer normalization weights either all have values or none of them have values. Additionally, if + // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights + // either all have values or none of them have values. Layer normalization is used when the values of all the + // layer normalization weights are present + if (m_InputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights"); + } + if (m_ForgetLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + } + if (m_CellLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + if (m_OutputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + } + + + if (m_Parameters.m_LayerNormEnabled) + { + if (!m_Parameters.m_CifgEnabled) + { + if (!m_InputLayerNormWeights) + { + throw InvalidArgumentException("Layer normalisation is enabled and CIFG-LSTM is disabled but " + "InputLayerNormWeights are not present"); + } + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), + 1, n_cell, "InputLayerNormWeights"); + } + else if (m_InputLayerNormWeights) + { + throw InvalidArgumentException("InputLayerNormWeights are present while CIFG is enabled"); + } + + ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled", + "ForgetLayerNormWeights"); + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + + ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled", + "OutputLayerNormWeights"); + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + + ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled", + "CellLayerNormWeights"); + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights) + { + throw InvalidArgumentException("Layer normalisation is disabled but one or more layer normalisation weights " + "are present."); + } } void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const |