ArmNN
 23.11
QuantizedLstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for QuantizedLstmQueueDescriptor:
[legend]
Collaboration diagram for QuantizedLstmQueueDescriptor:
[legend]

Public Member Functions

 QuantizedLstmQueueDescriptor ()
 
void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
 
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Public Attributes

const ConstTensorHandlem_InputToInputWeights
 
const ConstTensorHandlem_InputToForgetWeights
 
const ConstTensorHandlem_InputToCellWeights
 
const ConstTensorHandlem_InputToOutputWeights
 
const ConstTensorHandlem_RecurrentToInputWeights
 
const ConstTensorHandlem_RecurrentToForgetWeights
 
const ConstTensorHandlem_RecurrentToCellWeights
 
const ConstTensorHandlem_RecurrentToOutputWeights
 
const ConstTensorHandlem_InputGateBias
 
const ConstTensorHandlem_ForgetGateBias
 
const ConstTensorHandlem_CellBias
 
const ConstTensorHandlem_OutputGateBias
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 
bool m_AllowExpandedDims = false
 

Additional Inherited Members

- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 614 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ QuantizedLstmQueueDescriptor()

Definition at line 616 of file WorkloadData.hpp.

617  : m_InputToInputWeights(nullptr)
618  , m_InputToForgetWeights(nullptr)
619  , m_InputToCellWeights(nullptr)
620  , m_InputToOutputWeights(nullptr)
621 
622  , m_RecurrentToInputWeights(nullptr)
623  , m_RecurrentToForgetWeights(nullptr)
624  , m_RecurrentToCellWeights(nullptr)
625  , m_RecurrentToOutputWeights(nullptr)
626 
627  , m_InputGateBias(nullptr)
628  , m_ForgetGateBias(nullptr)
629  , m_CellBias(nullptr)
630  , m_OutputGateBias(nullptr)
631  {}

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 3431 of file WorkloadData.cpp.

3432 {
3433  const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3434 
3435  // Validate number of inputs/outputs
3436  ValidateNumInputs(workloadInfo, descriptorName, 3);
3437  ValidateNumOutputs(workloadInfo, descriptorName, 2);
3438 
3439  // Input/output tensor infos
3440  auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3441  auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3442  auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3443 
3444  auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3445  auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3446 
3447  std::vector<DataType> inputOutputSupportedTypes =
3448  {
3450  };
3451 
3452  std::vector<DataType> cellStateSupportedTypes =
3453  {
3455  };
3456 
3457  std::vector<DataType> weightsSupportedTypes =
3458  {
3460  };
3461 
3462  std::vector<DataType> biasSupportedTypes =
3463  {
3465  };
3466 
3467  // Validate types of input/output tensors
3468  ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3469  ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3470  ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3471 
3472  ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3473  ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3474 
3475  // Validate matching types of input/output tensors
3476  ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3477  ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3478  "outputStateIn", "outputStateOut");
3479  ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3480 
3481  // Validate matching quantization info for input/output tensors
3482  ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3483  ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3484  ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3485 
3486  // Infer number of batches, input size and output size from tensor dimensions
3487  const uint32_t numBatches = inputInfo.GetShape()[0];
3488  const uint32_t inputSize = inputInfo.GetShape()[1];
3489  const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3490 
3491  // Validate number of dimensions and number of elements for input/output tensors
3492  ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3493  ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3494  ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3495  ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3496  ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3497 
3498  // Validate number of dimensions and number of elements for weights tensors
3499  ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3500  auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3501  ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3502 
3503  ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3504  auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3505  ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3506 
3507  ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3508  auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3509  ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3510 
3511  ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3512  auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3513  ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3514 
3515  ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3516  auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3517  ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3518 
3519  ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3520  auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3521  ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3522  " RecurrentToForgetWeights");
3523 
3524  ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3525  auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3526  ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3527 
3528  ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3529  auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3530  ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3531 
3532  // Validate data types for weights tensors (all should match each other)
3533  ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3534 
3535  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3536  "inputToInputWeights", "inputToForgetWeights");
3537  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3538  "inputToInputWeights", "inputToCellWeights");
3539  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3540  "inputToInputWeights", "inputToOutputWeights");
3541 
3542  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3543  "inputToInputWeights", "recurrentToInputWeights");
3544  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3545  "inputToInputWeights", "recurrentToForgeteights");
3546  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3547  "inputToInputWeights", "recurrentToCellWeights");
3548  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3549  "inputToInputWeights", "recurrentToOutputWeights");
3550 
3551  // Validate matching quantization info for weight tensors (all should match each other)
3552  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3553  descriptorName, "inputToInputWeights", "inputToForgetWeights");
3554  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3555  descriptorName, "inputToInputWeights", "inputToCellWeights");
3556  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3557  descriptorName, "inputToInputWeights", "inputToOutputWeights");
3558 
3559  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3560  descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3561  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3562  descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3563  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3564  descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3565  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3566  descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3567 
3568  // Validate number of dimensions and number of elements in bias tensors
3569  ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3570  auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3571  ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3572 
3573  ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3574  auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3575  ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3576 
3577  ValidatePointer(m_CellBias, descriptorName, "CellBias");
3578  auto cellBiasInfo = m_CellBias->GetTensorInfo();
3579  ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3580 
3581  ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3582  auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3583  ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3584 
3585  // Validate data types for bias tensors (all should match each other)
3586  ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3587 
3588  ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3589  "inputGateBias", "forgetGateBias");
3590  ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3591  "inputGateBias", "cellBias");
3592  ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3593  "inputGateBias", "outputGateBias");
3594 
3595  // Validate bias tensor quantization info
3596  ValidateBiasTensorQuantization(inputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3597  ValidateBiasTensorQuantization(forgetGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3598  ValidateBiasTensorQuantization(cellBiasInfo, inputToInputWeightsInfo, descriptorName);
3599  ValidateBiasTensorQuantization(outputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3600 }

References TensorInfo::GetShape(), ConstTensorHandle::GetTensorInfo(), QuantizedLstmQueueDescriptor::m_CellBias, QuantizedLstmQueueDescriptor::m_ForgetGateBias, QuantizedLstmQueueDescriptor::m_InputGateBias, WorkloadInfo::m_InputTensorInfos, QuantizedLstmQueueDescriptor::m_InputToCellWeights, QuantizedLstmQueueDescriptor::m_InputToForgetWeights, QuantizedLstmQueueDescriptor::m_InputToInputWeights, QuantizedLstmQueueDescriptor::m_InputToOutputWeights, QuantizedLstmQueueDescriptor::m_OutputGateBias, WorkloadInfo::m_OutputTensorInfos, QuantizedLstmQueueDescriptor::m_RecurrentToCellWeights, QuantizedLstmQueueDescriptor::m_RecurrentToForgetWeights, QuantizedLstmQueueDescriptor::m_RecurrentToInputWeights, QuantizedLstmQueueDescriptor::m_RecurrentToOutputWeights, armnn::QAsymmU8, armnn::QSymmS16, armnn::Signed32, and QueueDescriptor::ValidateTensorNumDimNumElem().

Member Data Documentation

◆ m_CellBias

◆ m_ForgetGateBias

const ConstTensorHandle* m_ForgetGateBias

◆ m_InputGateBias

const ConstTensorHandle* m_InputGateBias

◆ m_InputToCellWeights

const ConstTensorHandle* m_InputToCellWeights

◆ m_InputToForgetWeights

const ConstTensorHandle* m_InputToForgetWeights

◆ m_InputToInputWeights

const ConstTensorHandle* m_InputToInputWeights

◆ m_InputToOutputWeights

const ConstTensorHandle* m_InputToOutputWeights

◆ m_OutputGateBias

const ConstTensorHandle* m_OutputGateBias

◆ m_RecurrentToCellWeights

const ConstTensorHandle* m_RecurrentToCellWeights

◆ m_RecurrentToForgetWeights

const ConstTensorHandle* m_RecurrentToForgetWeights

◆ m_RecurrentToInputWeights

const ConstTensorHandle* m_RecurrentToInputWeights

◆ m_RecurrentToOutputWeights

const ConstTensorHandle* m_RecurrentToOutputWeights

The documentation for this struct was generated from the following files:
armnn::QuantizedLstmQueueDescriptor::m_CellBias
const ConstTensorHandle * m_CellBias
Definition: WorkloadData.hpp:645
armnn::QuantizedLstmQueueDescriptor::m_RecurrentToInputWeights
const ConstTensorHandle * m_RecurrentToInputWeights
Definition: WorkloadData.hpp:638
armnn::QuantizedLstmQueueDescriptor::m_InputToForgetWeights
const ConstTensorHandle * m_InputToForgetWeights
Definition: WorkloadData.hpp:634
armnn::QuantizedLstmQueueDescriptor::m_RecurrentToOutputWeights
const ConstTensorHandle * m_RecurrentToOutputWeights
Definition: WorkloadData.hpp:641
armnn::QuantizedLstmQueueDescriptor::m_InputGateBias
const ConstTensorHandle * m_InputGateBias
Definition: WorkloadData.hpp:643
armnn::DataType::QAsymmU8
@ QAsymmU8
armnn::ConstTensorHandle::GetTensorInfo
const TensorInfo & GetTensorInfo() const
Definition: TensorHandle.hpp:40
armnn::DataType::QSymmS16
@ QSymmS16
armnn::QuantizedLstmQueueDescriptor::m_RecurrentToCellWeights
const ConstTensorHandle * m_RecurrentToCellWeights
Definition: WorkloadData.hpp:640
armnn::WorkloadInfo::m_OutputTensorInfos
std::vector< TensorInfo > m_OutputTensorInfos
Definition: WorkloadInfo.hpp:19
armnn::QuantizedLstmQueueDescriptor::m_InputToInputWeights
const ConstTensorHandle * m_InputToInputWeights
Definition: WorkloadData.hpp:633
armnn::QuantizedLstmQueueDescriptor::m_RecurrentToForgetWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
Definition: WorkloadData.hpp:639
armnn::QuantizedLstmQueueDescriptor::m_InputToCellWeights
const ConstTensorHandle * m_InputToCellWeights
Definition: WorkloadData.hpp:635
armnn::QuantizedLstmQueueDescriptor::m_InputToOutputWeights
const ConstTensorHandle * m_InputToOutputWeights
Definition: WorkloadData.hpp:636
armnn::DataType::Signed32
@ Signed32
armnn::WorkloadInfo::m_InputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos
Definition: WorkloadInfo.hpp:18
armnn::QueueDescriptor::ValidateTensorNumDimNumElem
void ValidateTensorNumDimNumElem(const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
Definition: WorkloadData.cpp:435
armnn::QuantizedLstmQueueDescriptor::m_OutputGateBias
const ConstTensorHandle * m_OutputGateBias
Definition: WorkloadData.hpp:646
armnn::QuantizedLstmQueueDescriptor::m_ForgetGateBias
const ConstTensorHandle * m_ForgetGateBias
Definition: WorkloadData.hpp:644