3045 const std::string descriptorName{
"QLstmQueueDescriptor"};
3048 ValidateNumInputs(workloadInfo, descriptorName, 3);
3049 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3061 std::vector<DataType> inputOutputSupportedTypes =
3066 std::vector<DataType> cellStateSupportedTypes =
3071 std::vector<DataType> weightsSupportedTypes =
3076 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3081 std::vector<DataType> biasSupportedTypes =
3087 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3088 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3089 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3091 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3092 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3093 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3096 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
3097 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3098 "outputStateIn",
"outputStateOut");
3099 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
3102 const uint32_t numBatches = inputInfo.GetShape()[0];
3103 const uint32_t inputSize = inputInfo.GetShape()[1];
3104 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3105 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3108 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName +
" input");
3109 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateIn");
3110 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName +
" cellStateIn");
3112 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateOut");
3113 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName +
" cellStateOut");
3114 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName +
" output");
3119 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize),
" InputToForgetWeights");
3123 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize),
" InputToCellWeights");
3127 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize),
" InputToOutputWeights");
3131 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3132 " RecurrentToForgetWeights");
3136 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize),
" RecurrentToCellWeights");
3140 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize),
" RecurrentToCellWeights");
3143 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3145 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3146 "inputToForgetWeights",
"inputToCellWeights");
3147 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3148 "inputToForgetWeights",
"inputToOutputWeights");
3150 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3151 "inputToForgetWeights",
"recurrentToForgeteights");
3152 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3153 "inputToForgetWeights",
"recurrentToCellWeights");
3154 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3155 "inputToForgetWeights",
"recurrentToOutputWeights");
3160 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits,
" ForgetGateBias");
3162 ValidatePointer(
m_CellBias, descriptorName,
"CellBias");
3164 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits,
" CellBias");
3168 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits,
" OutputGateBias");
3171 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3173 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3174 "forgetGateBias",
"cellBias");
3175 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3176 "forgetGateBias",
"outputGateBias");
3184 if (!allCifgParamsPresentOrNot)
3187 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present " 3188 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be " 3189 "set appropriately.");
3196 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize),
" InputToInputWeights");
3199 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3200 " RecurrentToInputWeights");
3203 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits,
" InputGateBias");
3206 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3207 "inputToForgetWeights",
"inputToInputWeights");
3208 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3209 "inputToForgetWeights",
"recurrentToInputWeights");
3210 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3211 "forgetGateBias",
"inputGateBias");
3215 bool allPeepholeWeightsPresentOrNot =
3221 if (!allPeepholeWeightsPresentOrNot)
3224 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole " 3225 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present " 3226 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set " 3233 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits,
" cellToForgetWeights");
3234 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3237 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits,
" cellToOutputWeights");
3238 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3239 "cellToForgetWeight",
"cellToOutputWeights");
3244 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits,
" cellToInputWeights");
3245 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3246 "cellToForgetWeights",
"cellToInputWeights");
3251 bool allLayerNormWeightsPresentOrNot =
3257 if (!allLayerNormWeightsPresentOrNot)
3260 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights " 3261 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not " 3262 "be present at all (Layer Norm disabled). InputLayerNormWeights should " 3263 "only be present when Layer Norm is enabled and CIFG is disabled. " 3264 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3270 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits,
" forgetLayerNormWeights");
3271 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3274 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits,
" cellLayerNormWeights");
3275 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3276 "forgetLayerNormWeights",
"cellLayerNormWeights");
3279 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits,
" outputLayerNormWeights");
3280 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3281 "forgetLayerNormWeights",
"outputLayerNormWeights");
3286 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits,
" inputLayerNormWeights");
3287 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3288 "forgetLayerNormWeights",
"inputLayerNormWeights");
3293 bool correctProjectionTensorsPresent =
3298 if (!correctProjectionTensorsPresent)
3301 ": If projection is enabled, ProjectionWeights should be present and " 3302 "ProjectionBias is optional. If projection is disabled, neither " 3303 "ProjectionWeights nor ProjectionBias should be present.");
3309 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize),
"ProjectionWeights");
3310 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3315 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize,
"ProjectionBias");
3316 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3323 ": If projection is disabled, output quantization info (scale, offset) " 3324 "should match HiddenStateScale and HiddenStateZeroPoint.");
const ConstTensorHandle * m_CellLayerNormWeights
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_InputToOutputWeights
bool m_PeepholeEnabled
Enable/disable peephole.
float m_HiddenStateScale
Hidden State quantization scale.
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_CellToOutputWeights
const ConstTensorHandle * m_CellToInputWeights
QLstmDescriptor m_Parameters
const ConstTensorHandle * m_ForgetLayerNormWeights
const TensorInfo & GetTensorInfo() const
std::vector< TensorInfo > m_InputTensorInfos
bool m_LayerNormEnabled
Enable/disable layer normalization.
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_CellBias
std::vector< TensorInfo > m_OutputTensorInfos
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_ProjectionBias
const ConstTensorHandle * m_RecurrentToCellWeights
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_RecurrentToInputWeights
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
const ConstTensorHandle * m_RecurrentToForgetWeights
int32_t m_HiddenStateZeroPoint
Hidden State zero point.