From 9c3cae8683e4b24932446b88d3ecbc02f9f9fa08 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Thu, 1 Aug 2019 16:01:48 +0100 Subject: IVGCVSW-3470 Add Quantized_LSTM tests * Added Layer and Create Workload tests for the new Quantized LSTM layer. * Tests to be enabled on NEON and CL in their respective patches. Signed-off-by: James Conroy Change-Id: I7e9e9768dd63010ab58367c45fffcff452377cfb --- src/backends/backendsCommon/WorkloadData.cpp | 171 +++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 88cd6a69d6..a4d35827fa 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -2266,4 +2266,175 @@ void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloa } } +void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"QuantizedLstmQueueDescriptor"}; + + // Validate number of inputs/outputs + ValidateNumInputs(workloadInfo, descriptorName, 3); + ValidateNumOutputs(workloadInfo, descriptorName, 2); + + // Input/output tensor infos + auto inputInfo = workloadInfo.m_InputTensorInfos[0]; + auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1]; + auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2]; + + auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0]; + auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1]; + + std::vector inputOutputSupportedTypes = + { + DataType::QuantisedAsymm8 + }; + + std::vector cellStateSupportedTypes = + { + DataType::QuantisedSymm16 + }; + + std::vector weightsSupportedTypes = + { + DataType::QuantisedAsymm8 + }; + + std::vector biasSupportedTypes = + { + DataType::Signed32 + }; + + // Validate types of input/output tensors + ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName); + ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName); + ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName); + + ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName); + ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName); + + // Validate matching types of input/output tensors + ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn"); + ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName, + "outputStateIn", "outputStateOut"); + ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut"); + + // Validate matching quantization info for input/output tensors + ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn"); + ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut"); + ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut"); + + // Infer number of batches, input size and output size from tensor dimensions + const uint32_t numBatches = inputInfo.GetShape()[0]; + const uint32_t inputSize = inputInfo.GetShape()[1]; + const uint32_t outputSize = cellStateInInfo.GetShape()[1]; + + // Validate number of dimensions and number of elements for input/output tensors + ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input"); + ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn"); + ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn"); + ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut"); + ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut"); + + // Validate number of dimensions and number of elements for weights tensors + ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights"); + auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights"); + + ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights"); + auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights"); + auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights"); + + ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights"); + auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights"); + + ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights"); + auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights"); + + ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights"); + auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize), + " RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights"); + auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights"); + + ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights"); + auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo(); + ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights"); + + // Validate data types for weights tensors (all should match each other) + ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName, + "inputToInputWeights", "inputToForgetWeights"); + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName, + "inputToInputWeights", "inputToCellWeights"); + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName, + "inputToInputWeights", "inputToOutputWeights"); + + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName, + "inputToInputWeights", "recurrentToInputWeights"); + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName, + "inputToInputWeights", "recurrentToForgeteights"); + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName, + "inputToInputWeights", "recurrentToCellWeights"); + ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName, + "inputToInputWeights", "recurrentToOutputWeights"); + + // Validate matching quantization info for weight tensors (all should match each other) + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo, + descriptorName, "inputToInputWeights", "inputToForgetWeights"); + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo, + descriptorName, "inputToInputWeights", "inputToCellWeights"); + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo, + descriptorName, "inputToInputWeights", "inputToOutputWeights"); + + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo, + descriptorName, "inputToInputWeights", "recurrentToInputWeights"); + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, + descriptorName, "inputToInputWeights", "recurrentToForgetWeights"); + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo, + descriptorName, "inputToInputWeights", "recurrentToCellWeights"); + ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, + descriptorName, "inputToInputWeights", "recurrentToOutputWeights"); + + // Validate number of dimensions and number of elements in bias tensors + ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias"); + auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias"); + + ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias"); + auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias"); + + ValidatePointer(m_CellBias, descriptorName, "CellBias"); + auto cellBiasInfo = m_CellBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias"); + + ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias"); + auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo(); + ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias"); + + // Validate data types for bias tensors (all should match each other) + ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName); + + ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName, + "inputGateBias", "forgetGateBias"); + ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName, + "inputGateBias", "cellBias"); + ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName, + "inputGateBias", "outputGateBias"); + + // Validate bias tensor quantization info + ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName); + ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName); + ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName); + ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName); +} + } // namespace armnn -- cgit v1.2.1