ArmNN
 22.05
LstmVisitor Class Reference

#include <ConstTensorLayerVisitor.hpp>

Inheritance diagram for LstmVisitor:
TestLayerVisitor StrategyBase< NoThrowStrategy > IStrategy TestLstmLayerVisitor TestQLstmLayerVisitor

Public Member Functions

 LstmVisitor (const LstmInputParams &params, const char *name=nullptr)
 
- Public Member Functions inherited from TestLayerVisitor
 TestLayerVisitor (const char *name)
 
- Public Member Functions inherited from StrategyBase< NoThrowStrategy >
virtual void ExecuteStrategy (const armnn::IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0) override
 
- Public Member Functions inherited from IStrategy
virtual void FinishStrategy ()
 

Protected Member Functions

template<typename LayerType >
void CheckInputParameters (const LayerType *layer, const LstmInputParams &inputParams)
 
- Protected Member Functions inherited from TestLayerVisitor
virtual ~TestLayerVisitor ()
 
void CheckLayerName (const char *name)
 
void CheckLayerPointer (const IConnectableLayer *layer)
 
void CheckConstTensors (const ConstTensor &expected, const ConstTensor &actual)
 
void CheckConstTensors (const ConstTensor &expected, const ConstTensorHandle &actual)
 
void CheckConstTensorPtrs (const std::string &name, const ConstTensor *expected, const ConstTensor *actual)
 
void CheckConstTensorPtrs (const std::string &name, const ConstTensor *expected, const std::shared_ptr< ConstTensorHandle > actual)
 
void CheckOptionalConstTensors (const Optional< ConstTensor > &expected, const Optional< ConstTensor > &actual)
 
- Protected Member Functions inherited from StrategyBase< NoThrowStrategy >
virtual ~StrategyBase ()
 
- Protected Member Functions inherited from IStrategy
 IStrategy ()
 
virtual ~IStrategy ()
 

Protected Attributes

LstmInputParams m_InputParams
 
- Protected Attributes inherited from StrategyBase< NoThrowStrategy >
NoThrowStrategy m_DefaultStrategy
 

Detailed Description

Definition at line 234 of file ConstTensorLayerVisitor.hpp.

Constructor & Destructor Documentation

◆ LstmVisitor()

LstmVisitor ( const LstmInputParams params,
const char *  name = nullptr 
)
inlineexplicit

Definition at line 237 of file ConstTensorLayerVisitor.hpp.

239  : TestLayerVisitor(name)
240  , m_InputParams(params) {}
TestLayerVisitor(const char *name)

Member Function Documentation

◆ CheckInputParameters()

void CheckInputParameters ( const LayerType layer,
const LstmInputParams inputParams 
)
protected

Definition at line 250 of file ConstTensorLayerVisitor.hpp.

References TestLayerVisitor::CheckConstTensorPtrs(), LstmInputParams::m_CellBias, LstmInputParams::m_CellLayerNormWeights, LstmInputParams::m_CellToForgetWeights, LstmInputParams::m_CellToInputWeights, LstmInputParams::m_CellToOutputWeights, LstmInputParams::m_ForgetGateBias, LstmInputParams::m_ForgetLayerNormWeights, LstmInputParams::m_InputGateBias, LstmInputParams::m_InputLayerNormWeights, LstmInputParams::m_InputToCellWeights, LstmInputParams::m_InputToForgetWeights, LstmInputParams::m_InputToInputWeights, LstmInputParams::m_InputToOutputWeights, LstmInputParams::m_OutputGateBias, LstmInputParams::m_OutputLayerNormWeights, LstmInputParams::m_ProjectionBias, LstmInputParams::m_ProjectionWeights, LstmInputParams::m_RecurrentToCellWeights, LstmInputParams::m_RecurrentToForgetWeights, LstmInputParams::m_RecurrentToInputWeights, and LstmInputParams::m_RecurrentToOutputWeights.

