3359 const std::string descriptorName{
"QuantizedLstmQueueDescriptor"};
3362 ValidateNumInputs(workloadInfo, descriptorName, 3);
3363 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3373 std::vector<DataType> inputOutputSupportedTypes =
3378 std::vector<DataType> cellStateSupportedTypes =
3383 std::vector<DataType> weightsSupportedTypes =
3388 std::vector<DataType> biasSupportedTypes =
3394 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3395 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3396 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3398 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3399 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3402 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
3403 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3404 "outputStateIn",
"outputStateOut");
3405 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
3408 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
3409 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName,
"input",
"outputStateOut");
3410 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
3413 const uint32_t numBatches = inputInfo.GetShape()[0];
3414 const uint32_t inputSize = inputInfo.GetShape()[1];
3415 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3418 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName +
" input");
3419 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName +
" cellStateIn");
3420 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateIn");
3421 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" cellStateOut");
3422 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateOut");
3427 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize),
" InputToInputWeights");
3431 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize),
" InputToForgetWeights");
3435 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize),
" InputToCellWeights");
3439 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize),
" InputToOutputWeights");
3443 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToInputWeights");
3447 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3448 " RecurrentToForgetWeights");
3452 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToCellWeights");
3456 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToCellWeights");
3459 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3461 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3462 "inputToInputWeights",
"inputToForgetWeights");
3463 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3464 "inputToInputWeights",
"inputToCellWeights");
3465 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3466 "inputToInputWeights",
"inputToOutputWeights");
3468 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3469 "inputToInputWeights",
"recurrentToInputWeights");
3470 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3471 "inputToInputWeights",
"recurrentToForgeteights");
3472 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3473 "inputToInputWeights",
"recurrentToCellWeights");
3474 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3475 "inputToInputWeights",
"recurrentToOutputWeights");
3478 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3479 descriptorName,
"inputToInputWeights",
"inputToForgetWeights");
3480 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3481 descriptorName,
"inputToInputWeights",
"inputToCellWeights");
3482 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3483 descriptorName,
"inputToInputWeights",
"inputToOutputWeights");
3485 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3486 descriptorName,
"inputToInputWeights",
"recurrentToInputWeights");
3487 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3488 descriptorName,
"inputToInputWeights",
"recurrentToForgetWeights");
3489 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3490 descriptorName,
"inputToInputWeights",
"recurrentToCellWeights");
3491 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3492 descriptorName,
"inputToInputWeights",
"recurrentToOutputWeights");
3497 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize,
" InputGateBias");
3501 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize,
" ForgetGateBias");
3503 ValidatePointer(
m_CellBias, descriptorName,
"CellBias");
3505 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize,
" CellBias");
3509 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize,
" OutputGateBias");
3512 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3514 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3515 "inputGateBias",
"forgetGateBias");
3516 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3517 "inputGateBias",
"cellBias");
3518 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3519 "inputGateBias",
"outputGateBias");
3522 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3523 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3524 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3525 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToInputWeights
const TensorInfo & GetTensorInfo() const
std::vector< TensorInfo > m_InputTensorInfos
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ForgetGateBias
std::vector< TensorInfo > m_OutputTensorInfos
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights