aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2020-04-29 20:01:10 +0100
committerJames Conroy <james.conroy@arm.com>2020-05-02 16:44:33 +0000
commit4f1f899da140bb0490cf7e404daeaf1206f4db8b (patch)
treedc6d1215440e0efa677d47a4b944882d72e12cc9 /src/backends/backendsCommon/WorkloadData.cpp
parent56e1a5f68213c9134826ad14c6e1fb4c0d41fb46 (diff)
downloadarmnn-4f1f899da140bb0490cf7e404daeaf1206f4db8b.tar.gz
IVGCVSW-4449 Add QLstm ref implementation
* Adds ref implemenation for new HAL 1.3 operator, QLstm. * Adds Layer and CreateWorkload unit tests. * Adds WorkloadData validate for QLstm. Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I8a721f07ff06105e6495a1a0561b9503aa8146dc
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp286
1 files changed, 286 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index d1249a492f..5796fc7c77 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -2844,6 +2844,292 @@ void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
}
+void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"QLstmQueueDescriptor"};
+
+ // Validate number of inputs/outputs
+ ValidateNumInputs(workloadInfo, descriptorName, 3);
+ ValidateNumOutputs(workloadInfo, descriptorName, 3);
+
+ // Input/output tensor info
+ auto inputInfo = workloadInfo.m_InputTensorInfos[0];
+ auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
+ auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
+
+ auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
+ auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
+ auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
+
+ // Supported types for various tensors in QLSTM
+ std::vector<DataType> inputOutputSupportedTypes =
+ {
+ DataType::QAsymmS8
+ };
+
+ std::vector<DataType> cellStateSupportedTypes =
+ {
+ DataType::QSymmS16
+ };
+
+ std::vector<DataType> weightsSupportedTypes =
+ {
+ DataType::QSymmS8
+ };
+
+ std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
+ {
+ DataType::QSymmS16
+ };
+
+ std::vector<DataType> biasSupportedTypes =
+ {
+ DataType::Signed32
+ };
+
+ // Validate types of input/output tensors
+ ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
+ ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
+ ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
+
+ ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
+ ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
+ ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
+
+ // Validate matching types of input/output tensors
+ ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
+ ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
+ "outputStateIn", "outputStateOut");
+ ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
+
+ // Infer number of batches, number of units, input size and output size from tensor dimensions
+ const uint32_t numBatches = inputInfo.GetShape()[0];
+ const uint32_t inputSize = inputInfo.GetShape()[1];
+ const uint32_t outputSize = outputStateInInfo.GetShape()[1];
+ const uint32_t numUnits = cellStateInInfo.GetShape()[1];
+
+ // Validate number of dimensions and number of elements for input/output tensors
+ ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
+ ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
+ ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
+
+ ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
+ ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
+ ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
+
+ // Validate number of dimensions and number of elements for MANDATORY weight tensors
+ ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
+ auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
+
+ ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
+ auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
+
+ ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
+ auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
+
+ ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
+ auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
+ " RecurrentToForgetWeights");
+
+ ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
+ auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
+
+ ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
+ auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
+
+ // Validate data types for MANDATORY weights tensors (all should match each other)
+ ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
+
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
+ "inputToForgetWeights", "inputToCellWeights");
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
+ "inputToForgetWeights", "inputToOutputWeights");
+
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
+ "inputToForgetWeights", "recurrentToForgeteights");
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
+ "inputToForgetWeights", "recurrentToCellWeights");
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
+ "inputToForgetWeights", "recurrentToOutputWeights");
+
+ // Validate number of dimensions and number of elements for MANDATORY bias tensors
+ ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
+ auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
+ ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
+
+ ValidatePointer(m_CellBias, descriptorName, "CellBias");
+ auto cellBiasInfo = m_CellBias->GetTensorInfo();
+ ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
+
+ ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
+ auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
+ ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
+
+ // Validate data types for MANDATORY bias tensors
+ ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
+
+ ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
+ "forgetGateBias", "cellBias");
+ ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
+ "forgetGateBias", "outputGateBias");
+
+ // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
+ const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
+ !m_Parameters.m_CifgEnabled) ||
+ (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
+ !m_InputGateBias && m_Parameters.m_CifgEnabled));
+
+ if (!allCifgParamsPresentOrNot)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
+ "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
+ "set appropriately.");
+ }
+
+ if (!m_Parameters.m_CifgEnabled)
+ {
+ // Validate number of dimensions and number of elements
+ auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
+
+ auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
+ " RecurrentToInputWeights");
+
+ auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
+ ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
+
+ // Validate data types
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
+ "inputToForgetWeights", "inputToInputWeights");
+ ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
+ "inputToForgetWeights", "recurrentToInputWeights");
+ ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
+ "forgetGateBias", "inputGateBias");
+ }
+
+ // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
+ bool allPeepholeWeightsPresentOrNot =
+ (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
+ && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
+ || (!m_CellToInputWeights && !m_CellToForgetWeights
+ && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
+
+ if (!allPeepholeWeightsPresentOrNot)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
+ "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
+ "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
+ "appropriately.");
+ }
+
+ if (m_Parameters.m_PeepholeEnabled)
+ {
+ auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
+ ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
+
+ auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
+ ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
+ "cellToForgetWeight", "cellToOutputWeights");
+
+ if (!m_Parameters.m_CifgEnabled)
+ {
+ auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
+ ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
+ "cellToForgetWeights", "cellToInputWeights");
+ }
+ }
+
+ // Validate OPTIONAL params: Layer Norm Weights
+ bool allLayerNormWeightsPresentOrNot =
+ (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
+ && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
+ || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
+ && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
+
+ if (!allLayerNormWeightsPresentOrNot)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
+ "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
+ "be present at all (Layer Norm disabled). InputLayerNormWeights should "
+ "only be present when Layer Norm is enabled and CIFG is disabled. "
+ "m_Parameters.m_LayerNormEnabled should be set appropriately.");
+ }
+
+ if (m_Parameters.m_LayerNormEnabled)
+ {
+ auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
+ ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
+
+ auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
+ ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
+ "forgetLayerNormWeights", "cellLayerNormWeights");
+
+ auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
+ ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
+ "forgetLayerNormWeights", "outputLayerNormWeights");
+
+ if (!m_Parameters.m_CifgEnabled)
+ {
+ auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
+ ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
+ "forgetLayerNormWeights", "inputLayerNormWeights");
+ }
+ }
+
+ // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
+ bool correctProjectionTensorsPresent =
+ ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
+ (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
+ (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
+
+ if (!correctProjectionTensorsPresent)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": If projection is enabled, ProjectionWeights should be present and "
+ "ProjectionBias is optional. If projection is disabled, neither "
+ "ProjectionWeights nor ProjectionBias should be present.");
+ }
+
+ if (m_Parameters.m_ProjectionEnabled)
+ {
+ auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
+ ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
+ ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
+
+ if (m_ProjectionBias)
+ {
+ auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
+ ValidateTensorNumDimNumElem(projectionBiasInfo, 1, numUnits, "ProjectionBias");
+ ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
+ }
+
+ }
+ else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
+ outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
+ throw InvalidArgumentException(descriptorName +
+ ": If projection is disabled, output quantization info (scale, offset) "
+ "should match HiddenStateScale and HiddenStateZeroPoint.");
+ }
+
+}
+
void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"QuantizedLstmQueueDescriptor"};