251 {
252  CheckConstTensorPtrs("OutputGateBias",
253  inputParams.m_OutputGateBias,
254  layer->m_BasicParameters.m_OutputGateBias);
255  CheckConstTensorPtrs("InputToForgetWeights",
256  inputParams.m_InputToForgetWeights,
257  layer->m_BasicParameters.m_InputToForgetWeights);
258  CheckConstTensorPtrs("InputToCellWeights",
259  inputParams.m_InputToCellWeights,
260  layer->m_BasicParameters.m_InputToCellWeights);
261  CheckConstTensorPtrs("InputToOutputWeights",
262  inputParams.m_InputToOutputWeights,
263  layer->m_BasicParameters.m_InputToOutputWeights);
264  CheckConstTensorPtrs("RecurrentToForgetWeights",
265  inputParams.m_RecurrentToForgetWeights,
266  layer->m_BasicParameters.m_RecurrentToForgetWeights);
267  CheckConstTensorPtrs("RecurrentToCellWeights",
268  inputParams.m_RecurrentToCellWeights,
269  layer->m_BasicParameters.m_RecurrentToCellWeights);
270  CheckConstTensorPtrs("RecurrentToOutputWeights",
271  inputParams.m_RecurrentToOutputWeights,
272  layer->m_BasicParameters.m_RecurrentToOutputWeights);
273  CheckConstTensorPtrs("ForgetGateBias",
274  inputParams.m_ForgetGateBias,
275  layer->m_BasicParameters.m_ForgetGateBias);
276  CheckConstTensorPtrs("CellBias",
277  inputParams.m_CellBias,
278  layer->m_BasicParameters.m_CellBias);
279 
280  CheckConstTensorPtrs("InputToInputWeights",
281  inputParams.m_InputToInputWeights,
282  layer->m_CifgParameters.m_InputToInputWeights);
283  CheckConstTensorPtrs("RecurrentToInputWeights",
284  inputParams.m_RecurrentToInputWeights,
285  layer->m_CifgParameters.m_RecurrentToInputWeights);
286  CheckConstTensorPtrs("InputGateBias",
287  inputParams.m_InputGateBias,
288  layer->m_CifgParameters.m_InputGateBias);
289 
290  CheckConstTensorPtrs("ProjectionBias",
291  inputParams.m_ProjectionBias,
292  layer->m_ProjectionParameters.m_ProjectionBias);
293  CheckConstTensorPtrs("ProjectionWeights",
294  inputParams.m_ProjectionWeights,
295  layer->m_ProjectionParameters.m_ProjectionWeights);
296 
297  CheckConstTensorPtrs("CellToInputWeights",
298  inputParams.m_CellToInputWeights,
299  layer->m_PeepholeParameters.m_CellToInputWeights);
300  CheckConstTensorPtrs("CellToForgetWeights",
301  inputParams.m_CellToForgetWeights,
302  layer->m_PeepholeParameters.m_CellToForgetWeights);
303  CheckConstTensorPtrs("CellToOutputWeights",
304  inputParams.m_CellToOutputWeights,
305  layer->m_PeepholeParameters.m_CellToOutputWeights);
306 
307  CheckConstTensorPtrs("InputLayerNormWeights",
308  inputParams.m_InputLayerNormWeights,
309  layer->m_LayerNormParameters.m_InputLayerNormWeights);
310  CheckConstTensorPtrs("ForgetLayerNormWeights",
311  inputParams.m_ForgetLayerNormWeights,
312  layer->m_LayerNormParameters.m_ForgetLayerNormWeights);
313  CheckConstTensorPtrs("CellLayerNormWeights",
314  inputParams.m_CellLayerNormWeights,
315  layer->m_LayerNormParameters.m_CellLayerNormWeights);
316  CheckConstTensorPtrs("OutputLayerNormWeights",
317  inputParams.m_OutputLayerNormWeights,
318  layer->m_LayerNormParameters.m_OutputLayerNormWeights);
319 }
void CheckConstTensorPtrs(const std::string &name, const ConstTensor *expected, const ConstTensor *actual)

Member Data Documentation

◆ m_InputParams

LstmInputParams m_InputParams
protected

Definition at line 246 of file ConstTensorLayerVisitor.hpp.


The documentation for this class was generated from the following file: