2877 const std::string descriptorName{
"QLstmQueueDescriptor"};
2880 ValidateNumInputs(workloadInfo, descriptorName, 3);
2881 ValidateNumOutputs(workloadInfo, descriptorName, 3);
2893 std::vector<DataType> inputOutputSupportedTypes =
2898 std::vector<DataType> cellStateSupportedTypes =
2903 std::vector<DataType> weightsSupportedTypes =
2908 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
2913 std::vector<DataType> biasSupportedTypes =
2919 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2920 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2921 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2923 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2924 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2925 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
2928 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
2929 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2930 "outputStateIn",
"outputStateOut");
2931 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
2934 const uint32_t numBatches = inputInfo.GetShape()[0];
2935 const uint32_t inputSize = inputInfo.GetShape()[1];
2936 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
2937 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
2940 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName +
" input");
2941 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateIn");
2942 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName +
" cellStateIn");
2944 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateOut");
2945 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName +
" cellStateOut");
2946 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName +
" output");
2951 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize),
" InputToForgetWeights");
2955 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize),
" InputToCellWeights");
2959 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize),
" InputToOutputWeights");
2963 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
2964 " RecurrentToForgetWeights");
2968 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize),
" RecurrentToCellWeights");
2972 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize),
" RecurrentToCellWeights");
2975 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
2977 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
2978 "inputToForgetWeights",
"inputToCellWeights");
2979 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2980 "inputToForgetWeights",
"inputToOutputWeights");
2982 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2983 "inputToForgetWeights",
"recurrentToForgeteights");
2984 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2985 "inputToForgetWeights",
"recurrentToCellWeights");
2986 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2987 "inputToForgetWeights",
"recurrentToOutputWeights");
2992 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits,
" ForgetGateBias");
2994 ValidatePointer(
m_CellBias, descriptorName,
"CellBias");
2996 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits,
" CellBias");
3000 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits,
" OutputGateBias");
3003 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3005 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3006 "forgetGateBias",
"cellBias");
3007 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3008 "forgetGateBias",
"outputGateBias");
3016 if (!allCifgParamsPresentOrNot)
3019 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present " 3020 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be " 3021 "set appropriately.");
3028 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize),
" InputToInputWeights");
3031 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3032 " RecurrentToInputWeights");
3035 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits,
" InputGateBias");
3038 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3039 "inputToForgetWeights",
"inputToInputWeights");
3040 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3041 "inputToForgetWeights",
"recurrentToInputWeights");
3042 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3043 "forgetGateBias",
"inputGateBias");
3047 bool allPeepholeWeightsPresentOrNot =
3053 if (!allPeepholeWeightsPresentOrNot)
3056 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole " 3057 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present " 3058 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set " 3065 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits,
" cellToForgetWeights");
3066 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3069 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits,
" cellToOutputWeights");
3070 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3071 "cellToForgetWeight",
"cellToOutputWeights");
3076 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits,
" cellToInputWeights");
3077 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3078 "cellToForgetWeights",
"cellToInputWeights");
3083 bool allLayerNormWeightsPresentOrNot =
3089 if (!allLayerNormWeightsPresentOrNot)
3092 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights " 3093 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not " 3094 "be present at all (Layer Norm disabled). InputLayerNormWeights should " 3095 "only be present when Layer Norm is enabled and CIFG is disabled. " 3096 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3102 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits,
" forgetLayerNormWeights");
3103 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3106 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits,
" cellLayerNormWeights");
3107 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3108 "forgetLayerNormWeights",
"cellLayerNormWeights");
3111 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits,
" outputLayerNormWeights");
3112 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3113 "forgetLayerNormWeights",
"outputLayerNormWeights");
3118 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits,
" inputLayerNormWeights");
3119 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3120 "forgetLayerNormWeights",
"inputLayerNormWeights");
3125 bool correctProjectionTensorsPresent =
3130 if (!correctProjectionTensorsPresent)
3133 ": If projection is enabled, ProjectionWeights should be present and " 3134 "ProjectionBias is optional. If projection is disabled, neither " 3135 "ProjectionWeights nor ProjectionBias should be present.");
3141 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize),
"ProjectionWeights");
3142 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3147 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize,
"ProjectionBias");
3148 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3155 ": If projection is disabled, output quantization info (scale, offset) " 3156 "should match HiddenStateScale and HiddenStateZeroPoint.");
const ConstCpuTensorHandle * m_CellToForgetWeights
const ConstCpuTensorHandle * m_ProjectionWeights
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstCpuTensorHandle * m_ProjectionBias
float m_HiddenStateScale
Hidden State quantization scale.
const ConstCpuTensorHandle * m_ForgetLayerNormWeights
const ConstCpuTensorHandle * m_CellLayerNormWeights
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_RecurrentToInputWeights
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_CellBias
QLstmDescriptor m_Parameters
const ConstCpuTensorHandle * m_CellToOutputWeights
const ConstCpuTensorHandle * m_OutputLayerNormWeights
std::vector< TensorInfo > m_InputTensorInfos
bool m_LayerNormEnabled
Enable/disable layer normalization.
const ConstCpuTensorHandle * m_InputToForgetWeights
std::vector< TensorInfo > m_OutputTensorInfos
const ConstCpuTensorHandle * m_CellToInputWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstCpuTensorHandle * m_InputGateBias
const ConstCpuTensorHandle * m_InputLayerNormWeights
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_InputToOutputWeights
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
const TensorInfo & GetTensorInfo() const
const ConstCpuTensorHandle * m_InputToInputWeights
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
const ConstCpuTensorHandle * m_InputToCellWeights