aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2019-07-17 11:27:46 +0100
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-07-24 10:40:13 +0100
commitee18dc8d1725f472850ab0c398fd7cbc4b850891 (patch)
treeb57738b18781d512f5438ca5154652571393e4e8 /src/armnn
parent7b1845206d723a91aec811edaf7cb0cf832dfd25 (diff)
downloadarmnn-ee18dc8d1725f472850ab0c398fd7cbc4b850891.tar.gz
IVGCVSW-3469 Add front end for Quantized LSTM layer
* Added new layer QuantizedLstm (Android Q) * Made necessary changes to APIs * Added unit tests Change-Id: I3b9f16b0e7e49f51932cf204c87cb7118798123a Signed-off-by: James Conroy <james.conroy@arm.com>
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/InternalTypes.hpp1
-rw-r--r--src/armnn/LayerSupport.cpp23
-rw-r--r--src/armnn/LayersFwd.hpp2
-rw-r--r--src/armnn/Network.cpp38
-rw-r--r--src/armnn/Network.hpp4
-rw-r--r--src/armnn/layers/QuantizedLstmLayer.cpp290
-rw-r--r--src/armnn/layers/QuantizedLstmLayer.hpp87
-rw-r--r--src/armnn/test/ConstTensorLayerVisitor.cpp237
-rw-r--r--src/armnn/test/ConstTensorLayerVisitor.hpp29
-rw-r--r--src/armnn/test/InferOutputTests.cpp3
-rw-r--r--src/armnn/test/InferOutputTests.hpp48
11 files changed, 757 insertions, 5 deletions
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index bf095ac8a2..b0fea7c8c2 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -51,6 +51,7 @@ enum class LayerType
PreCompiled,
Prelu,
Quantize,
+ QuantizedLstm,
Reshape,
Resize,
Rsqrt,
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index a2908aae33..047c80a8c4 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -446,14 +446,29 @@ bool IsPadSupported(const BackendId& backend,
}
bool IsQuantizeSupported(const BackendId& backend,
- const TensorInfo& input,
- const TensorInfo& output,
- char* reasonIfUnsupported,
- size_t reasonIfUnsupportedMaxLength)
+ const TensorInfo& input,
+ const TensorInfo& output,
+ char* reasonIfUnsupported,
+ size_t reasonIfUnsupportedMaxLength)
{
FORWARD_LAYER_SUPPORT_FUNC(backend, IsQuantizeSupported, input, output);
}
+bool IsQuantizedLstmSupported(const BackendId& backend,
+ const TensorInfo& input,
+ const TensorInfo& previousCellStateIn,
+ const TensorInfo& previousOutputIn,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& output,
+ const QuantizedLstmInputParamsInfo& paramsInfo,
+ char* reasonIfUnsupported,
+ size_t reasonIfUnsupportedMaxLength)
+
+{
+ FORWARD_LAYER_SUPPORT_FUNC(backend, IsQuantizedLstmSupported, input, previousCellStateIn, previousOutputIn,
+ cellStateOut, output, paramsInfo);
+}
+
bool IsPermuteSupported(const BackendId& backend,
const TensorInfo& input,
const TensorInfo& output,
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index b3f7adc02c..2c8d5d2e07 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -43,6 +43,7 @@
#include "layers/PreCompiledLayer.hpp"
#include "layers/PreluLayer.hpp"
#include "layers/QuantizeLayer.hpp"
+#include "layers/QuantizedLstmLayer.hpp"
#include "layers/ReshapeLayer.hpp"
#include "layers/ResizeLayer.hpp"
#include "layers/RsqrtLayer.hpp"
@@ -120,6 +121,7 @@ DECLARE_LAYER(Pooling2d)
DECLARE_LAYER(PreCompiled)
DECLARE_LAYER(Prelu)
DECLARE_LAYER(Quantize)
+DECLARE_LAYER(QuantizedLstm)
DECLARE_LAYER(Reshape)
DECLARE_LAYER(Resize)
DECLARE_LAYER(Rsqrt)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index a43800827f..2195c71735 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1458,6 +1458,44 @@ IConnectableLayer* Network::AddStackLayer(const StackDescriptor& stackDescriptor
return m_Graph->AddLayer<StackLayer>(stackDescriptor, name);
}
+IConnectableLayer* Network::AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
+ const char* name)
+{
+ const auto layer = m_Graph->AddLayer<QuantizedLstmLayer>(name);
+
+ // InputToX weights
+ layer->m_QuantizedLstmParameters.m_InputToInputWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_InputToInputWeights());
+ layer->m_QuantizedLstmParameters.m_InputToForgetWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_InputToForgetWeights());
+ layer->m_QuantizedLstmParameters.m_InputToCellWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_InputToCellWeights());
+ layer->m_QuantizedLstmParameters.m_InputToOutputWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_InputToOutputWeights());
+
+ // RecurrentToX weights
+ layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToInputWeights());
+ layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToForgetWeights());
+ layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToCellWeights());
+ layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToOutputWeights());
+
+ // Bias
+ layer->m_QuantizedLstmParameters.m_InputGateBias =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_InputGateBias());
+ layer->m_QuantizedLstmParameters.m_ForgetGateBias =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_ForgetGateBias());
+ layer->m_QuantizedLstmParameters.m_CellBias =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_CellBias());
+ layer->m_QuantizedLstmParameters.m_OutputGateBias =
+ std::make_unique<ScopedCpuTensorHandle>(params.get_OutputGateBias());
+
+ return layer;
+}
+
void Network::Accept(ILayerVisitor& visitor) const
{
for (auto layer : GetGraph())
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index 8a99debb47..679ab51d43 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -6,6 +6,7 @@
#include <armnn/DescriptorsFwd.hpp>
#include <armnn/LstmParams.hpp>
+#include <armnn/QuantizedLstmParams.hpp>
#include <armnn/TensorFwd.hpp>
#include <armnn/Types.hpp>
@@ -200,6 +201,9 @@ public:
IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
const char* name = nullptr) override;
+ IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
+ const char* name = nullptr) override;
+
void Accept(ILayerVisitor& visitor) const override;
private:
diff --git a/src/armnn/layers/QuantizedLstmLayer.cpp b/src/armnn/layers/QuantizedLstmLayer.cpp
new file mode 100644
index 0000000000..1d8540d563
--- /dev/null
+++ b/src/armnn/layers/QuantizedLstmLayer.cpp
@@ -0,0 +1,290 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "QuantizedLstmLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <armnn/TypesUtils.hpp>
+#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
+ : Layer(3, 2, LayerType::QuantizedLstm, name)
+{
+}
+
+std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ QuantizedLstmQueueDescriptor descriptor;
+
+ // QuantizedLstmLayer parameters - there are no optional params
+ descriptor.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights.get();
+ descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
+ descriptor.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights.get();
+ descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
+
+ descriptor.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
+ descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
+ descriptor.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
+ descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
+
+ descriptor.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias.get();
+ descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
+ descriptor.m_CellBias = m_QuantizedLstmParameters.m_CellBias.get();
+ descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
+
+ return factory.CreateQuantizedLstm(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
+{
+ auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
+
+ layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToInputWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToForgetWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToCellWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToOutputWeights) : nullptr;
+
+ layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToInputWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
+ ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToForgetWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToCellWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
+ ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToOutputWeights) : nullptr;
+
+ layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputGateBias) : nullptr;
+ layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_ForgetGateBias) : nullptr;
+ layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_CellBias) : nullptr;
+ layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_OutputGateBias) : nullptr;
+
+ return std::move(layer);
+}
+
+std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 3);
+
+ // Get input values for validation
+ unsigned int numBatches = inputShapes[0][0];
+ unsigned int outputSize = inputShapes[1][1];
+
+ std::vector<TensorShape> outShapes;
+ outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
+ outShapes.push_back(TensorShape({numBatches, outputSize})); // output
+
+ return outShapes;
+}
+
+void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(3, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes(
+ {
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
+ GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() // previousOutputIn
+ });
+
+ BOOST_ASSERT(inferredShapes.size() == 2);
+
+ // Check weights and bias for nullptr
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
+
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
+
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
+
+ // Check output TensorShape(s) match inferred shape
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "QuantizedLstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "QuantizedLstmLayer: TensorShape set on OutputSlot[1] does not match the inferred shape.",
+ GetOutputSlot(1).GetTensorInfo().GetShape(),
+ inferredShapes[1]);
+}
+
+Layer::ConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef()
+{
+ return
+ {
+ m_QuantizedLstmParameters.m_InputToInputWeights,
+ m_QuantizedLstmParameters.m_InputToForgetWeights,
+ m_QuantizedLstmParameters.m_InputToCellWeights,
+ m_QuantizedLstmParameters.m_InputToOutputWeights,
+
+ m_QuantizedLstmParameters.m_RecurrentToInputWeights,
+ m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
+ m_QuantizedLstmParameters.m_RecurrentToCellWeights,
+ m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
+
+ m_QuantizedLstmParameters.m_InputGateBias,
+ m_QuantizedLstmParameters.m_ForgetGateBias,
+ m_QuantizedLstmParameters.m_CellBias,
+ m_QuantizedLstmParameters.m_OutputGateBias
+ };
+}
+
+void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
+{
+ QuantizedLstmInputParams inputParams;
+
+ // InputToX weight tensors
+ ConstTensor inputToInputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
+ {
+ ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToInputWeights->Map(true));
+ inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
+ inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
+ }
+
+ ConstTensor inputToForgetWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
+ {
+ ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true));
+ inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
+ inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
+ }
+
+ ConstTensor inputToCellWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
+ {
+ ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToCellWeights->Map(true));
+ inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
+ inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
+ }
+
+ ConstTensor inputToOutputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
+ {
+ ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true));
+ inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
+ inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
+ }
+
+ // RecurrentToX weight tensors
+ ConstTensor recurrentToInputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
+ {
+ ConstTensor recurrentToInputWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true));
+ recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
+ inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
+ }
+
+ ConstTensor recurrentToForgetWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
+ {
+ ConstTensor recurrentToForgetWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true));
+ recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
+ inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
+ }
+
+ ConstTensor recurrentToCellWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
+ {
+ ConstTensor recurrentToCellWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true));
+ recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
+ inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
+ }
+
+ ConstTensor recurrentToOutputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
+ {
+ ConstTensor recurrentToOutputWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true));
+ recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
+ inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
+ }
+
+ // Bias tensors
+ ConstTensor inputGateBiasTensor;
+ if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
+ {
+ ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputGateBias->Map(true));
+ inputGateBiasTensor = inputGateBiasTensorCopy;
+ inputParams.m_InputGateBias = &inputGateBiasTensor;
+ }
+
+ ConstTensor forgetGateBiasTensor;
+ if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
+ {
+ ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_ForgetGateBias->Map(true));
+ forgetGateBiasTensor = forgetGateBiasTensorCopy;
+ inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
+ }
+
+ ConstTensor cellBiasTensor;
+ if (m_QuantizedLstmParameters.m_CellBias != nullptr)
+ {
+ ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_CellBias->Map(true));
+ cellBiasTensor = cellBiasTensorCopy;
+ inputParams.m_CellBias = &cellBiasTensor;
+ }
+
+ ConstTensor outputGateBiasTensor;
+ if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
+ {
+ ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_OutputGateBias->Map(true));
+ outputGateBiasTensor = outputGateBiasCopy;
+ inputParams.m_OutputGateBias = &outputGateBiasTensor;
+ }
+
+ visitor.VisitQuantizedLstmLayer(this, inputParams, GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/QuantizedLstmLayer.hpp b/src/armnn/layers/QuantizedLstmLayer.hpp
new file mode 100644
index 0000000000..4602f71114
--- /dev/null
+++ b/src/armnn/layers/QuantizedLstmLayer.hpp
@@ -0,0 +1,87 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <Layer.hpp>
+
+namespace armnn
+{
+
+class ScopedCpuTensorHandle;
+
+struct QuantizedLstmParameters
+{
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeights;
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeights;
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeights;
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeights;
+
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeights;
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeights;
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeights;
+ /// A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8).
+ std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeights;
+
+ /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
+ std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBias;
+ /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
+ std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBias;
+ /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
+ std::unique_ptr<ScopedCpuTensorHandle> m_CellBias;
+ /// A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
+ std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBias;
+};
+
+/// This layer represents a QuantizedLstm operation.
+class QuantizedLstmLayer : public Layer
+{
+public:
+
+ QuantizedLstmParameters m_QuantizedLstmParameters;
+
+ /// Makes a workload for the QuantizedLstm type.
+ /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ QuantizedLstmLayer* Clone(Graph& graph) const override;
+
+ /// Check if the input tensor shape(s)
+ /// will lead to a valid configuration of @ref QuantizedLstmLayer.
+ void ValidateTensorShapesFromInputs() override;
+
+ /// By default returns inputShapes if the number of inputs are equal to number of outputs,
+ /// otherwise infers the output shapes from given input shapes and layer properties.
+ /// @param [in] inputShapes The input shapes layer has.
+ /// @return A vector to the inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a QuantizedLstmLayer.
+ /// @param [in] name Optional name for the layer.
+ QuantizedLstmLayer(const char* name);
+
+ /// Default destructor
+ ~QuantizedLstmLayer() = default;
+
+ /// Retrieve the handles to the constant values stored by the layer.
+ /// @return A vector of the constant tensors stored by this layer.
+ Layer::ConstantTensors GetConstantTensorsByRef() override;
+};
+
+} // namespace armnn
diff --git a/src/armnn/test/ConstTensorLayerVisitor.cpp b/src/armnn/test/ConstTensorLayerVisitor.cpp
index e17ee46c81..cfcdb1d2ff 100644
--- a/src/armnn/test/ConstTensorLayerVisitor.cpp
+++ b/src/armnn/test/ConstTensorLayerVisitor.cpp
@@ -107,6 +107,64 @@ void TestLstmLayerVisitor::CheckInputParameters(const LstmInputParams& inputPara
CheckConstTensorPtrs("CellBias", m_InputParams.m_CellBias, inputParams.m_CellBias);
}
+void TestQuantizedLstmLayerVisitor::CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const ConstTensor* actual)
+{
+ if (expected == nullptr)
+ {
+ BOOST_CHECK_MESSAGE(actual == nullptr, name + " actual should have been a nullptr");
+ }
+ else
+ {
+ BOOST_CHECK_MESSAGE(actual != nullptr, name + " actual should have been set");
+ if (actual != nullptr)
+ {
+ CheckConstTensors(*expected, *actual);
+ }
+ }
+}
+
+void TestQuantizedLstmLayerVisitor::CheckInputParameters(const QuantizedLstmInputParams& inputParams)
+{
+ CheckConstTensorPtrs("InputToInputWeights",
+ m_InputParams.m_InputToInputWeights,
+ inputParams.m_InputToInputWeights);
+
+ CheckConstTensorPtrs("InputToForgetWeights",
+ m_InputParams.m_InputToForgetWeights,
+ inputParams.m_InputToForgetWeights);
+
+ CheckConstTensorPtrs("InputToCellWeights",
+ m_InputParams.m_InputToCellWeights,
+ inputParams.m_InputToCellWeights);
+
+ CheckConstTensorPtrs("InputToOutputWeights",
+ m_InputParams.m_InputToOutputWeights,
+ inputParams.m_InputToOutputWeights);
+
+ CheckConstTensorPtrs("RecurrentToInputWeights",
+ m_InputParams.m_RecurrentToInputWeights,
+ inputParams.m_RecurrentToInputWeights);
+
+ CheckConstTensorPtrs("RecurrentToForgetWeights",
+ m_InputParams.m_RecurrentToForgetWeights,
+ inputParams.m_RecurrentToForgetWeights);
+
+ CheckConstTensorPtrs("RecurrentToCellWeights",
+ m_InputParams.m_RecurrentToCellWeights,
+ inputParams.m_RecurrentToCellWeights);
+
+ CheckConstTensorPtrs("RecurrentToOutputWeights",
+ m_InputParams.m_RecurrentToOutputWeights,
+ inputParams.m_RecurrentToOutputWeights);
+
+ CheckConstTensorPtrs("InputGateBias", m_InputParams.m_InputGateBias, inputParams.m_InputGateBias);
+ CheckConstTensorPtrs("ForgetGateBias", m_InputParams.m_ForgetGateBias, inputParams.m_ForgetGateBias);
+ CheckConstTensorPtrs("CellBias", m_InputParams.m_CellBias, inputParams.m_CellBias);
+ CheckConstTensorPtrs("OutputGateBias", m_InputParams.m_OutputGateBias, inputParams.m_OutputGateBias);
+}
+
BOOST_AUTO_TEST_SUITE(TestConstTensorLayerVisitor)
BOOST_AUTO_TEST_CASE(CheckConvolution2dLayer)
@@ -1185,6 +1243,185 @@ BOOST_AUTO_TEST_CASE(CheckNamedLstmLayerProjection)
layer->Accept(visitor);
}
+BOOST_AUTO_TEST_CASE(CheckQuantizedLstmLayer)
+{
+ std::vector<uint8_t> inputToInputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToInputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToInputWeights(
+ TensorInfo(4, inputToInputWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToInputWeightsData);
+
+ std::vector<uint8_t> inputToForgetWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToForgetWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToForgetWeights(
+ TensorInfo(4, inputToForgetWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToForgetWeightsData);
+
+ std::vector<uint8_t> inputToCellWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToCellWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToCellWeights(
+ TensorInfo(4, inputToCellWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToCellWeightsData);
+
+ std::vector<uint8_t> inputToOutputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToOutputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToOutputWeights(
+ TensorInfo(4, inputToOutputWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToOutputWeightsData);
+
+
+ std::vector<uint8_t> recurrentToInputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToInputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToInputWeights(TensorInfo(
+ 4, recurrentToInputWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToInputWeightsData);
+
+ std::vector<uint8_t> recurrentToForgetWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToForgetWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToForgetWeights(TensorInfo(
+ 4, recurrentToForgetWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToForgetWeightsData);
+
+ std::vector<uint8_t> recurrentToCellWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToCellWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToCellWeights(TensorInfo(
+ 4, recurrentToCellWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToCellWeightsData);
+
+ std::vector<uint8_t> recurrentToOutputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToOutputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToOutputWeights(TensorInfo(
+ 4, recurrentToOutputWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToOutputWeightsData);
+
+
+ std::vector<int32_t> inputGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputGateBiasDimensions = {1, 1, 3, 3};
+ ConstTensor inputGateBias(
+ TensorInfo(4, inputGateBiasDimensions.data(), DataType::Signed32), inputGateBiasData);
+
+ std::vector<int32_t> forgetGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> forgetGateBiasDimensions = {1, 1, 3, 3};
+ ConstTensor forgetGateBias(TensorInfo(
+ 4, forgetGateBiasDimensions.data(), DataType::Signed32), forgetGateBiasData);
+
+ std::vector<int32_t> cellBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> cellBiasDimensions = {1, 1, 3, 3};
+ ConstTensor cellBias(TensorInfo(
+ 4, cellBiasDimensions.data(), DataType::Signed32), cellBiasData);
+
+ std::vector<int32_t> outputGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> outputGateBiasDimensions = {1, 1, 3, 3};
+ ConstTensor outputGateBias(TensorInfo(
+ 4, outputGateBiasDimensions.data(), DataType::Signed32), outputGateBiasData);
+
+ QuantizedLstmInputParams params;
+
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_InputGateBias = &inputGateBias;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ TestQuantizedLstmLayerVisitor visitor(params);
+
+ Network net;
+
+ IConnectableLayer* const layer = net.AddQuantizedLstmLayer(params);
+ layer->Accept(visitor);
+}
+
+BOOST_AUTO_TEST_CASE(CheckNamedQuantizedLstmLayer)
+{
+ const char* layerName = "LstmLayer";
+ std::vector<uint8_t> inputToInputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToInputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToInputWeights(
+ TensorInfo(4, inputToInputWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToInputWeightsData);
+
+ std::vector<uint8_t> inputToForgetWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToForgetWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToForgetWeights(
+ TensorInfo(4, inputToForgetWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToForgetWeightsData);
+
+ std::vector<uint8_t> inputToCellWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToCellWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToCellWeights(
+ TensorInfo(4, inputToCellWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToCellWeightsData);
+
+ std::vector<uint8_t> inputToOutputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputToOutputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor inputToOutputWeights(
+ TensorInfo(4, inputToOutputWeightsDimensions.data(), DataType::QuantisedAsymm8), inputToOutputWeightsData);
+
+
+ std::vector<uint8_t> recurrentToInputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToInputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToInputWeights(TensorInfo(
+ 4, recurrentToInputWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToInputWeightsData);
+
+ std::vector<uint8_t> recurrentToForgetWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToForgetWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToForgetWeights(TensorInfo(
+ 4, recurrentToForgetWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToForgetWeightsData);
+
+ std::vector<uint8_t> recurrentToCellWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToCellWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToCellWeights(TensorInfo(
+ 4, recurrentToCellWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToCellWeightsData);
+
+ std::vector<uint8_t> recurrentToOutputWeightsData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> recurrentToOutputWeightsDimensions = {1, 1, 3, 3};
+ ConstTensor recurrentToOutputWeights(TensorInfo(
+ 4, recurrentToOutputWeightsDimensions.data(), DataType::QuantisedAsymm8), recurrentToOutputWeightsData);
+
+
+ std::vector<int32_t> inputGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> inputGateBiasDimensions = {1, 1, 3, 3};
+ ConstTensor inputGateBias(
+ TensorInfo(4, inputGateBiasDimensions.data(), DataType::Signed32), inputGateBiasData);
+
+ std::vector<int32_t> forgetGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> forgetGateBiasDimensions = {1, 1, 3, 3};
+ ConstTensor forgetGateBias(TensorInfo(
+ 4, forgetGateBiasDimensions.data(), DataType::Signed32), forgetGateBiasData);
+
+ std::vector<int32_t> cellBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> cellBiasDimensions = {1, 1, 3, 3};
+ ConstTensor cellBias(TensorInfo(
+ 4, cellBiasDimensions.data(), DataType::Signed32), cellBiasData);
+
+ std::vector<int32_t> outputGateBiasData = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<unsigned int> outputGateBiasDimensions = {1, 1, 3, 3};
+ ConstTensor outputGateBias(TensorInfo(
+ 4, outputGateBiasDimensions.data(), DataType::Signed32), outputGateBiasData);
+
+ QuantizedLstmInputParams params;
+
+ params.m_InputToInputWeights = &inputToInputWeights;
+ params.m_InputToForgetWeights = &inputToForgetWeights;
+ params.m_InputToCellWeights = &inputToCellWeights;
+ params.m_InputToOutputWeights = &inputToOutputWeights;
+
+ params.m_RecurrentToInputWeights = &recurrentToInputWeights;
+ params.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ params.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ params.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+
+ params.m_InputGateBias = &inputGateBias;
+ params.m_ForgetGateBias = &forgetGateBias;
+ params.m_CellBias = &cellBias;
+ params.m_OutputGateBias = &outputGateBias;
+
+ TestQuantizedLstmLayerVisitor visitor(params, layerName);
+
+ Network net;
+
+ IConnectableLayer* const layer = net.AddQuantizedLstmLayer(params, layerName);
+ layer->Accept(visitor);
+}
+
BOOST_AUTO_TEST_SUITE_END()
} // namespace armnn
diff --git a/src/armnn/test/ConstTensorLayerVisitor.hpp b/src/armnn/test/ConstTensorLayerVisitor.hpp
index 80409b331f..203c5fd91b 100644
--- a/src/armnn/test/ConstTensorLayerVisitor.hpp
+++ b/src/armnn/test/ConstTensorLayerVisitor.hpp
@@ -7,6 +7,7 @@
#include "TestLayerVisitor.hpp"
#include <armnn/Descriptors.hpp>
#include <armnn/LstmParams.hpp>
+#include <armnn/QuantizedLstmParams.hpp>
namespace armnn
{
@@ -220,4 +221,32 @@ private:
LstmInputParams m_InputParams;
};
+
+class TestQuantizedLstmLayerVisitor : public TestLayerVisitor
+{
+public:
+ explicit TestQuantizedLstmLayerVisitor(const QuantizedLstmInputParams& params,
+ const char* name = nullptr)
+ : TestLayerVisitor(name)
+ , m_InputParams(params)
+ {}
+
+ void VisitQuantizedLstmLayer(const IConnectableLayer* layer,
+ const QuantizedLstmInputParams& params,
+ const char* name = nullptr)
+ {
+ CheckLayerPointer(layer);
+ CheckLayerName(name);
+ CheckInputParameters(params);
+ }
+
+protected:
+ void CheckInputParameters(const QuantizedLstmInputParams& inputParams);
+ void CheckConstTensorPtrs(const std::string& name, const ConstTensor* expected, const ConstTensor* actual);
+
+private:
+ QuantizedLstmInputParams m_InputParams;
+};
+
+
} // namespace armnn
diff --git a/src/armnn/test/InferOutputTests.cpp b/src/armnn/test/InferOutputTests.cpp
index 4581d87a5b..8606745623 100644
--- a/src/armnn/test/InferOutputTests.cpp
+++ b/src/armnn/test/InferOutputTests.cpp
@@ -40,4 +40,7 @@ ARMNN_SIMPLE_TEST_CASE(DepthwiseConvolution2dInferOutputShape, DepthwiseConvolut
// TransposeConvolution2D
ARMNN_SIMPLE_TEST_CASE(TransposeConvolution2dInferOutputShape, TransposeConvolution2dInferOutputShapeTest)
+// QuantizedLstm
+ARMNN_SIMPLE_TEST_CASE(QuantizedLstmInferOutputShape, QuantizedLstmInferOutputShapeTest)
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnn/test/InferOutputTests.hpp b/src/armnn/test/InferOutputTests.hpp
index 58a081a130..2dd2ff0e73 100644
--- a/src/armnn/test/InferOutputTests.hpp
+++ b/src/armnn/test/InferOutputTests.hpp
@@ -443,4 +443,50 @@ void DepthwiseConvolution2dInferOutputShapeTest()
armnn::TensorShape expectedOutputShape(4, expectedOutputSizes.data());
BOOST_CHECK(expectedOutputShape == depthwiseConvolution2dLayer->InferOutputShapes(shapes).at(0));
-} \ No newline at end of file
+}
+
+// QuantizedLstm
+void QuantizedLstmInferOutputShapeImpl(const std::vector<armnn::TensorShape>& inputShapes,
+ std::vector<armnn::TensorShape>& outputShapes)
+{
+ armnn::Graph graph;
+ armnn::QuantizedLstmLayer* const quantizedLstmLayer = graph.AddLayer<armnn::QuantizedLstmLayer>("quantizedLstm");
+ outputShapes = quantizedLstmLayer->InferOutputShapes(inputShapes);
+}
+
+void QuantizedLstmInferOutputShapeTest()
+{
+ // Input shapes
+ const std::vector<unsigned int> inputShape{ 2, 5 };
+ const std::vector<unsigned int> previousCellStateInShape{ 2, 10 };
+ const std::vector<unsigned int> previousOutputInShape{ 2, 10 };
+ armnn::TensorShape inputTensorShape(2, inputShape.data());
+ armnn::TensorShape previousCellStateInTensorShape(2, previousCellStateInShape.data());
+ armnn::TensorShape previousOutputInTensorShape(2, previousOutputInShape.data());
+
+ std::vector<armnn::TensorShape> inShapes
+ {
+ inputTensorShape,
+ previousCellStateInTensorShape,
+ previousOutputInTensorShape
+ };
+
+ // Output shapes
+ const std::vector<unsigned int> cellStateOutShape{ 2, 10 };
+ const std::vector<unsigned int> outputShape{ 2, 10 };
+ armnn::TensorShape cellStateOutTensorShape(2, cellStateOutShape.data());
+ armnn::TensorShape outputTensorShape(2, outputShape.data());
+
+ std::vector<armnn::TensorShape> expectedOutShapes
+ {
+ cellStateOutTensorShape,
+ outputTensorShape
+ };
+
+ std::vector<armnn::TensorShape> actualOutShapes;
+ BOOST_CHECK_NO_THROW(QuantizedLstmInferOutputShapeImpl(inShapes, actualOutShapes));
+
+ BOOST_CHECK(actualOutShapes.size() == 2);
+ BOOST_CHECK(expectedOutShapes[0] == actualOutShapes[0]);
+ BOOST_CHECK(expectedOutShapes[1] == actualOutShapes[1]);
+}