2685 const std::string descriptorName{
"QuantizedLstmQueueDescriptor"};
2688 ValidateNumInputs(workloadInfo, descriptorName, 3);
2689 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2699 std::vector<DataType> inputOutputSupportedTypes =
2704 std::vector<DataType> cellStateSupportedTypes =
2709 std::vector<DataType> weightsSupportedTypes =
2714 std::vector<DataType> biasSupportedTypes =
2720 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2721 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2722 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2724 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2725 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2728 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
2729 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2730 "outputStateIn",
"outputStateOut");
2731 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
2734 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
2735 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName,
"input",
"outputStateOut");
2736 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
2739 const uint32_t numBatches = inputInfo.GetShape()[0];
2740 const uint32_t inputSize = inputInfo.GetShape()[1];
2741 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2744 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName +
" input");
2745 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName +
" cellStateIn");
2746 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateIn");
2747 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" cellStateOut");
2748 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateOut");
2753 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize),
" InputToInputWeights");
2757 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize),
" InputToForgetWeights");
2761 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize),
" InputToCellWeights");
2765 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize),
" InputToOutputWeights");
2769 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToInputWeights");
2773 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2774 " RecurrentToForgetWeights");
2778 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToCellWeights");
2782 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToCellWeights");
2785 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2787 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2788 "inputToInputWeights",
"inputToForgetWeights");
2789 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2790 "inputToInputWeights",
"inputToCellWeights");
2791 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2792 "inputToInputWeights",
"inputToOutputWeights");
2794 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2795 "inputToInputWeights",
"recurrentToInputWeights");
2796 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2797 "inputToInputWeights",
"recurrentToForgeteights");
2798 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2799 "inputToInputWeights",
"recurrentToCellWeights");
2800 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2801 "inputToInputWeights",
"recurrentToOutputWeights");
2804 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2805 descriptorName,
"inputToInputWeights",
"inputToForgetWeights");
2806 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2807 descriptorName,
"inputToInputWeights",
"inputToCellWeights");
2808 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2809 descriptorName,
"inputToInputWeights",
"inputToOutputWeights");
2811 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2812 descriptorName,
"inputToInputWeights",
"recurrentToInputWeights");
2813 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2814 descriptorName,
"inputToInputWeights",
"recurrentToForgetWeights");
2815 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2816 descriptorName,
"inputToInputWeights",
"recurrentToCellWeights");
2817 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2818 descriptorName,
"inputToInputWeights",
"recurrentToOutputWeights");
2823 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize,
" InputGateBias");
2827 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize,
" ForgetGateBias");
2829 ValidatePointer(
m_CellBias, descriptorName,
"CellBias");
2831 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize,
" CellBias");
2835 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize,
" OutputGateBias");
2838 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2840 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2841 "inputGateBias",
"forgetGateBias");
2842 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2843 "inputGateBias",
"cellBias");
2844 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2845 "inputGateBias",
"outputGateBias");
2848 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2849 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2850 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2851 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
const ConstCpuTensorHandle * m_CellBias
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_InputToForgetWeights
const TensorInfo & GetTensorInfo() const
const ConstCpuTensorHandle * m_InputToInputWeights
std::vector< TensorInfo > m_OutputTensorInfos
const ConstCpuTensorHandle * m_RecurrentToInputWeights
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_InputGateBias
std::vector< TensorInfo > m_InputTensorInfos
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const ConstCpuTensorHandle * m_InputToOutputWeights
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_InputToCellWeights