ArmNN
 22.02
LstmVisitor Class Reference

#include <ConstTensorLayerVisitor.hpp>

Inheritance diagram for LstmVisitor:
TestLayerVisitor StrategyBase< NoThrowStrategy > IStrategy TestLstmLayerVisitor TestQLstmLayerVisitor

Public Member Functions

 LstmVisitor (const LstmInputParams &params, const char *name=nullptr)
 
- Public Member Functions inherited from TestLayerVisitor
 TestLayerVisitor (const char *name)
 
- Public Member Functions inherited from StrategyBase< NoThrowStrategy >
virtual void ExecuteStrategy (const armnn::IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0) override
 
- Public Member Functions inherited from IStrategy
virtual void FinishStrategy ()
 

Protected Member Functions

template<typename LayerType >
void CheckInputParameters (const LayerType *layer, const LstmInputParams &inputParams)
 
- Protected Member Functions inherited from TestLayerVisitor
virtual ~TestLayerVisitor ()
 
void CheckLayerName (const char *name)
 
void CheckLayerPointer (const IConnectableLayer *layer)
 
void CheckConstTensors (const ConstTensor &expected, const ConstTensor &actual)
 
void CheckConstTensors (const ConstTensor &expected, const ConstTensorHandle &actual)
 
void CheckConstTensorPtrs (const std::string &name, const ConstTensor *expected, const ConstTensor *actual)
 
void CheckConstTensorPtrs (const std::string &name, const ConstTensor *expected, const std::shared_ptr< ConstTensorHandle > actual)
 
void CheckOptionalConstTensors (const Optional< ConstTensor > &expected, const Optional< ConstTensor > &actual)
 
- Protected Member Functions inherited from StrategyBase< NoThrowStrategy >
virtual ~StrategyBase ()
 
- Protected Member Functions inherited from IStrategy
 IStrategy ()
 
virtual ~IStrategy ()
 

Protected Attributes

LstmInputParams m_InputParams
 
- Protected Attributes inherited from StrategyBase< NoThrowStrategy >
NoThrowStrategy m_DefaultStrategy
 

Detailed Description

Definition at line 258 of file ConstTensorLayerVisitor.hpp.

Constructor & Destructor Documentation

◆ LstmVisitor()

LstmVisitor ( const LstmInputParams params,
const char *  name = nullptr 
)
inlineexplicit

Definition at line 261 of file ConstTensorLayerVisitor.hpp.

263  : TestLayerVisitor(name)
264  , m_InputParams(params) {}
TestLayerVisitor(const char *name)

Member Function Documentation

◆ CheckInputParameters()

void CheckInputParameters ( const LayerType layer,
const LstmInputParams inputParams 
)
protected

Definition at line 274 of file ConstTensorLayerVisitor.hpp.

References TestLayerVisitor::CheckConstTensorPtrs(), LstmInputParams::m_CellBias, LstmInputParams::m_CellLayerNormWeights, LstmInputParams::m_CellToForgetWeights, LstmInputParams::m_CellToInputWeights, LstmInputParams::m_CellToOutputWeights, LstmInputParams::m_ForgetGateBias, LstmInputParams::m_ForgetLayerNormWeights, LstmInputParams::m_InputGateBias, LstmInputParams::m_InputLayerNormWeights, LstmInputParams::m_InputToCellWeights, LstmInputParams::m_InputToForgetWeights, LstmInputParams::m_InputToInputWeights, LstmInputParams::m_InputToOutputWeights, LstmInputParams::m_OutputGateBias, LstmInputParams::m_OutputLayerNormWeights, LstmInputParams::m_ProjectionBias, LstmInputParams::m_ProjectionWeights, LstmInputParams::m_RecurrentToCellWeights, LstmInputParams::m_RecurrentToForgetWeights, LstmInputParams::m_RecurrentToInputWeights, and LstmInputParams::m_RecurrentToOutputWeights.

275 {
276  CheckConstTensorPtrs("OutputGateBias",
277  inputParams.m_OutputGateBias,
278  layer->m_BasicParameters.m_OutputGateBias);
279  CheckConstTensorPtrs("InputToForgetWeights",
280  inputParams.m_InputToForgetWeights,
281  layer->m_BasicParameters.m_InputToForgetWeights);
282  CheckConstTensorPtrs("InputToCellWeights",
283  inputParams.m_InputToCellWeights,
284  layer->m_BasicParameters.m_InputToCellWeights);
285  CheckConstTensorPtrs("InputToOutputWeights",
286  inputParams.m_InputToOutputWeights,
287  layer->m_BasicParameters.m_InputToOutputWeights);
288  CheckConstTensorPtrs("RecurrentToForgetWeights",
289  inputParams.m_RecurrentToForgetWeights,
290  layer->m_BasicParameters.m_RecurrentToForgetWeights);
291  CheckConstTensorPtrs("RecurrentToCellWeights",
292  inputParams.m_RecurrentToCellWeights,
293  layer->m_BasicParameters.m_RecurrentToCellWeights);
294  CheckConstTensorPtrs("RecurrentToOutputWeights",
295  inputParams.m_RecurrentToOutputWeights,
296  layer->m_BasicParameters.m_RecurrentToOutputWeights);
297  CheckConstTensorPtrs("ForgetGateBias",
298  inputParams.m_ForgetGateBias,
299  layer->m_BasicParameters.m_ForgetGateBias);
300  CheckConstTensorPtrs("CellBias",
301  inputParams.m_CellBias,
302  layer->m_BasicParameters.m_CellBias);
303 
304  CheckConstTensorPtrs("InputToInputWeights",
305  inputParams.m_InputToInputWeights,
306  layer->m_CifgParameters.m_InputToInputWeights);
307  CheckConstTensorPtrs("RecurrentToInputWeights",
308  inputParams.m_RecurrentToInputWeights,
309  layer->m_CifgParameters.m_RecurrentToInputWeights);
310  CheckConstTensorPtrs("InputGateBias",
311  inputParams.m_InputGateBias,
312  layer->m_CifgParameters.m_InputGateBias);
313 
314  CheckConstTensorPtrs("ProjectionBias",
315  inputParams.m_ProjectionBias,
316  layer->m_ProjectionParameters.m_ProjectionBias);
317  CheckConstTensorPtrs("ProjectionWeights",
318  inputParams.m_ProjectionWeights,
319  layer->m_ProjectionParameters.m_ProjectionWeights);
320 
321  CheckConstTensorPtrs("CellToInputWeights",
322  inputParams.m_CellToInputWeights,
323  layer->m_PeepholeParameters.m_CellToInputWeights);
324  CheckConstTensorPtrs("CellToForgetWeights",
325  inputParams.m_CellToForgetWeights,
326  layer->m_PeepholeParameters.m_CellToForgetWeights);
327  CheckConstTensorPtrs("CellToOutputWeights",
328  inputParams.m_CellToOutputWeights,
329  layer->m_PeepholeParameters.m_CellToOutputWeights);
330 
331  CheckConstTensorPtrs("InputLayerNormWeights",
332  inputParams.m_InputLayerNormWeights,
333  layer->m_LayerNormParameters.m_InputLayerNormWeights);
334  CheckConstTensorPtrs("ForgetLayerNormWeights",
335  inputParams.m_ForgetLayerNormWeights,
336  layer->m_LayerNormParameters.m_ForgetLayerNormWeights);
337  CheckConstTensorPtrs("CellLayerNormWeights",
338  inputParams.m_CellLayerNormWeights,
339  layer->m_LayerNormParameters.m_CellLayerNormWeights);
340  CheckConstTensorPtrs("OutputLayerNormWeights",
341  inputParams.m_OutputLayerNormWeights,
342  layer->m_LayerNormParameters.m_OutputLayerNormWeights);
343 }
void CheckConstTensorPtrs(const std::string &name, const ConstTensor *expected, const ConstTensor *actual)

Member Data Documentation

◆ m_InputParams

LstmInputParams m_InputParams
protected

Definition at line 270 of file ConstTensorLayerVisitor.hpp.


The documentation for this class was generated from the following file: