ArmNN  NotReleased
QuantizedLstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for QuantizedLstmQueueDescriptor:
QueueDescriptor

Public Member Functions

 QuantizedLstmQueueDescriptor ()
 
void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptor
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 

Public Attributes

const ConstCpuTensorHandlem_InputToInputWeights
 
const ConstCpuTensorHandlem_InputToForgetWeights
 
const ConstCpuTensorHandlem_InputToCellWeights
 
const ConstCpuTensorHandlem_InputToOutputWeights
 
const ConstCpuTensorHandlem_RecurrentToInputWeights
 
const ConstCpuTensorHandlem_RecurrentToForgetWeights
 
const ConstCpuTensorHandlem_RecurrentToCellWeights
 
const ConstCpuTensorHandlem_RecurrentToOutputWeights
 
const ConstCpuTensorHandlem_InputGateBias
 
const ConstCpuTensorHandlem_ForgetGateBias
 
const ConstCpuTensorHandlem_CellBias
 
const ConstCpuTensorHandlem_OutputGateBias
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 

Additional Inherited Members

- Protected Member Functions inherited from QueueDescriptor
 ~QueueDescriptor ()=default
 
 QueueDescriptor ()=default
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 507 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ QuantizedLstmQueueDescriptor()

Definition at line 509 of file WorkloadData.hpp.

510  : m_InputToInputWeights(nullptr)
511  , m_InputToForgetWeights(nullptr)
512  , m_InputToCellWeights(nullptr)
513  , m_InputToOutputWeights(nullptr)
514 
515  , m_RecurrentToInputWeights(nullptr)
516  , m_RecurrentToForgetWeights(nullptr)
517  , m_RecurrentToCellWeights(nullptr)
518  , m_RecurrentToOutputWeights(nullptr)
519 
520  , m_InputGateBias(nullptr)
521  , m_ForgetGateBias(nullptr)
522  , m_CellBias(nullptr)
523  , m_OutputGateBias(nullptr)
524  {}
const ConstCpuTensorHandle * m_CellBias
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_InputToForgetWeights
const ConstCpuTensorHandle * m_InputToInputWeights
const ConstCpuTensorHandle * m_RecurrentToInputWeights
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_InputGateBias
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const ConstCpuTensorHandle * m_InputToOutputWeights
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_InputToCellWeights

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 2683 of file WorkloadData.cpp.

References WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, armnn::QAsymmU8, armnn::QSymmS16, and armnn::Signed32.

2684 {
2685  const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
2686 
2687  // Validate number of inputs/outputs
2688  ValidateNumInputs(workloadInfo, descriptorName, 3);
2689  ValidateNumOutputs(workloadInfo, descriptorName, 2);
2690 
2691  // Input/output tensor infos
2692  auto inputInfo = workloadInfo.m_InputTensorInfos[0];
2693  auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
2694  auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
2695 
2696  auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
2697  auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
2698 
2699  std::vector<DataType> inputOutputSupportedTypes =
2700  {
2702  };
2703 
2704  std::vector<DataType> cellStateSupportedTypes =
2705  {
2707  };
2708 
2709  std::vector<DataType> weightsSupportedTypes =
2710  {
2712  };
2713 
2714  std::vector<DataType> biasSupportedTypes =
2715  {
2717  };
2718 
2719  // Validate types of input/output tensors
2720  ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
2721  ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
2722  ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
2723 
2724  ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
2725  ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
2726 
2727  // Validate matching types of input/output tensors
2728  ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2729  ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
2730  "outputStateIn", "outputStateOut");
2731  ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2732 
2733  // Validate matching quantization info for input/output tensors
2734  ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
2735  ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
2736  ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
2737 
2738  // Infer number of batches, input size and output size from tensor dimensions
2739  const uint32_t numBatches = inputInfo.GetShape()[0];
2740  const uint32_t inputSize = inputInfo.GetShape()[1];
2741  const uint32_t outputSize = cellStateInInfo.GetShape()[1];
2742 
2743  // Validate number of dimensions and number of elements for input/output tensors
2744  ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
2745  ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
2746  ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
2747  ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
2748  ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
2749 
2750  // Validate number of dimensions and number of elements for weights tensors
2751  ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
2752  auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
2753  ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
2754 
2755  ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
2756  auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
2757  ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
2758 
2759  ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
2760  auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
2761  ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
2762 
2763  ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
2764  auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
2765  ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
2766 
2767  ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
2768  auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
2769  ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
2770 
2771  ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
2772  auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
2773  ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
2774  " RecurrentToForgetWeights");
2775 
2776  ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
2777  auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
2778  ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2779 
2780  ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
2781  auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
2782  ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
2783 
2784  // Validate data types for weights tensors (all should match each other)
2785  ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
2786 
2787  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
2788  "inputToInputWeights", "inputToForgetWeights");
2789  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
2790  "inputToInputWeights", "inputToCellWeights");
2791  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
2792  "inputToInputWeights", "inputToOutputWeights");
2793 
2794  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
2795  "inputToInputWeights", "recurrentToInputWeights");
2796  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
2797  "inputToInputWeights", "recurrentToForgeteights");
2798  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
2799  "inputToInputWeights", "recurrentToCellWeights");
2800  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
2801  "inputToInputWeights", "recurrentToOutputWeights");
2802 
2803  // Validate matching quantization info for weight tensors (all should match each other)
2804  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
2805  descriptorName, "inputToInputWeights", "inputToForgetWeights");
2806  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
2807  descriptorName, "inputToInputWeights", "inputToCellWeights");
2808  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
2809  descriptorName, "inputToInputWeights", "inputToOutputWeights");
2810 
2811  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
2812  descriptorName, "inputToInputWeights", "recurrentToInputWeights");
2813  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
2814  descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
2815  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
2816  descriptorName, "inputToInputWeights", "recurrentToCellWeights");
2817  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
2818  descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
2819 
2820  // Validate number of dimensions and number of elements in bias tensors
2821  ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
2822  auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
2823  ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
2824 
2825  ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
2826  auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
2827  ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
2828 
2829  ValidatePointer(m_CellBias, descriptorName, "CellBias");
2830  auto cellBiasInfo = m_CellBias->GetTensorInfo();
2831  ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
2832 
2833  ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
2834  auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
2835  ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
2836 
2837  // Validate data types for bias tensors (all should match each other)
2838  ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
2839 
2840  ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
2841  "inputGateBias", "forgetGateBias");
2842  ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
2843  "inputGateBias", "cellBias");
2844  ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
2845  "inputGateBias", "outputGateBias");
2846 
2847  // Validate bias tensor quantization info
2848  ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2849  ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2850  ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2851  ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
2852 }
const ConstCpuTensorHandle * m_CellBias
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_InputToForgetWeights
const TensorInfo & GetTensorInfo() const
const ConstCpuTensorHandle * m_InputToInputWeights
std::vector< TensorInfo > m_OutputTensorInfos
const ConstCpuTensorHandle * m_RecurrentToInputWeights
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_InputGateBias
std::vector< TensorInfo > m_InputTensorInfos
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const ConstCpuTensorHandle * m_InputToOutputWeights
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_InputToCellWeights

Member Data Documentation

◆ m_CellBias

const ConstCpuTensorHandle* m_CellBias

Definition at line 538 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_ForgetGateBias

const ConstCpuTensorHandle* m_ForgetGateBias

Definition at line 537 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputGateBias

const ConstCpuTensorHandle* m_InputGateBias

Definition at line 536 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToCellWeights

const ConstCpuTensorHandle* m_InputToCellWeights

Definition at line 528 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToForgetWeights

const ConstCpuTensorHandle* m_InputToForgetWeights

Definition at line 527 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToInputWeights

const ConstCpuTensorHandle* m_InputToInputWeights

Definition at line 526 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToOutputWeights

const ConstCpuTensorHandle* m_InputToOutputWeights

Definition at line 529 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_OutputGateBias

const ConstCpuTensorHandle* m_OutputGateBias

Definition at line 539 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToCellWeights

const ConstCpuTensorHandle* m_RecurrentToCellWeights

Definition at line 533 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToForgetWeights

const ConstCpuTensorHandle* m_RecurrentToForgetWeights

Definition at line 532 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToInputWeights

const ConstCpuTensorHandle* m_RecurrentToInputWeights

Definition at line 531 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToOutputWeights

const ConstCpuTensorHandle* m_RecurrentToOutputWeights

Definition at line 534 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().


The documentation for this struct was generated from the following files: