aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-08 09:57:55 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-10 09:15:04 +0000
commitad5293a86e315049de36afd723dcd1a7e70681a7 (patch)
treeb9003cd1fba00c267a971d899284b3fcbd5ce6f5
parent8b797a84f1e8f9d1d5d064afbc4fc12c21b8ffed (diff)
downloadarmnn-ad5293a86e315049de36afd723dcd1a7e70681a7.tar.gz
IVGCVSW-3337 Add Neon backend support for LSTM layer normalisation
* Update neon lstm workload * Add unit tests * Add isLstmSupported Change-Id: I493c159137f6544b0f2532d16d4fafd7a7e587e5 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp25
-rw-r--r--src/backends/neon/NeonLayerSupport.hpp11
-rw-r--r--src/backends/neon/test/NeonCreateWorkloadTests.cpp23
-rw-r--r--src/backends/neon/test/NeonLayerTests.cpp2
-rw-r--r--src/backends/neon/workloads/NeonLstmFloatWorkload.cpp148
-rw-r--r--src/backends/neon/workloads/NeonLstmFloatWorkload.hpp22
6 files changed, 164 insertions, 67 deletions
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index 4fee53f51f..ea875f6926 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -26,6 +26,7 @@
#include "workloads/NeonDequantizeWorkload.hpp"
#include "workloads/NeonGreaterWorkload.hpp"
#include "workloads/NeonL2NormalizationFloatWorkload.hpp"
+#include "workloads/NeonLstmFloatWorkload.hpp"
#include "workloads/NeonMaximumWorkload.hpp"
#include "workloads/NeonMeanWorkload.hpp"
#include "workloads/NeonConcatWorkload.hpp"
@@ -334,6 +335,30 @@ bool NeonLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
FORWARD_WORKLOAD_VALIDATE_FUNC(NeonL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
}
+bool NeonLayerSupport::IsLstmSupported(const TensorInfo& input,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& scratchBuffer,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& output,
+ const LstmDescriptor& descriptor,
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLstmFloatWorkloadValidate,
+ reasonIfUnsupported,
+ input,
+ outputStateIn,
+ cellStateIn,
+ scratchBuffer,
+ outputStateOut,
+ cellStateOut,
+ output,
+ descriptor,
+ paramsInfo);
+}
+
bool NeonLayerSupport::IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index 315248c79d..318cad7424 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -96,6 +96,17 @@ public:
const L2NormalizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsLstmSupported(const TensorInfo& input,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& scratchBuffer,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& output,
+ const LstmDescriptor& descriptor,
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsMaximumSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
index 4968d0ed90..49c5a72a90 100644
--- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp
+++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp
@@ -710,6 +710,29 @@ BOOST_AUTO_TEST_CASE(CreateL2NormalizationNhwcWorkload)
NeonCreateL2NormalizationWorkloadTest<NeonL2NormalizationFloatWorkload, DataType::Float32>(DataLayout::NHWC);
}
+template <typename LstmWorkloadType>
+static void NeonCreateLstmWorkloadTest()
+{
+ Graph graph;
+ NeonWorkloadFactory factory =
+ NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager());
+
+ auto workload = CreateLstmWorkloadTest<LstmWorkloadType>(factory, graph);
+
+ LstmQueueDescriptor queueDescriptor = workload->GetData();
+
+ auto inputHandle = boost::polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Inputs[0]);
+ auto outputHandle = boost::polymorphic_downcast<IAclTensorHandle*>(queueDescriptor.m_Outputs[1]);
+
+ BOOST_TEST(TestNeonTensorHandleInfo(inputHandle, TensorInfo({ 2, 2 }, DataType::Float32)));
+ BOOST_TEST(TestNeonTensorHandleInfo(outputHandle, TensorInfo({ 2, 4 }, DataType::Float32)));
+}
+
+BOOST_AUTO_TEST_CASE(CreateLSTMWorkloadFloatWorkload)
+{
+ NeonCreateLstmWorkloadTest<NeonLstmFloatWorkload>();
+}
+
template <typename ConcatWorkloadType, armnn::DataType DataType>
static void NeonCreateConcatWorkloadTest(std::initializer_list<unsigned int> outputShape,
unsigned int concatAxis)
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index 51fd219365..049680aafe 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -469,6 +469,8 @@ ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgNoPeepholeNoProjection,
LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest)
ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjection,
LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest)
+ARMNN_AUTO_TEST_CASE(LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNorm,
+ LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest)
// Mean
ARMNN_AUTO_TEST_CASE(MeanSimpleFloat32, MeanSimpleTest<armnn::DataType::Float32>)
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
index c7f5f090ce..6dd9f4f698 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.cpp
@@ -97,6 +97,30 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
lstm_param.set_peephole_params(m_CellToForgetWeightsTensor.get(), m_CellToOutputWeightsTensor.get());
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
+ }
+
+ m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
+
+ m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
+
+ m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::Tensor>();
+ BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
+
+ lstm_param.set_layer_normalization_params(m_Data.m_Parameters.m_CifgEnabled ?
+ nullptr : m_InputLayerNormWeightsTensor.get(),
+ m_ForgetLayerNormWeightsTensor.get(),
+ m_CellLayerNormWeightsTensor.get(),
+ m_OutputLayerNormWeightsTensor.get());
+ }
+
const arm_compute::ITensor& input = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
const arm_compute::ITensor& output_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
const arm_compute::ITensor& cell_state_in = static_cast<IAclTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
@@ -113,13 +137,13 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
m_ScratchBuffer = std::make_unique<arm_compute::Tensor>();
if (m_Data.m_Parameters.m_CifgEnabled)
{
- // 2D tensor with dimensions [num_units * 4, batch_size] with CIFG
+ // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG
armnn::TensorInfo scratchBuffer1({ batch_size, num_units * 3 }, DataType::Float32);
BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer1);
}
else
{
- // scratch_buffer [num_units * 3, batch_size] without CIFG
+ // scratch_buffer [num_units * 4, batch_size] without CIFG
armnn::TensorInfo scratchBuffer2({ batch_size, num_units * 4 }, DataType::Float32);
BuildArmComputeTensor(*m_ScratchBuffer, scratchBuffer2);
}
@@ -222,6 +246,17 @@ NeonLstmFloatWorkload::NeonLstmFloatWorkload(const LstmQueueDescriptor &descript
m_Data.m_CellToOutputWeights);
}
+ if (m_Data.m_Parameters.m_LayerNormEnabled)
+ {
+ if (!m_Data.m_Parameters.m_CifgEnabled)
+ {
+ InitializeArmComputeTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
+ }
+ InitializeArmComputeTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
+ InitializeArmComputeTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
+ InitializeArmComputeTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
+ }
+
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
m_LstmLayer.prepare();
@@ -241,27 +276,11 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights)
+ const LstmInputParamsInfo& paramsInfo)
{
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
- // The inputs and the outputs
+ // The inputs and outputs
const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
@@ -271,18 +290,24 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
// Basic parameters
- const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights);
- const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights);
- const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights);
+ const arm_compute::TensorInfo aclInputToForgetWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights());
+ const arm_compute::TensorInfo aclInputToCellWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights());
+ const arm_compute::TensorInfo aclInputToOutputWeightsInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights());
const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToForgetWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights());
const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToCellWeights);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights());
const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
- = BuildArmComputeTensorInfo(recurrentToOutputWeights);
- const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias);
- const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias);
- const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias);
+ = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights());
+ const arm_compute::TensorInfo aclForgetGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias());
+ const arm_compute::TensorInfo aclCellBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_CellBias());
+ const arm_compute::TensorInfo aclOutputGateBiasInfo
+ = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias());
arm_compute::TensorInfo aclInputToInputWeightsInfo;
arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
@@ -293,48 +318,65 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input,
arm_compute::TensorInfo aclCellToForgetWeightsInfo;
arm_compute::TensorInfo aclCellToOutputWeightsInfo;
+ arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
+ arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
+
+
if (!descriptor.m_CifgEnabled)
{
- armnn::TensorInfo inputToInputWInfo = *inputToInputWeights;
- aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo);
- armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights;
- aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo);
-
- if (cellToInputWeights != nullptr)
+ if (descriptor.m_PeepholeEnabled)
{
- armnn::TensorInfo cellToInputWInfo = *cellToInputWeights;
- aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo);
+ aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights());
}
- armnn::TensorInfo inputGateBiasInfo = *inputGateBias;
- aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo);
+ aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights());
+ aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights());
+ aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias());
+
lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo,
- cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr,
+ descriptor.m_PeepholeEnabled ? &aclCellToInputWeightsInfo : nullptr,
&aclInputGateBiasInfo);
}
if (descriptor.m_ProjectionEnabled)
{
- const armnn::TensorInfo& projectionWInfo = *projectionWeights;
- aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo);
-
- if (projectionBias != nullptr)
+ if (paramsInfo.m_ProjectionBias != nullptr)
{
- const armnn::TensorInfo& projectionBiasInfo = *projectionBias;
- aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo);
+ aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionBias());
}
+ aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights());
+
lstm_params_info.set_projection_params(&aclProjectionWeightsInfo,
- projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr);
+ paramsInfo.m_ProjectionBias != nullptr ?
+ &aclProjectionBiasInfo : nullptr);
}
if (descriptor.m_PeepholeEnabled)
{
- const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights;
- aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo);
- const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights;
- aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo);
+ aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights());
+ aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights());
+
lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo);
}
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputLayerNormWeights());
+ }
+ aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetLayerNormWeights());
+ aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellLayerNormWeights());
+ aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputLayerNormWeights());
+
+ lstm_params_info.set_layer_normalization_params(descriptor.m_CifgEnabled ?
+ nullptr : &aclInputLayerNormWeightsInfo,
+ &aclForgetLayerNormWeightsInfo,
+ &aclCellLayerNormWeightsInfo,
+ &aclOutputLayerNormWeightsInfo);
+ }
+
float cell_threshold = descriptor.m_ClippingThresCell;
float projection_threshold = descriptor.m_ClippingThresProj;
@@ -407,6 +449,10 @@ void NeonLstmFloatWorkload::FreeUnusedTensors()
FreeTensorIfUnused(m_ProjectionWeightsTensor);
FreeTensorIfUnused(m_ProjectionBiasTensor);
FreeTensorIfUnused(m_ScratchBuffer);
+ FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
+ FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
}
} //namespace armnn
diff --git a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
index f87f24d88a..c116cdd967 100644
--- a/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
+++ b/src/backends/neon/workloads/NeonLstmFloatWorkload.hpp
@@ -43,6 +43,11 @@ private:
std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
+ std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
+ std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
+
void FreeUnusedTensors();
};
@@ -50,21 +55,6 @@ arm_compute::Status NeonLstmFloatWorkloadValidate(const TensorInfo& input, const
const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
const TensorInfo& output, const LstmDescriptor &descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights);
+ const LstmInputParamsInfo& paramsInfo);
} //namespace armnn