ArmNN
 20.08
QuantizedLstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for QuantizedLstmQueueDescriptor:
QueueDescriptor

Public Member Functions

 QuantizedLstmQueueDescriptor ()
 
void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptor
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 

Public Attributes

const ConstCpuTensorHandlem_InputToInputWeights
 
const ConstCpuTensorHandlem_InputToForgetWeights
 
const ConstCpuTensorHandlem_InputToCellWeights
 
const ConstCpuTensorHandlem_InputToOutputWeights
 
const ConstCpuTensorHandlem_RecurrentToInputWeights
 
const ConstCpuTensorHandlem_RecurrentToForgetWeights
 
const ConstCpuTensorHandlem_RecurrentToCellWeights
 
const ConstCpuTensorHandlem_RecurrentToOutputWeights
 
const ConstCpuTensorHandlem_InputGateBias
 
const ConstCpuTensorHandlem_ForgetGateBias
 
const ConstCpuTensorHandlem_CellBias
 
const ConstCpuTensorHandlem_OutputGateBias
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 

Additional Inherited Members

- Protected Member Functions inherited from QueueDescriptor
 ~QueueDescriptor ()=default
 
 QueueDescriptor ()=default
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 586 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ QuantizedLstmQueueDescriptor()

Definition at line 588 of file WorkloadData.hpp.

589  : m_InputToInputWeights(nullptr)
590  , m_InputToForgetWeights(nullptr)
591  , m_InputToCellWeights(nullptr)
592  , m_InputToOutputWeights(nullptr)
593 
594  , m_RecurrentToInputWeights(nullptr)
595  , m_RecurrentToForgetWeights(nullptr)
596  , m_RecurrentToCellWeights(nullptr)
597  , m_RecurrentToOutputWeights(nullptr)
598 
599  , m_InputGateBias(nullptr)
600  , m_ForgetGateBias(nullptr)
601  , m_CellBias(nullptr)
602  , m_OutputGateBias(nullptr)
603  {}
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_InputGateBias
const ConstCpuTensorHandle * m_InputToCellWeights
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_RecurrentToInputWeights
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const ConstCpuTensorHandle * m_CellBias
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_InputToForgetWeights
const ConstCpuTensorHandle * m_InputToOutputWeights
const ConstCpuTensorHandle * m_InputToInputWeights

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 3161 of file WorkloadData.cpp.

References WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, armnn::QAsymmU8, armnn::QSymmS16, and armnn::Signed32.

3162 {
3163  const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3164 
3165  // Validate number of inputs/outputs
3166  ValidateNumInputs(workloadInfo, descriptorName, 3);
3167  ValidateNumOutputs(workloadInfo, descriptorName, 2);
3168 
3169  // Input/output tensor infos
3170  auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3171  auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3172  auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3173 
3174  auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3175  auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3176 
3177  std::vector<DataType> inputOutputSupportedTypes =
3178  {
3180  };
3181 
3182  std::vector<DataType> cellStateSupportedTypes =
3183  {
3185  };
3186 
3187  std::vector<DataType> weightsSupportedTypes =
3188  {
3190  };
3191 
3192  std::vector<DataType> biasSupportedTypes =
3193  {
3195  };
3196 
3197  // Validate types of input/output tensors
3198  ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3199  ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3200  ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3201 
3202  ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3203  ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3204 
3205  // Validate matching types of input/output tensors
3206  ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3207  ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3208  "outputStateIn", "outputStateOut");
3209  ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3210 
3211  // Validate matching quantization info for input/output tensors
3212  ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3213  ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3214  ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3215 
3216  // Infer number of batches, input size and output size from tensor dimensions
3217  const uint32_t numBatches = inputInfo.GetShape()[0];
3218  const uint32_t inputSize = inputInfo.GetShape()[1];
3219  const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3220 
3221  // Validate number of dimensions and number of elements for input/output tensors
3222  ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3223  ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3224  ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3225  ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3226  ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3227 
3228  // Validate number of dimensions and number of elements for weights tensors
3229  ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3230  auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3231  ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3232 
3233  ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3234  auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3235  ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3236 
3237  ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3238  auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3239  ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3240 
3241  ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3242  auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3243  ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3244 
3245  ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3246  auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3247  ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3248 
3249  ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3250  auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3251  ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3252  " RecurrentToForgetWeights");
3253 
3254  ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3255  auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3256  ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3257 
3258  ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3259  auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3260  ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3261 
3262  // Validate data types for weights tensors (all should match each other)
3263  ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3264 
3265  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3266  "inputToInputWeights", "inputToForgetWeights");
3267  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3268  "inputToInputWeights", "inputToCellWeights");
3269  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3270  "inputToInputWeights", "inputToOutputWeights");
3271 
3272  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3273  "inputToInputWeights", "recurrentToInputWeights");
3274  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3275  "inputToInputWeights", "recurrentToForgeteights");
3276  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3277  "inputToInputWeights", "recurrentToCellWeights");
3278  ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3279  "inputToInputWeights", "recurrentToOutputWeights");
3280 
3281  // Validate matching quantization info for weight tensors (all should match each other)
3282  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3283  descriptorName, "inputToInputWeights", "inputToForgetWeights");
3284  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3285  descriptorName, "inputToInputWeights", "inputToCellWeights");
3286  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3287  descriptorName, "inputToInputWeights", "inputToOutputWeights");
3288 
3289  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3290  descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3291  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3292  descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3293  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3294  descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3295  ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3296  descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3297 
3298  // Validate number of dimensions and number of elements in bias tensors
3299  ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3300  auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3301  ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3302 
3303  ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3304  auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3305  ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3306 
3307  ValidatePointer(m_CellBias, descriptorName, "CellBias");
3308  auto cellBiasInfo = m_CellBias->GetTensorInfo();
3309  ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3310 
3311  ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3312  auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3313  ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3314 
3315  // Validate data types for bias tensors (all should match each other)
3316  ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3317 
3318  ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3319  "inputGateBias", "forgetGateBias");
3320  ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3321  "inputGateBias", "cellBias");
3322  ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3323  "inputGateBias", "outputGateBias");
3324 
3325  // Validate bias tensor quantization info
3326  ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3327  ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3328  ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3329  ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3330 }
const ConstCpuTensorHandle * m_RecurrentToForgetWeights
const ConstCpuTensorHandle * m_InputGateBias
const ConstCpuTensorHandle * m_InputToCellWeights
std::vector< TensorInfo > m_InputTensorInfos
const ConstCpuTensorHandle * m_ForgetGateBias
const ConstCpuTensorHandle * m_RecurrentToInputWeights
std::vector< TensorInfo > m_OutputTensorInfos
const ConstCpuTensorHandle * m_RecurrentToCellWeights
const ConstCpuTensorHandle * m_RecurrentToOutputWeights
const ConstCpuTensorHandle * m_CellBias
const ConstCpuTensorHandle * m_OutputGateBias
const ConstCpuTensorHandle * m_InputToForgetWeights
const ConstCpuTensorHandle * m_InputToOutputWeights
const ConstCpuTensorHandle * m_InputToInputWeights
const TensorInfo & GetTensorInfo() const

Member Data Documentation

◆ m_CellBias

const ConstCpuTensorHandle* m_CellBias

Definition at line 617 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_ForgetGateBias

const ConstCpuTensorHandle* m_ForgetGateBias

Definition at line 616 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputGateBias

const ConstCpuTensorHandle* m_InputGateBias

Definition at line 615 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToCellWeights

const ConstCpuTensorHandle* m_InputToCellWeights

Definition at line 607 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToForgetWeights

const ConstCpuTensorHandle* m_InputToForgetWeights

Definition at line 606 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToInputWeights

const ConstCpuTensorHandle* m_InputToInputWeights

Definition at line 605 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_InputToOutputWeights

const ConstCpuTensorHandle* m_InputToOutputWeights

Definition at line 608 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_OutputGateBias

const ConstCpuTensorHandle* m_OutputGateBias

Definition at line 618 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToCellWeights

const ConstCpuTensorHandle* m_RecurrentToCellWeights

Definition at line 612 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToForgetWeights

const ConstCpuTensorHandle* m_RecurrentToForgetWeights

Definition at line 611 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToInputWeights

const ConstCpuTensorHandle* m_RecurrentToInputWeights

Definition at line 610 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().

◆ m_RecurrentToOutputWeights

const ConstCpuTensorHandle* m_RecurrentToOutputWeights

Definition at line 613 of file WorkloadData.hpp.

Referenced by QuantizedLstmLayer::CreateWorkload().


The documentation for this struct was generated from the following files: