ArmNN
 22.11
QuantizedLstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for QuantizedLstmQueueDescriptor:
QueueDescriptor

Public Member Functions

 QuantizedLstmQueueDescriptor ()
 
void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
 
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Public Attributes

const ConstTensorHandlem_InputToInputWeights
 
const ConstTensorHandlem_InputToForgetWeights
 
const ConstTensorHandlem_InputToCellWeights
 
const ConstTensorHandlem_InputToOutputWeights
 
const ConstTensorHandlem_RecurrentToInputWeights
 
const ConstTensorHandlem_RecurrentToForgetWeights
 
const ConstTensorHandlem_RecurrentToCellWeights
 
const ConstTensorHandlem_RecurrentToOutputWeights
 
const ConstTensorHandlem_InputGateBias
 
const ConstTensorHandlem_ForgetGateBias
 
const ConstTensorHandlem_CellBias
 
const ConstTensorHandlem_OutputGateBias
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 
bool m_AllowExpandedDims = false
 

Additional Inherited Members

- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 646 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ QuantizedLstmQueueDescriptor()

Definition at line 648 of file WorkloadData.hpp.

649  : m_InputToInputWeights(nullptr)
650  , m_InputToForgetWeights(nullptr)
651  , m_InputToCellWeights(nullptr)
652  , m_InputToOutputWeights(nullptr)
653 
654  , m_RecurrentToInputWeights(nullptr)
655  , m_RecurrentToForgetWeights(nullptr)
656  , m_RecurrentToCellWeights(nullptr)
657  , m_RecurrentToOutputWeights(nullptr)
658 
659  , m_InputGateBias(nullptr)
660  , m_ForgetGateBias(nullptr)
661  , m_CellBias(nullptr)
662  , m_OutputGateBias(nullptr)
663  {}
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 3387 of file WorkloadData.cpp.

References WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, armnn::QAsymmU8, armnn::QSymmS16, armnn::Signed32, and QueueDescriptor::ValidateTensorNumDimNumElem().

3388 {
3389  const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3390 
3391  // Validate number of inputs/outputs
3392  ValidateNumInputs(workloadInfo, descriptorName, 3);
3393  ValidateNumOutputs(workloadInfo, descriptorName, 2);
3394 
3395  // Input/output tensor infos
3396  auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3397  auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3398  auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3399 
3400  auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3401  auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3402 
3403  std::vector<DataType> inputOutputSupportedTypes =
3404  {
3406  };
3407 
3408  std::vector<DataType> cellStateSupportedTypes =
3409  {
3411  };
3412 
3413  std::vector<DataType> weightsSupportedTypes =
3414  {
3416  };
3417 
3418  std::vector<DataType> biasSupportedTypes =
3419  {
3421  };
3422 
3423  // Validate types of input/output tensors
3424  ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3425  ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3426  ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3427 
3428  ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3429  ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3430 
3431  // Validate matching types of input/output tensors
3432  ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3433  ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3434  "outputStateIn", "outputStateOut");
3435  ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3436 
3437  // Validate matching quantization info for input/output tensors
3438  ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3439  ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3440  ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3441 
3442  // Infer number of batches, input size and output size from tensor dimensions
3443  const uint32_t numBatches = inputInfo.GetShape()[0];
3444  const uint32_t inputSize = inputInfo.GetShape()[1];
3445  const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3446 
3447  // Validate number of dimensions and number of elements for input/output tensors
3448  ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3449  ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3450  ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3451  ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3452  ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3453 
3454  // Validate number of dimensions and number of elements for weights tensors
3455  ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3456  auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3457  ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3458 
3459  ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3460  auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3461  ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3462 
3463  ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3464  auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3465  ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3466 
3467  ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3468  auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3469  ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3470 
3471  ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3472  auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3473  ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3474 
3475  ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3476  auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3477  ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3478  " RecurrentToForgetWeights");
3479 
3480  ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3481  auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3482  ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3483 
3484  ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3485  auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3486  ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3487 
3488  // Validate data types for weights tensors (all should match each other)
3489  ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3490 
3491  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3492  "inputToInputWeights", "inputToForgetWeights");
3493  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3494  "inputToInputWeights", "inputToCellWeights");
3495  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3496  "inputToInputWeights", "inputToOutputWeights");
3497 
3498  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3499  "inputToInputWeights", "recurrentToInputWeights");
3500  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3501  "inputToInputWeights", "recurrentToForgeteights");
3502  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3503  "inputToInputWeights", "recurrentToCellWeights");
3504  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3505  "inputToInputWeights", "recurrentToOutputWeights");
3506 
3507  // Validate matching quantization info for weight tensors (all should match each other)
3508  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3509  descriptorName, "inputToInputWeights", "inputToForgetWeights");
3510  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3511  descriptorName, "inputToInputWeights", "inputToCellWeights");
3512  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3513  descriptorName, "inputToInputWeights", "inputToOutputWeights");
3514 
3515  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3516  descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3517  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3518  descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3519  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3520  descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3521  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3522  descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3523 
3524  // Validate number of dimensions and number of elements in bias tensors
3525  ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3526  auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3527  ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3528 
3529  ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3530  auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3531  ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3532 
3533  ValidatePointer(m_CellBias, descriptorName, "CellBias");
3534  auto cellBiasInfo = m_CellBias->GetTensorInfo();
3535  ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3536 
3537  ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3538  auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3539  ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3540 
3541  // Validate data types for bias tensors (all should match each other)
3542  ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3543 
3544  ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3545  "inputGateBias", "forgetGateBias");
3546  ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3547  "inputGateBias", "cellBias");
3548  ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3549  "inputGateBias", "outputGateBias");
3550 
3551  // Validate bias tensor quantization info
3552  ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3553  ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3554  ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3555  ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3556 }
void ValidateTensorNumDimNumElem(const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToInputWeights
const TensorInfo & GetTensorInfo() const
std::vector< TensorInfo > m_InputTensorInfos
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ForgetGateBias
std::vector< TensorInfo > m_OutputTensorInfos
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights

Member Data Documentation

◆ m_CellBias

const ConstTensorHandle* m_CellBias

Definition at line 677 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_ForgetGateBias

const ConstTensorHandle* m_ForgetGateBias

Definition at line 676 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputGateBias

const ConstTensorHandle* m_InputGateBias

Definition at line 675 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToCellWeights

const ConstTensorHandle* m_InputToCellWeights

Definition at line 667 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToForgetWeights

const ConstTensorHandle* m_InputToForgetWeights

Definition at line 666 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToInputWeights

const ConstTensorHandle* m_InputToInputWeights

Definition at line 665 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToOutputWeights

const ConstTensorHandle* m_InputToOutputWeights

Definition at line 668 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_OutputGateBias

const ConstTensorHandle* m_OutputGateBias

Definition at line 678 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToCellWeights

const ConstTensorHandle* m_RecurrentToCellWeights

Definition at line 672 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToForgetWeights

const ConstTensorHandle* m_RecurrentToForgetWeights

Definition at line 671 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToInputWeights

const ConstTensorHandle* m_RecurrentToInputWeights

Definition at line 670 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToOutputWeights

const ConstTensorHandle* m_RecurrentToOutputWeights

Definition at line 673 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().


The documentation for this struct was generated from the following files: