aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/Lstm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/Lstm.hpp')
-rw-r--r--delegate/src/Lstm.hpp82
1 files changed, 33 insertions, 49 deletions
diff --git a/delegate/src/Lstm.hpp b/delegate/src/Lstm.hpp
index 829e3bf9c6..8d719ee351 100644
--- a/delegate/src/Lstm.hpp
+++ b/delegate/src/Lstm.hpp
@@ -19,22 +19,6 @@
namespace armnnDelegate
{
-bool IsOptional(TfLiteNode* tfLiteNode, const int index)
-{
- if (tfLiteNode->inputs->data[index] < 0) {
- return true;
- }
- return false;
-
-}
-
-armnn::ConstTensor* CreateConstTensor(const TfLiteTensor* tfLiteTensors, TfLiteNode* tfLiteNode, int index)
-{
- const TfLiteTensor &tfLiteTensor = tfLiteTensors[tfLiteNode->inputs->data[index]];
- armnn::TensorInfo tensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensor);
- return new armnn::ConstTensor(tensorInfo, tfLiteTensor.data.data);
-}
-
TfLiteStatus VisitLstmOperator(DelegateData& delegateData,
TfLiteContext* tfLiteContext,
TfLiteNode* tfLiteNode,
@@ -68,60 +52,60 @@ TfLiteStatus VisitLstmOperator(DelegateData& delegateData,
// Set the params structure for the AddLstmLayer call
armnn::LstmInputParams params;
- if (!IsOptional(tfLiteNode, 1))
+ if (!IsOptionalOperandPresent(tfLiteNode, 1))
{
- params.m_InputToInputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 1);
+ params.m_InputToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 1);
}
- params.m_InputToForgetWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 2);
- params.m_InputToCellWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 3);
- params.m_InputToOutputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 4);
+ params.m_InputToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 2);
+ params.m_InputToCellWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 3);
+ params.m_InputToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 4);
// Recurrent weight tensors of size {n_cell, n_output}
- if (!IsOptional(tfLiteNode, 5))
+ if (!IsOptionalOperandPresent(tfLiteNode, 5))
{
- params.m_RecurrentToInputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 5);
+ params.m_RecurrentToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 5);
}
- params.m_RecurrentToForgetWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 6);
- params.m_RecurrentToCellWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 7);
- params.m_RecurrentToOutputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 8);
+ params.m_RecurrentToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 6);
+ params.m_RecurrentToCellWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 7);
+ params.m_RecurrentToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 8);
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
- if (!IsOptional(tfLiteNode, 9))
+ if (!IsOptionalOperandPresent(tfLiteNode, 9))
{
- params.m_CellToInputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 9);
+ params.m_CellToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 9);
}
- if (!IsOptional(tfLiteNode, 10))
+ if (!IsOptionalOperandPresent(tfLiteNode, 10))
{
- params.m_CellToForgetWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 10);
+ params.m_CellToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 10);
}
- if (!IsOptional(tfLiteNode, 11))
+ if (!IsOptionalOperandPresent(tfLiteNode, 11))
{
- params.m_CellToOutputWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 11);
+ params.m_CellToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 11);
}
// Gates bias tensors of size {n_cell}
- if (!IsOptional(tfLiteNode, 12))
+ if (!IsOptionalOperandPresent(tfLiteNode, 12))
{
- params.m_InputGateBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 12);
+ params.m_InputGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 12);
}
- params.m_ForgetGateBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 13);
- params.m_CellBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 14);
- params.m_OutputGateBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 15);
+ params.m_ForgetGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 13);
+ params.m_CellBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 14);
+ params.m_OutputGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 15);
// Projection weight tensor of size {n_output, n_cell}
- if (!IsOptional(tfLiteNode, 16))
+ if (!IsOptionalOperandPresent(tfLiteNode, 16))
{
- params.m_ProjectionWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 16);
+ params.m_ProjectionWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 16);
}
// Projection bias tensor of size {n_output}
- if (!IsOptional(tfLiteNode, 17))
+ if (!IsOptionalOperandPresent(tfLiteNode, 17))
{
- params.m_ProjectionBias = CreateConstTensor(tfLiteTensors, tfLiteNode, 17);
+ params.m_ProjectionBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 17);
}
// These state tensors are defined as variable tensors, and will be modified by this op.
@@ -129,24 +113,24 @@ TfLiteStatus VisitLstmOperator(DelegateData& delegateData,
armnn::TensorInfo cellStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[19]]);
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal matrix.
- if (tfLiteNode->inputs->size >= 21 && !IsOptional(tfLiteNode, 20))
+ if (tfLiteNode->inputs->size >= 21 && !IsOptionalOperandPresent(tfLiteNode, 20))
{
- params.m_InputLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 20);
+ params.m_InputLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 20);
}
- if (tfLiteNode->inputs->size >= 22 && !IsOptional(tfLiteNode, 21))
+ if (tfLiteNode->inputs->size >= 22 && !IsOptionalOperandPresent(tfLiteNode, 21))
{
- params.m_ForgetLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 21);
+ params.m_ForgetLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 21);
}
- if (tfLiteNode->inputs->size >= 23 && !IsOptional(tfLiteNode, 22))
+ if (tfLiteNode->inputs->size >= 23 && !IsOptionalOperandPresent(tfLiteNode, 22))
{
- params.m_CellLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 22);
+ params.m_CellLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 22);
}
- if (tfLiteNode->inputs->size >= 24 && !IsOptional(tfLiteNode, 23))
+ if (tfLiteNode->inputs->size >= 24 && !IsOptionalOperandPresent(tfLiteNode, 23))
{
- params.m_OutputLayerNormWeights = CreateConstTensor(tfLiteTensors, tfLiteNode, 23);
+ params.m_OutputLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 23);
}
// set the layer descriptor