diff options
author | James Conroy <james.conroy@arm.com> | 2020-04-29 20:01:10 +0100 |
---|---|---|
committer | James Conroy <james.conroy@arm.com> | 2020-05-02 16:44:33 +0000 |
commit | 4f1f899da140bb0490cf7e404daeaf1206f4db8b (patch) | |
tree | dc6d1215440e0efa677d47a4b944882d72e12cc9 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 56e1a5f68213c9134826ad14c6e1fb4c0d41fb46 (diff) | |
download | armnn-4f1f899da140bb0490cf7e404daeaf1206f4db8b.tar.gz |
IVGCVSW-4449 Add QLstm ref implementation
* Adds ref implemenation for new HAL 1.3
operator, QLstm.
* Adds Layer and CreateWorkload unit tests.
* Adds WorkloadData validate for QLstm.
Signed-off-by: James Conroy <james.conroy@arm.com>
Change-Id: I8a721f07ff06105e6495a1a0561b9503aa8146dc
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 286 |
1 files changed, 286 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index d1249a492f..5796fc7c77 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -2844,6 +2844,292 @@ void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } +void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"QLstmQueueDescriptor"}; + + // Validate number of inputs/outputs + ValidateNumInputs(workloadInfo, descriptorName, 3); + ValidateNumOutputs(workloadInfo, descriptorName, 3); + + // Input/output tensor info + auto inputInfo = workloadInfo.m_InputTensorInfos[0]; + auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1]; + auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2]; + + auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0]; + auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1]; + auto outputInfo = workloadInfo.m_OutputTensorInfos[2]; + + // Supported types for various tensors in QLSTM + std::vector<DataType> inputOutputSupportedTypes = + { + DataType::QAsymmS8 + }; + + std::vector<DataType> cellStateSupportedTypes = + { + DataType::QSymmS16 + }; + + std::vector<DataType> weightsSupportedTypes = + { + DataType::QSymmS8 + }; + + std::vector<DataType> layerNormPeepholeWeightsSupportedTypes = + { + DataType::QSymmS16 + }; + + std::vector<DataType> biasSupportedTypes = + { + DataType::Signed32 + }; + + // Validate types of input/output tensors + ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName); + + ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName); + ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName); + + // Validate matching types of input/output tensors + ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn"); + ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName, + "outputStateIn", "outputStateOut"); + ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut"); + + // Infer number of batches, number of units, input size and output size from tensor dimensions + const uint32_t numBatches = inputInfo.GetShape()[0]; + const uint32_t inputSize = inputInfo.GetShape()[1]; + const uint32_t outputSize = outputStateInInfo.GetShape()[1]; + const uint32_t numUnits = cellStateInInfo.GetShape()[1]; + + // Validate number of dimensions and number of elements for input/output tensors + ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input"); + ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn"); + ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn"); + + ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut"); + ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut"); + ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output"); + + // Validate number of dimensions and number of elements for MANDATORY weight tensors + ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights"); + auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights"); + auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights"); + + ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights"); + auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights"); + + ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights"); + auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize), + " RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights"); + auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights"); + + ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights"); + auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights"); + + // Validate data types for MANDATORY weights tensors (all should match each other) + ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName, + "inputToForgetWeights", "inputToCellWeights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName, + "inputToForgetWeights", "inputToOutputWeights"); + + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToForgeteights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToCellWeights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToOutputWeights"); + + // Validate number of dimensions and number of elements for MANDATORY bias tensors + ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias"); + auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias"); + + ValidatePointer(m_CellBias, descriptorName, "CellBias"); + auto cellBiasInfo = m_CellBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias"); + + ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias"); + auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias"); + + // Validate data types for MANDATORY bias tensors + ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName, + "forgetGateBias", "cellBias"); + ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName, + "forgetGateBias", "outputGateBias"); + + // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias) + const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias && + !m_Parameters.m_CifgEnabled) || + (!m_InputToInputWeights && !m_RecurrentToInputWeights && + !m_InputGateBias && m_Parameters.m_CifgEnabled)); + + if (!allCifgParamsPresentOrNot) + { + throw InvalidArgumentException(descriptorName + + ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present " + "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be " + "set appropriately."); + } + + if (!m_Parameters.m_CifgEnabled) + { + // Validate number of dimensions and number of elements + auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights"); + + auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize), + " RecurrentToInputWeights"); + + auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias"); + + // Validate data types + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName, + "inputToForgetWeights", "inputToInputWeights"); + ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName, + "inputToForgetWeights", "recurrentToInputWeights"); + ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName, + "forgetGateBias", "inputGateBias"); + } + + // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights) + bool allPeepholeWeightsPresentOrNot = + (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights + && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled) + || (!m_CellToInputWeights && !m_CellToForgetWeights + && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled)); + + if (!allPeepholeWeightsPresentOrNot) + { + throw InvalidArgumentException(descriptorName + + ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole " + "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present " + "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set " + "appropriately."); + } + + if (m_Parameters.m_PeepholeEnabled) + { + auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights"); + ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName); + + auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights"); + ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName, + "cellToForgetWeight", "cellToOutputWeights"); + + if (!m_Parameters.m_CifgEnabled) + { + auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights"); + ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName, + "cellToForgetWeights", "cellToInputWeights"); + } + } + + // Validate OPTIONAL params: Layer Norm Weights + bool allLayerNormWeightsPresentOrNot = + (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights + && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled) + || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights + && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled)); + + if (!allLayerNormWeightsPresentOrNot) + { + throw InvalidArgumentException(descriptorName + + ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights " + "and CellLayerNormWeights should all be present (Layer Norm enabled) or not " + "be present at all (Layer Norm disabled). InputLayerNormWeights should " + "only be present when Layer Norm is enabled and CIFG is disabled. " + "m_Parameters.m_LayerNormEnabled should be set appropriately."); + } + + if (m_Parameters.m_LayerNormEnabled) + { + auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights"); + ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName); + + auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights"); + ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName, + "forgetLayerNormWeights", "cellLayerNormWeights"); + + auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights"); + ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName, + "forgetLayerNormWeights", "outputLayerNormWeights"); + + if (!m_Parameters.m_CifgEnabled) + { + auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights"); + ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName, + "forgetLayerNormWeights", "inputLayerNormWeights"); + } + } + + // Validate OPTIONAL params: Projection (projectionWeights, projectionBias) + bool correctProjectionTensorsPresent = + ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) || + (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) || + (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled)); + + if (!correctProjectionTensorsPresent) + { + throw InvalidArgumentException(descriptorName + + ": If projection is enabled, ProjectionWeights should be present and " + "ProjectionBias is optional. If projection is disabled, neither " + "ProjectionWeights nor ProjectionBias should be present."); + } + + if (m_Parameters.m_ProjectionEnabled) + { + auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights"); + ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName); + + if (m_ProjectionBias) + { + auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(projectionBiasInfo, 1, numUnits, "ProjectionBias"); + ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName); + } + + } + else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) && + outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) { + throw InvalidArgumentException(descriptorName + + ": If projection is disabled, output quantization info (scale, offset) " + "should match HiddenStateScale and HiddenStateZeroPoint."); + } + +} + void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"QuantizedLstmQueueDescriptor"}; |