3163 const std::string descriptorName{
"QuantizedLstmQueueDescriptor"};
3166 ValidateNumInputs(workloadInfo, descriptorName, 3);
3167 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3177 std::vector<DataType> inputOutputSupportedTypes =
3182 std::vector<DataType> cellStateSupportedTypes =
3187 std::vector<DataType> weightsSupportedTypes =
3192 std::vector<DataType> biasSupportedTypes =
3198 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3199 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3200 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3202 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3203 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3206 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
3207 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3208 "outputStateIn",
"outputStateOut");
3209 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
3212 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName,
"input",
"outputStateIn");
3213 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName,
"input",
"outputStateOut");
3214 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName,
"cellStateIn",
"cellStateOut");
3217 const uint32_t numBatches = inputInfo.GetShape()[0];
3218 const uint32_t inputSize = inputInfo.GetShape()[1];
3219 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3222 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName +
" input");
3223 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName +
" cellStateIn");
3224 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateIn");
3225 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" cellStateOut");
3226 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName +
" outputStateOut");
3231 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize),
" InputToInputWeights");
3235 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize),
" InputToForgetWeights");
3239 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize),
" InputToCellWeights");
3243 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize),
" InputToOutputWeights");
3247 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToInputWeights");
3251 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3252 " RecurrentToForgetWeights");
3256 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToCellWeights");
3260 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize),
" RecurrentToCellWeights");
3263 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3265 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3266 "inputToInputWeights",
"inputToForgetWeights");
3267 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3268 "inputToInputWeights",
"inputToCellWeights");
3269 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3270 "inputToInputWeights",
"inputToOutputWeights");
3272 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3273 "inputToInputWeights",
"recurrentToInputWeights");
3274 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3275 "inputToInputWeights",
"recurrentToForgeteights");
3276 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3277 "inputToInputWeights",
"recurrentToCellWeights");
3278 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3279 "inputToInputWeights",
"recurrentToOutputWeights");
3282 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3283 descriptorName,
"inputToInputWeights",
"inputToForgetWeights");
3284 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3285 descriptorName,
"inputToInputWeights",
"inputToCellWeights");
3286 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3287 descriptorName,
"inputToInputWeights",
"inputToOutputWeights");
3289 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3290 descriptorName,
"inputToInputWeights",
"recurrentToInputWeights");
3291 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3292 descriptorName,
"inputToInputWeights",
"recurrentToForgetWeights");
3293 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3294 descriptorName,
"inputToInputWeights",
"recurrentToCellWeights");
3295 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3296 descriptorName,
"inputToInputWeights",
"recurrentToOutputWeights");
3301 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize,
" InputGateBias");
3305 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize,
" ForgetGateBias");
3307 ValidatePointer(
m_CellBias, descriptorName,
"CellBias");
3309 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize,
" CellBias");
3313 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize,
" OutputGateBias");
3316 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3318 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3319 "inputGateBias",
"forgetGateBias");
3320 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3321 "inputGateBias",
"cellBias");
3322 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3323 "inputGateBias",
"outputGateBias");
3326 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3327 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3328 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3329 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_InputGateBias
const ConstCpuTensorHandle * m_InputToCellWeights
std::vector< TensorInfo > m_InputTensorInfos
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_RecurrentToInputWeights
std::vector< TensorInfo > m_OutputTensorInfos
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const ConstCpuTensorHandle * m_CellBias
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_InputToForgetWeights
const ConstCpuTensorHandle * m_InputToOutputWeights
const ConstCpuTensorHandle * m_InputToInputWeights
const TensorInfo & GetTensorInfo() const