ArmNN
 21.05
LstmLayer.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "LstmLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
13 
14 namespace armnn
15 {
16 
17 LstmLayer::LstmLayer(const LstmDescriptor& param, const char* name)
18  : LayerWithParameters(3, 4, LayerType::Lstm, param, name)
19 {
20 }
21 
22 std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24  LstmQueueDescriptor descriptor;
25 
26  // Basic parameters
34  descriptor.m_CellBias = m_BasicParameters.m_CellBias.get();
36 
37  // Cifg parameters
39  {
43  }
44 
45  // Projection parameters
47  {
50  }
51 
52  // Peephole parameters
54  {
56  {
58  }
61  }
62 
63  // Layer normalisation parameters
65  {
67  {
69  }
73  }
74 
75  SetAdditionalInfo(descriptor);
76 
77  return factory.CreateLstm(descriptor, PrepInfoAndDesc(descriptor));
78 }
79 
81 {
82  auto layer = CloneBase<LstmLayer>(graph, m_Param, GetName());
83 
86  : nullptr;
87  layer->m_BasicParameters.m_InputToCellWeights = m_BasicParameters.m_InputToCellWeights ?
89  layer->m_BasicParameters.m_InputToOutputWeights = m_BasicParameters.m_InputToOutputWeights ?
91  layer->m_BasicParameters.m_RecurrentToForgetWeights = m_BasicParameters.m_RecurrentToForgetWeights ?
93  layer->m_BasicParameters.m_RecurrentToCellWeights = m_BasicParameters.m_RecurrentToCellWeights ?
95  layer->m_BasicParameters.m_RecurrentToOutputWeights = m_BasicParameters.m_RecurrentToOutputWeights ?
97  layer->m_BasicParameters.m_ForgetGateBias = m_BasicParameters.m_ForgetGateBias ?
99  layer->m_BasicParameters.m_CellBias = m_BasicParameters.m_CellBias ?
100  m_BasicParameters.m_CellBias : nullptr;
101  layer->m_BasicParameters.m_OutputGateBias = m_BasicParameters.m_OutputGateBias ?
103 
104  if (!m_Param.m_CifgEnabled)
105  {
106  layer->m_CifgParameters.m_InputToInputWeights = m_CifgParameters.m_InputToInputWeights ?
108  layer->m_CifgParameters.m_RecurrentToInputWeights = m_CifgParameters.m_RecurrentToInputWeights ?
110  layer->m_CifgParameters.m_InputGateBias = m_CifgParameters.m_InputGateBias ?
112  }
113 
114  if (m_Param.m_ProjectionEnabled)
115  {
116  layer->m_ProjectionParameters.m_ProjectionWeights = m_ProjectionParameters.m_ProjectionWeights ?
118  layer->m_ProjectionParameters.m_ProjectionBias = m_ProjectionParameters.m_ProjectionBias ?
120  }
121 
122  if (m_Param.m_PeepholeEnabled)
123  {
124  if (!m_Param.m_CifgEnabled)
125  {
126  layer->m_PeepholeParameters.m_CellToInputWeights = m_PeepholeParameters.m_CellToInputWeights ?
128  }
129  layer->m_PeepholeParameters.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights ?
131  layer->m_PeepholeParameters.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights ?
133  }
134 
135  if (m_Param.m_LayerNormEnabled)
136  {
137  layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
139  layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
141  layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
143  layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
145  }
146 
147  return std::move(layer);
148 }
149 
150 std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
151 {
152  ARMNN_ASSERT(inputShapes.size() == 3);
153 
154  // Get input values for validation
155  unsigned int batchSize = inputShapes[0][0];
156  unsigned int outputSize = inputShapes[1][1];
157  unsigned int numUnits = inputShapes[2][1];
158 
159  std::vector<TensorShape> outShapes;
160  outShapes.push_back(TensorShape({batchSize, numUnits * (m_Param.m_CifgEnabled ? 3 : 4)}));
161  outShapes.push_back(TensorShape({batchSize, outputSize}));
162  outShapes.push_back(TensorShape({batchSize, numUnits}));
163  outShapes.push_back(TensorShape({batchSize, outputSize}));
164 
165  return outShapes;
166 }
167 
169 {
171 
172  const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
173 
175 
176  auto inferredShapes = InferOutputShapes( {
180  });
181 
182  ARMNN_ASSERT(inferredShapes.size() == 4);
183 
184  // Check if the weights are nullptr
186  "LstmLayer: m_BasicParameters.m_InputToForgetWeights should not be null.");
188  "LstmLayer: m_BasicParameters.m_InputToCellWeights should not be null.");
190  "LstmLayer: m_BasicParameters.m_InputToOutputWeights should not be null.");
192  "LstmLayer: m_BasicParameters.m_RecurrentToForgetWeights should not be null.");
194  "LstmLayer: m_BasicParameters.m_RecurrentToCellWeights should not be null.");
196  "LstmLayer: m_BasicParameters.m_RecurrentToOutputWeights should not be null.");
198  "LstmLayer: m_BasicParameters.m_ForgetGateBias should not be null.");
200  "LstmLayer: m_BasicParameters.m_CellBias should not be null.");
202  "LstmLayer: m_BasicParameters.m_OutputGateBias should not be null.");
203 
204  if (!m_Param.m_CifgEnabled)
205  {
207  "LstmLayer: m_CifgParameters.m_InputToInputWeights should not be null.");
209  "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not be null.");
211  "LstmLayer: m_CifgParameters.m_InputGateBias should not be null.");
212 
213  ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
214  }
215  else
216  {
218  "LstmLayer: m_CifgParameters.m_InputToInputWeights should not have a value when CIFG is enabled.");
220  "LstmLayer: m_CifgParameters.m_RecurrentToInputWeights should not have a value when CIFG is enabled.");
222  "LstmLayer: m_CifgParameters.m_InputGateBias should not have a value when CIFG is enabled.");
223 
224  ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "LstmLayer");
225  }
226 
228  {
230  "LstmLayer: m_ProjectionParameters.m_ProjectionWeights should not be null.");
231  }
232 
234  {
235  if (!m_Param.m_CifgEnabled)
236  {
238  "LstmLayer: m_PeepholeParameters.m_CellToInputWeights should not be null "
239  "when Peephole is enabled and CIFG is disabled.");
240  }
242  "LstmLayer: m_PeepholeParameters.m_CellToForgetWeights should not be null.");
244  "LstmLayer: m_PeepholeParameters.m_CellToOutputWeights should not be null.");
245  }
246 
248  GetOutputSlot(1).GetTensorInfo().GetShape(), inferredShapes[1], m_ShapeInferenceMethod, "LstmLayer", 1);
250  GetOutputSlot(2).GetTensorInfo().GetShape(), inferredShapes[2], m_ShapeInferenceMethod, "LstmLayer", 2);
252  GetOutputSlot(3).GetTensorInfo().GetShape(), inferredShapes[3], m_ShapeInferenceMethod, "LstmLayer", 3);
253 
255  {
257  {
259  "LstmLayer: m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
260  }
262  "LstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
264  "LstmLayer: m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
266  "LstmLayer: m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
267  }
268 }
269 
271 {
281 
282  // Cifg parameters
286 
287  // Projection parameters
290 
291  // Peephole parameters
295 
296  // Layer normalisation parameters
301 }
302 
303 void LstmLayer::Accept(ILayerVisitor& visitor) const
304 {
305  LstmInputParams inputParams;
315 
316  // Cifg parameters
320 
321  // Projection parameters
324 
325  // Peephole parameters
329 
330  // Layer normalisation parameters
335 
336  ConstTensor inputToInputWeightsTensor;
338  {
339  ConstTensor inputToInputWeightsTensorCopy(managedInputToInputWeights.GetTensorInfo(),
340  managedInputToInputWeights.Map());
341  inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
342  inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
343  }
344  ConstTensor inputToForgetWeightsTensor;
346  {
347  ConstTensor inputToForgetWeightsTensorCopy(managedInputToForgetWeights.GetTensorInfo(),
348  managedInputToForgetWeights.Map());
349  inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
350  inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
351  }
352  ConstTensor inputToCellWeightsTensor;
354  {
355  ConstTensor inputToCellWeightsTensorCopy(managedInputToCellWeights.GetTensorInfo(),
356  managedInputToCellWeights.Map());
357  inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
358  inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
359  }
360  ConstTensor inputToOutputWeightsTensor;
362  {
363  ConstTensor inputToOutputWeightsTensorCopy(managedInputToOutputWeights.GetTensorInfo(),
364  managedInputToOutputWeights.Map());
365  inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
366  inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
367  }
368  ConstTensor recurrentToInputWeightsTensor;
370  {
371  ConstTensor recurrentToInputWeightsTensorCopy(
372  managedRecurrentToInputWeights.GetTensorInfo(),
373  managedRecurrentToInputWeights.Map());
374  recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
375  inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
376  }
377  ConstTensor recurrentToForgetWeightsTensor;
379  {
380  ConstTensor recurrentToForgetWeightsTensorCopy(
381  managedRecurrentToForgetWeights.GetTensorInfo(),
382  managedRecurrentToForgetWeights.Map());
383  recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
384  inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
385  }
386  ConstTensor recurrentToCellWeightsTensor;
388  {
389  ConstTensor recurrentToCellWeightsTensorCopy(
390  managedRecurrentToCellWeights.GetTensorInfo(),
391  managedRecurrentToCellWeights.Map());
392  recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
393  inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
394  }
395  ConstTensor recurrentToOutputWeightsTensor;
397  {
398  ConstTensor recurrentToOutputWeightsTensorCopy(
399  managedRecurrentToOutputWeights.GetTensorInfo(),
400  managedRecurrentToOutputWeights.Map());
401  recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
402  inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
403  }
404  ConstTensor cellToInputWeightsTensor;
406  {
407  ConstTensor cellToInputWeightsTensorCopy(managedCellToInputWeights.GetTensorInfo(),
408  managedCellToInputWeights.Map());
409  cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
410  inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
411  }
412  ConstTensor cellToForgetWeightsTensor;
414  {
415  ConstTensor cellToForgetWeightsTensorCopy(managedCellToForgetWeights.GetTensorInfo(),
416  managedCellToForgetWeights.Map());
417  cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy;
418  inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
419  }
420  ConstTensor cellToOutputWeightsTensor;
422  {
423  ConstTensor cellToOutputWeightsTensorCopy(managedCellToOutputWeights.GetTensorInfo(),
424  managedCellToOutputWeights.Map());
425  cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy;
426  inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
427  }
428  ConstTensor inputGateBiasTensor;
429  if (m_CifgParameters.m_InputGateBias != nullptr)
430  {
431  ConstTensor inputGateBiasTensorCopy(managedInputGateBias.GetTensorInfo(),
432  managedInputGateBias.Map());
433  inputGateBiasTensor = inputGateBiasTensorCopy;
434  inputParams.m_InputGateBias = &inputGateBiasTensor;
435  }
436  ConstTensor forgetGateBiasTensor;
437  if (m_BasicParameters.m_ForgetGateBias != nullptr)
438  {
439  ConstTensor forgetGateBiasTensorCopy(managedForgetGateBias.GetTensorInfo(),
440  managedForgetGateBias.Map());
441  forgetGateBiasTensor = forgetGateBiasTensorCopy;
442  inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
443  }
444  ConstTensor cellBiasTensor;
445  if (m_BasicParameters.m_CellBias != nullptr)
446  {
447  ConstTensor cellBiasTensorCopy(managedCellBias.GetTensorInfo(),
448  managedCellBias.Map());
449  cellBiasTensor = cellBiasTensorCopy;
450  inputParams.m_CellBias = &cellBiasTensor;
451  }
452  ConstTensor outputGateBias;
453  if (m_BasicParameters.m_OutputGateBias != nullptr)
454  {
455  ConstTensor outputGateBiasCopy(managedOutputGateBias.GetTensorInfo(),
456  managedOutputGateBias.Map());
457  outputGateBias = outputGateBiasCopy;
458  inputParams.m_OutputGateBias = &outputGateBias;
459  }
460  ConstTensor projectionWeightsTensor;
462  {
463  ConstTensor projectionWeightsTensorCopy(managedProjectionWeights.GetTensorInfo(),
464  managedProjectionWeights.Map());
465  projectionWeightsTensor = projectionWeightsTensorCopy;
466  inputParams.m_ProjectionWeights = &projectionWeightsTensor;
467  }
468  ConstTensor projectionBiasTensor;
470  {
471  ConstTensor projectionBiasTensorCopy(managedProjectionBias.GetTensorInfo(),
472  managedProjectionBias.Map());
473  projectionBiasTensor = projectionBiasTensorCopy;
474  inputParams.m_ProjectionBias = &projectionBiasTensor;
475  }
476  ConstTensor inputLayerNormTensor;
478  {
479  ConstTensor inputLayerNormTensorCopy(managedInputLayerNormWeights.GetTensorInfo(),
480  managedInputLayerNormWeights.Map());
481  inputLayerNormTensor = inputLayerNormTensorCopy;
482  inputParams.m_InputLayerNormWeights = &inputLayerNormTensor;
483  }
484  ConstTensor forgetLayerNormTensor;
486  {
487  ConstTensor forgetLayerNormTensorCopy(managedForgetLayerNormWeights.GetTensorInfo(),
488  managedForgetLayerNormWeights.Map());
489  forgetLayerNormTensor = forgetLayerNormTensorCopy;
490  inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor;
491  }
492  ConstTensor cellLayerNormTensor;
494  {
495  ConstTensor cellLayerNormTensorCopy(managedCellLayerNormWeights.GetTensorInfo(),
496  managedCellLayerNormWeights.Map());
497  cellLayerNormTensor = cellLayerNormTensorCopy;
498  inputParams.m_CellLayerNormWeights = &cellLayerNormTensor;
499  }
500  ConstTensor outputLayerNormTensor;
502  {
503  ConstTensor outputLayerNormTensorCopy(managedOutputLayerNormWeights.GetTensorInfo(),
504  managedOutputLayerNormWeights.Map());
505  outputLayerNormTensor = outputLayerNormTensorCopy;
506  inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor;
507  }
508 
509 
510  visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName());
511 }
512 
514 {
515  std::vector<ConstTensor> constTensors;
516 
517  LstmDescriptor descriptor = GetParameters();
518 
528 
529  // Cifg parameters
533 
534  // Projection parameters
537 
538  // Peephole parameters
542 
543  // Layer normalisation parameters
548 
549  // First add mandatory/basic parameters
551  {
552  constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
553  managedInputToForgetWeights.Map()));
554  }
556  {
557  constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
558  managedInputToCellWeights.Map()));
559  }
561  {
562  constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
563  managedInputToOutputWeights.Map()));
564  }
566  {
567  constTensors.emplace_back(ConstTensor(
568  managedRecurrentToForgetWeights.GetTensorInfo(),
569  managedRecurrentToForgetWeights.Map()));
570  }
572  {
573  constTensors.emplace_back(ConstTensor(
574  managedRecurrentToCellWeights.GetTensorInfo(),
575  managedRecurrentToCellWeights.Map()));
576  }
578  {
579  constTensors.emplace_back(ConstTensor(
580  managedRecurrentToOutputWeights.GetTensorInfo(),
581  managedRecurrentToOutputWeights.Map()));
582  }
583  if (m_BasicParameters.m_ForgetGateBias != nullptr)
584  {
585  constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
586  managedForgetGateBias.Map()));
587  }
588  if (m_BasicParameters.m_CellBias != nullptr)
589  {
590  constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
591  managedCellBias.Map()));
592  }
593  if (m_BasicParameters.m_OutputGateBias != nullptr)
594  {
595  constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
596  managedOutputGateBias.Map()));
597  }
598 
599  // Add cifg parameters
600  if (!descriptor.m_CifgEnabled)
601  {
603  {
604  constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
605  managedInputToInputWeights.Map()));
606  }
608  {
609  constTensors.emplace_back(ConstTensor(
610  managedRecurrentToInputWeights.GetTensorInfo(),
611  managedRecurrentToInputWeights.Map()));
612  }
613  if (m_CifgParameters.m_InputGateBias != nullptr)
614  {
615  constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
616  managedInputGateBias.Map()));
617  }
618  }
619 
620  // Add peephole parameters
621  if (descriptor.m_PeepholeEnabled)
622  {
623  if (!descriptor.m_CifgEnabled)
624  {
626  {
627  constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
628  managedCellToInputWeights.Map()));
629  }
630  }
632  {
633  constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
634  managedCellToForgetWeights.Map()));
635  }
637  {
638  constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
639  managedCellToOutputWeights.Map()));
640  }
641  }
642 
643  // Add projection parameters
644  if (descriptor.m_ProjectionEnabled)
645  {
647  {
648  constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
649  managedProjectionWeights.Map()));
650  }
652  {
653  constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
654  managedProjectionBias.Map()));
655  }
656  }
657 
658  // Add norm parameters
659  if (descriptor.m_LayerNormEnabled)
660  {
661  if (!descriptor.m_CifgEnabled)
662  {
664  {
665  constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
666  managedInputLayerNormWeights.Map()));
667  }
668  }
670  {
671  constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
672  managedForgetLayerNormWeights.Map()));
673  }
675  {
676  constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
677  managedCellLayerNormWeights.Map()));
678  }
680  {
681  constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
682  managedOutputLayerNormWeights.Map()));
683  }
684  }
685 
686  strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
687 }
688 
689 } // namespace armnn
std::shared_ptr< ConstTensorHandle > m_ForgetGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:69
std::shared_ptr< ConstTensorHandle > m_OutputGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:73
bool m_ProjectionEnabled
Enable/disable the projection layer.
LstmBasicParameters m_BasicParameters
Definition: LstmLayer.hpp:81
const ConstTensor * m_ProjectionWeights
Definition: LstmParams.hpp:55
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
std::shared_ptr< ConstTensorHandle > m_OutputLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:23
const ConstTensor * m_CellBias
Definition: LstmParams.hpp:53
std::shared_ptr< ConstTensorHandle > m_CellToForgetWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:49
LstmDescriptor m_Param
The parameters for the layer (not including tensor-valued weights etc.).
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
const ConstTensorHandle * m_InputGateBias
const ConstTensor * m_CellToOutputWeights
Definition: LstmParams.hpp:50
const ConstTensorHandle * m_RecurrentToCellWeights
virtual void VisitLstmLayer(const IConnectableLayer *layer, const LstmDescriptor &descriptor, const LstmInputParams &params, const char *name=nullptr)=0
Function an Lstm layer should call back to when its Accept(ILayerVisitor&) function is invoked...
const ConstTensorHandle * m_CellBias
virtual void ExecuteStrategy(const armnn::IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0)=0
virtual std::unique_ptr< IWorkload > CreateLstm(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
const ConstTensor * m_CellToInputWeights
Definition: LstmParams.hpp:48
std::shared_ptr< ConstTensorHandle > m_InputLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:17
const ConstTensor * m_InputGateBias
Definition: LstmParams.hpp:51
const ConstTensor * m_RecurrentToCellWeights
Definition: LstmParams.hpp:46
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the LSTM type.
Definition: LstmLayer.cpp:22
const ConstTensorHandle * m_InputToOutputWeights
std::shared_ptr< ConstTensorHandle > m_ProjectionWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
Definition: LstmLayer.hpp:39
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensor * m_ForgetLayerNormWeights
Definition: LstmParams.hpp:58
void VerifyShapeInferenceType(const TensorShape &outputShape, ShapeInferenceMethod shapeInferenceMethod)
Definition: Layer.cpp:433
const ConstTensor * m_CellToForgetWeights
Definition: LstmParams.hpp:49
const TensorInfo & GetTensorInfo() const
Copyright (c) 2021 ARM Limited and Contributors.
std::shared_ptr< ConstTensorHandle > m_InputToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
Definition: LstmLayer.hpp:59
This layer represents a LSTM operation.
Definition: LstmLayer.hpp:77
const IOutputSlot * GetConnection() const override
Definition: Layer.hpp:199
void ExecuteStrategy(IStrategy &strategy) const override
Apply a visitor to this layer.
Definition: LstmLayer.cpp:513
void ValidateAndCopyShape(const TensorShape &outputShape, const TensorShape &inferredShape, const ShapeInferenceMethod shapeInferenceMethod, const std::string &layerName, const unsigned int outputSlotIndex=0)
Definition: Layer.cpp:393
const ConstTensor * m_OutputGateBias
Definition: LstmParams.hpp:54
std::shared_ptr< ConstTensorHandle > m_InputGateBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:33
const ConstTensorHandle * m_OutputGateBias
std::shared_ptr< ConstTensorHandle > m_CellToOutputWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:51
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
Definition: Layer.cpp:349
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:316
const ConstTensor * m_InputLayerNormWeights
Definition: LstmParams.hpp:57
std::shared_ptr< ConstTensorHandle > m_RecurrentToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
Definition: LstmLayer.hpp:65
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of LstmLayer.
Definition: LstmLayer.cpp:168
std::shared_ptr< ConstTensorHandle > m_CellBias
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:71
const ConstTensorHandle * m_CellLayerNormWeights
std::vector< std::reference_wrapper< std::shared_ptr< ConstTensorHandle > >> ConstantTensors
Definition: Layer.hpp:393
const ConstTensor * m_RecurrentToOutputWeights
Definition: LstmParams.hpp:47
LstmLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
Definition: LstmLayer.cpp:80
An LstmDescriptor for the LstmLayer.
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
const ConstTensor * m_ProjectionBias
Definition: LstmParams.hpp:56
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:314
std::shared_ptr< ConstTensorHandle > m_RecurrentToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
Definition: LstmLayer.hpp:31
bool m_PeepholeEnabled
Enable/disable peephole.
const ConstTensorHandle * m_CellToOutputWeights
std::shared_ptr< ConstTensorHandle > m_CellLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:21
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
LstmOptLayerNormParameters m_LayerNormParameters
Definition: LstmLayer.hpp:85
std::shared_ptr< ConstTensorHandle > m_RecurrentToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
Definition: LstmLayer.hpp:67
const ConstTensorHandle * m_InputToCellWeights
std::shared_ptr< ConstTensorHandle > m_InputToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
Definition: LstmLayer.hpp:29
const ConstTensorHandle * m_InputToForgetWeights
#define CHECK_LOCATION()
Definition: Exceptions.hpp:197
std::shared_ptr< ConstTensorHandle > m_ProjectionBias
A unique pointer to represent 1D weights tensor with dimensions [output_size].
Definition: LstmLayer.hpp:41
const ConstTensorHandle * m_RecurrentToInputWeights
Layer::ConstantTensors GetConstantTensorsByRef() override
Retrieve the handles to the constant values stored by the layer.
Definition: LstmLayer.cpp:270
LstmOptPeepholeParameters m_PeepholeParameters
Definition: LstmLayer.hpp:84
void SetAdditionalInfo(QueueDescriptor &descriptor) const
Definition: Layer.cpp:245
const ConstTensor * m_CellLayerNormWeights
Definition: LstmParams.hpp:59
const ConstTensor * m_ForgetGateBias
Definition: LstmParams.hpp:52
const ConstTensor * m_InputToCellWeights
Definition: LstmParams.hpp:42
const ConstTensorHandle * m_ForgetGateBias
const ConstTensor * m_InputToOutputWeights
Definition: LstmParams.hpp:43
LstmOptProjectionParameters m_ProjectionParameters
Definition: LstmLayer.hpp:83
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
void Accept(ILayerVisitor &visitor) const override
Apply a visitor to this layer.
Definition: LstmLayer.cpp:303
std::shared_ptr< ConstTensorHandle > m_InputToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
Definition: LstmLayer.hpp:57
const ConstTensor * m_RecurrentToForgetWeights
Definition: LstmParams.hpp:45
std::shared_ptr< ConstTensorHandle > m_RecurrentToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [output_size, num_units].
Definition: LstmLayer.hpp:63
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_ProjectionBias
const ConstTensor * m_RecurrentToInputWeights
Definition: LstmParams.hpp:44
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *LayerCreateWorkload.
LstmLayer(const LstmDescriptor &param, const char *name)
Constructor to create a LstmLayer.
Definition: LstmLayer.cpp:17
bool m_LayerNormEnabled
Enable/disable layer normalization.
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:318
const ConstTensorHandle * m_ForgetLayerNormWeights
virtual const TensorInfo & GetTensorInfo() const =0
LstmOptCifgParameters m_CifgParameters
Definition: LstmLayer.hpp:82
const char * GetName() const override
Returns the name of the layer.
Definition: Layer.hpp:311
const ConstTensorHandle * m_InputLayerNormWeights
std::shared_ptr< ConstTensorHandle > m_ForgetLayerNormWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:19
std::vector< TensorShape > InferOutputShapes(const std::vector< TensorShape > &inputShapes) const override
By default returns inputShapes if the number of inputs are equal to number of outputs, otherwise infers the output shapes from given input shapes and layer properties.
Definition: LstmLayer.cpp:150
std::shared_ptr< ConstTensorHandle > m_CellToInputWeights
A unique pointer to represent 1D weights tensor with dimensions [num_units].
Definition: LstmLayer.hpp:47
const ConstTensor * m_OutputLayerNormWeights
Definition: LstmParams.hpp:60
const void * Map(bool blocking=true)
RAII Managed resource Unmaps MemoryArea once out of scope.
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
const ConstTensorHandle * m_CellToInputWeights
const TensorInfo & GetTensorInfo() const override
Definition: Layer.cpp:63
const ConstTensorHandle * m_InputToInputWeights
ShapeInferenceMethod m_ShapeInferenceMethod
Definition: Layer.hpp:408
const ConstTensor * m_InputToForgetWeights
Definition: LstmParams.hpp:41
std::shared_ptr< ConstTensorHandle > m_InputToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
Definition: LstmLayer.hpp:61
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below...
Definition: Types.hpp:455
const ConstTensor * m_InputToInputWeights
Definition: LstmParams.hpp:40
const ConstTensorHandle * m_RecurrentToOutputWeights