ArmNN
 22.02
LstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for LstmQueueDescriptor:
QueueDescriptorWithParameters< LstmDescriptor > QueueDescriptor

Public Member Functions

 LstmQueueDescriptor ()
 
void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptorWithParameters< LstmDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Public Attributes

const ConstTensorHandlem_InputToInputWeights
 
const ConstTensorHandlem_InputToForgetWeights
 
const ConstTensorHandlem_InputToCellWeights
 
const ConstTensorHandlem_InputToOutputWeights
 
const ConstTensorHandlem_RecurrentToInputWeights
 
const ConstTensorHandlem_RecurrentToForgetWeights
 
const ConstTensorHandlem_RecurrentToCellWeights
 
const ConstTensorHandlem_RecurrentToOutputWeights
 
const ConstTensorHandlem_CellToInputWeights
 
const ConstTensorHandlem_CellToForgetWeights
 
const ConstTensorHandlem_CellToOutputWeights
 
const ConstTensorHandlem_InputGateBias
 
const ConstTensorHandlem_ForgetGateBias
 
const ConstTensorHandlem_CellBias
 
const ConstTensorHandlem_OutputGateBias
 
const ConstTensorHandlem_ProjectionWeights
 
const ConstTensorHandlem_ProjectionBias
 
const ConstTensorHandlem_InputLayerNormWeights
 
const ConstTensorHandlem_ForgetLayerNormWeights
 
const ConstTensorHandlem_CellLayerNormWeights
 
const ConstTensorHandlem_OutputLayerNormWeights
 
- Public Attributes inherited from QueueDescriptorWithParameters< LstmDescriptor >
LstmDescriptor m_Parameters
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 

Additional Inherited Members

- Protected Member Functions inherited from QueueDescriptorWithParameters< LstmDescriptor >
 QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters (QueueDescriptorWithParameters const &)=default
 
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
 
- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 420 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ LstmQueueDescriptor()

LstmQueueDescriptor ( )
inline

Definition at line 422 of file WorkloadData.hpp.

423  : m_InputToInputWeights(nullptr)
424  , m_InputToForgetWeights(nullptr)
425  , m_InputToCellWeights(nullptr)
426  , m_InputToOutputWeights(nullptr)
427  , m_RecurrentToInputWeights(nullptr)
428  , m_RecurrentToForgetWeights(nullptr)
429  , m_RecurrentToCellWeights(nullptr)
430  , m_RecurrentToOutputWeights(nullptr)
431  , m_CellToInputWeights(nullptr)
432  , m_CellToForgetWeights(nullptr)
433  , m_CellToOutputWeights(nullptr)
434  , m_InputGateBias(nullptr)
435  , m_ForgetGateBias(nullptr)
436  , m_CellBias(nullptr)
437  , m_OutputGateBias(nullptr)
438  , m_ProjectionWeights(nullptr)
439  , m_ProjectionBias(nullptr)
440  , m_InputLayerNormWeights(nullptr)
441  , m_ForgetLayerNormWeights(nullptr)
442  , m_CellLayerNormWeights(nullptr)
443  , m_OutputLayerNormWeights(nullptr)
444  {
445  }
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_CellLayerNormWeights
const ConstTensorHandle * m_CellToOutputWeights
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_ProjectionBias
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_CellToInputWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToOutputWeights

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 1965 of file WorkloadData.cpp.

References armnn::BFloat16, armnn::Float16, armnn::Float32, WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, and armnn::QSymmS16.

1966 {
1967  // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1968 
1969  const std::string descriptorName{"LstmQueueDescriptor"};
1970 
1971  // check dimensions of all inputs and outputs
1972  if (workloadInfo.m_InputTensorInfos.size() != 3)
1973  {
1974  throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1975  }
1976  if (workloadInfo.m_OutputTensorInfos.size() != 4)
1977  {
1978  throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1979  }
1980 
1981  std::vector<DataType> supportedTypes =
1982  {
1987  };
1988 
1989  // check for supported type of one input and match them with all the other input and output
1990  ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1991 
1992  // type matches all other inputs
1993  for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1994  {
1995  ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1996  workloadInfo.m_InputTensorInfos[i],
1997  descriptorName,
1998  "input_0",
1999  "input_" + std::to_string(i));
2000  }
2001  // type matches all other outputs
2002  for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
2003  {
2004  ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2005  workloadInfo.m_OutputTensorInfos[i],
2006  "LstmQueueDescriptor",
2007  "input_0",
2008  "output_" + std::to_string(i));
2009  }
2010 
2011  // Making sure clipping parameters have valid values.
2012  // == 0 means no clipping
2013  // > 0 means clipping
2014  if (m_Parameters.m_ClippingThresCell < 0.0f)
2015  {
2016  throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2017  }
2018  if (m_Parameters.m_ClippingThresProj < 0.0f)
2019  {
2020  throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2021  }
2022 
2023  // Inferring batch size, number of outputs and number of cells from the inputs.
2024  const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2025  const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2026  ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2027  const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2028  ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2029  const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2030 
2031  // input tensor
2032  ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2033  descriptorName + " input_0");
2034  // outputStateInTensor
2035  ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2036  descriptorName + " input_1");
2037  // outputStateInTensor
2038  ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2039  descriptorName + " input_2");
2040  // scratchBufferTensor
2041  unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
2042  ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2043  descriptorName + " output_0");
2044  // outputStateOutTensor
2045  ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2046  descriptorName + " output_1");
2047  // cellStateOutTensor
2048  ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2049  descriptorName + " output_2");
2050  // outputTensor
2051  ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2052  descriptorName + " output_3");
2053 
2054  // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2055  if ( m_InputToInputWeights )
2056  {
2057  ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2058  (n_cell * n_input), "InputLayerNormWeights");
2059  }
2060 
2061  ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2062  ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2063  (n_cell * n_input), "InputToForgetWeights");
2064 
2065  ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2066  ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2067  (n_cell * n_input), "InputToCellWeights");
2068 
2070  {
2071  ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2072  (n_cell * n_output), "RecurrentToInputWeights");
2073  }
2074 
2075  ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2076  ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2077  (n_cell * n_output), "RecurrentToForgetWeights");
2078 
2079  ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2080  ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2081  (n_cell * n_output), "RecurrentToCellWeights");
2082 
2083  // Make sure the input-gate's parameters are either both present (regular
2084  // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2085  bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2089  if (!cifg_weights_all_or_none)
2090  {
2091  throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2092  "RecurrentToInputWeights must either both be present (regular LSTM) "
2093  "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2094  "accordingly.");
2095  }
2096 
2097  if ( m_CellToInputWeights )
2098  {
2099  ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2100  n_cell, "CellToInputWeights");
2101  }
2102  if ( m_CellToForgetWeights )
2103  {
2104  ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2105  n_cell, "CellToForgetWeights");
2106  }
2107  if ( m_CellToOutputWeights )
2108  {
2109  ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2110  n_cell, "CellToOutputWeights");
2111  }
2112 
2113  // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2114  bool peephole_weights_all_or_none =
2119  if (!peephole_weights_all_or_none)
2120  {
2121  throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
2122  }
2123 
2124  // Make sure the input gate bias is present only when not a CIFG-LSTM.
2126  {
2127  if (m_InputGateBias)
2128  {
2129  throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
2130  }
2131  }
2132  else
2133  {
2134  if (!m_InputGateBias)
2135  {
2136  throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2137  "must be present.");
2138  }
2139  ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2140  n_cell, "InputGateBias");
2141  }
2142 
2143  ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2144  ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2145 
2146  ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2147  ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2148 
2149  ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2150  ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2151 
2152  if (m_ProjectionWeights)
2153  {
2154  ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2155  (n_cell * n_output), "ProjectionWeights");
2156  }
2157  if (m_ProjectionBias)
2158  {
2159  ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2160  }
2161 
2162  // Making sure the projection tensors are consistent:
2163  // 1) If projection weight is not present, then projection bias should not be
2164  // present.
2165  // 2) If projection weight is present, then projection bias is optional.
2166  bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2172  if (!projecton_tensors_consistent)
2173  {
2174  throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
2175  }
2176 
2177  // The four layer normalization weights either all have values or none of them have values. Additionally, if
2178  // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2179  // either all have values or none of them have values. Layer normalization is used when the values of all the
2180  // layer normalization weights are present
2182  {
2183  ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2184  }
2186  {
2187  ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2188  }
2190  {
2191  ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2192  }
2194  {
2195  ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2196  }
2197 
2199  {
2201  {
2203  {
2204  throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2205  "disabled but InputLayerNormWeights are not present");
2206  }
2207  ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2208  1, n_cell, "InputLayerNormWeights");
2209  }
2210  else if (m_InputLayerNormWeights)
2211  {
2212  throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2213  "enabled");
2214  }
2215 
2216  ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2217  "ForgetLayerNormWeights");
2218  ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2219 
2220  ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2221  "OutputLayerNormWeights");
2222  ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2223 
2224  ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2225  "CellLayerNormWeights");
2226  ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2227  }
2229  {
2230  throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2231  "normalisation weights are present.");
2232  }
2233 }
bool m_ProjectionEnabled
Enable/disable the projection layer.
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
float m_ClippingThresProj
Clipping threshold value for the projection.
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_OutputLayerNormWeights
TensorShape GetShape() const override
Get the number of elements for each dimension ordered from slowest iterating dimension to fastest ite...
const ConstTensorHandle * m_OutputGateBias
const TensorInfo & GetTensorInfo() const
std::vector< TensorInfo > m_InputTensorInfos
const ConstTensorHandle * m_CellLayerNormWeights
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstTensorHandle * m_CellToOutputWeights
std::vector< TensorInfo > m_OutputTensorInfos
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_InputToForgetWeights
float m_ClippingThresCell
Clipping threshold value for the cell state.
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_ForgetGateBias
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_ProjectionBias
bool m_LayerNormEnabled
Enable/disable layer normalization.
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_CellToInputWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToOutputWeights

Member Data Documentation

◆ m_CellBias

const ConstTensorHandle* m_CellBias

Definition at line 460 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_CellLayerNormWeights

const ConstTensorHandle* m_CellLayerNormWeights

Definition at line 466 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_CellToForgetWeights

const ConstTensorHandle* m_CellToForgetWeights

Definition at line 456 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_CellToInputWeights

const ConstTensorHandle* m_CellToInputWeights

Definition at line 455 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_CellToOutputWeights

const ConstTensorHandle* m_CellToOutputWeights

Definition at line 457 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_ForgetGateBias

const ConstTensorHandle* m_ForgetGateBias

Definition at line 459 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_ForgetLayerNormWeights

const ConstTensorHandle* m_ForgetLayerNormWeights

Definition at line 465 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_InputGateBias

const ConstTensorHandle* m_InputGateBias

Definition at line 458 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_InputLayerNormWeights

const ConstTensorHandle* m_InputLayerNormWeights

Definition at line 464 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_InputToCellWeights

const ConstTensorHandle* m_InputToCellWeights

Definition at line 449 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_InputToForgetWeights

const ConstTensorHandle* m_InputToForgetWeights

Definition at line 448 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_InputToInputWeights

const ConstTensorHandle* m_InputToInputWeights

Definition at line 447 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_InputToOutputWeights

const ConstTensorHandle* m_InputToOutputWeights

Definition at line 450 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_OutputGateBias

const ConstTensorHandle* m_OutputGateBias

Definition at line 461 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_OutputLayerNormWeights

const ConstTensorHandle* m_OutputLayerNormWeights

Definition at line 467 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_ProjectionBias

const ConstTensorHandle* m_ProjectionBias

Definition at line 463 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_ProjectionWeights

const ConstTensorHandle* m_ProjectionWeights

Definition at line 462 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_RecurrentToCellWeights

const ConstTensorHandle* m_RecurrentToCellWeights

Definition at line 453 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_RecurrentToForgetWeights

const ConstTensorHandle* m_RecurrentToForgetWeights

Definition at line 452 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_RecurrentToInputWeights

const ConstTensorHandle* m_RecurrentToInputWeights

Definition at line 451 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().

◆ m_RecurrentToOutputWeights

const ConstTensorHandle* m_RecurrentToOutputWeights

Definition at line 454 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload().


The documentation for this struct was generated from the following files: