aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/QuantizedLstmLayer.cpp
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2019-07-17 11:27:46 +0100
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-07-24 10:40:13 +0100
commitee18dc8d1725f472850ab0c398fd7cbc4b850891 (patch)
treeb57738b18781d512f5438ca5154652571393e4e8 /src/armnn/layers/QuantizedLstmLayer.cpp
parent7b1845206d723a91aec811edaf7cb0cf832dfd25 (diff)
downloadarmnn-ee18dc8d1725f472850ab0c398fd7cbc4b850891.tar.gz
IVGCVSW-3469 Add front end for Quantized LSTM layer
* Added new layer QuantizedLstm (Android Q) * Made necessary changes to APIs * Added unit tests Change-Id: I3b9f16b0e7e49f51932cf204c87cb7118798123a Signed-off-by: James Conroy <james.conroy@arm.com>
Diffstat (limited to 'src/armnn/layers/QuantizedLstmLayer.cpp')
-rw-r--r--src/armnn/layers/QuantizedLstmLayer.cpp290
1 files changed, 290 insertions, 0 deletions
diff --git a/src/armnn/layers/QuantizedLstmLayer.cpp b/src/armnn/layers/QuantizedLstmLayer.cpp
new file mode 100644
index 0000000000..1d8540d563
--- /dev/null
+++ b/src/armnn/layers/QuantizedLstmLayer.cpp
@@ -0,0 +1,290 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "QuantizedLstmLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <armnn/TypesUtils.hpp>
+#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
+ : Layer(3, 2, LayerType::QuantizedLstm, name)
+{
+}
+
+std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ QuantizedLstmQueueDescriptor descriptor;
+
+ // QuantizedLstmLayer parameters - there are no optional params
+ descriptor.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights.get();
+ descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
+ descriptor.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights.get();
+ descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
+
+ descriptor.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
+ descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
+ descriptor.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
+ descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
+
+ descriptor.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias.get();
+ descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
+ descriptor.m_CellBias = m_QuantizedLstmParameters.m_CellBias.get();
+ descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
+
+ return factory.CreateQuantizedLstm(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
+{
+ auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
+
+ layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToInputWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToForgetWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToCellWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputToOutputWeights) : nullptr;
+
+ layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToInputWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
+ ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToForgetWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToCellWeights) : nullptr;
+ layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
+ ? std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_RecurrentToOutputWeights) : nullptr;
+
+ layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_InputGateBias) : nullptr;
+ layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_ForgetGateBias) : nullptr;
+ layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_CellBias) : nullptr;
+ layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_QuantizedLstmParameters.m_OutputGateBias) : nullptr;
+
+ return std::move(layer);
+}
+
+std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 3);
+
+ // Get input values for validation
+ unsigned int numBatches = inputShapes[0][0];
+ unsigned int outputSize = inputShapes[1][1];
+
+ std::vector<TensorShape> outShapes;
+ outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
+ outShapes.push_back(TensorShape({numBatches, outputSize})); // output
+
+ return outShapes;
+}
+
+void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(3, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes(
+ {
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
+ GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() // previousOutputIn
+ });
+
+ BOOST_ASSERT(inferredShapes.size() == 2);
+
+ // Check weights and bias for nullptr
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
+
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
+
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
+ BOOST_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
+ "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
+
+ // Check output TensorShape(s) match inferred shape
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "QuantizedLstmLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "QuantizedLstmLayer: TensorShape set on OutputSlot[1] does not match the inferred shape.",
+ GetOutputSlot(1).GetTensorInfo().GetShape(),
+ inferredShapes[1]);
+}
+
+Layer::ConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef()
+{
+ return
+ {
+ m_QuantizedLstmParameters.m_InputToInputWeights,
+ m_QuantizedLstmParameters.m_InputToForgetWeights,
+ m_QuantizedLstmParameters.m_InputToCellWeights,
+ m_QuantizedLstmParameters.m_InputToOutputWeights,
+
+ m_QuantizedLstmParameters.m_RecurrentToInputWeights,
+ m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
+ m_QuantizedLstmParameters.m_RecurrentToCellWeights,
+ m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
+
+ m_QuantizedLstmParameters.m_InputGateBias,
+ m_QuantizedLstmParameters.m_ForgetGateBias,
+ m_QuantizedLstmParameters.m_CellBias,
+ m_QuantizedLstmParameters.m_OutputGateBias
+ };
+}
+
+void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
+{
+ QuantizedLstmInputParams inputParams;
+
+ // InputToX weight tensors
+ ConstTensor inputToInputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
+ {
+ ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToInputWeights->Map(true));
+ inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
+ inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
+ }
+
+ ConstTensor inputToForgetWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
+ {
+ ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true));
+ inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
+ inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
+ }
+
+ ConstTensor inputToCellWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
+ {
+ ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToCellWeights->Map(true));
+ inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
+ inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
+ }
+
+ ConstTensor inputToOutputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
+ {
+ ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true));
+ inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
+ inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
+ }
+
+ // RecurrentToX weight tensors
+ ConstTensor recurrentToInputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
+ {
+ ConstTensor recurrentToInputWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true));
+ recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
+ inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
+ }
+
+ ConstTensor recurrentToForgetWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
+ {
+ ConstTensor recurrentToForgetWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true));
+ recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
+ inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
+ }
+
+ ConstTensor recurrentToCellWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
+ {
+ ConstTensor recurrentToCellWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true));
+ recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
+ inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
+ }
+
+ ConstTensor recurrentToOutputWeightsTensor;
+ if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
+ {
+ ConstTensor recurrentToOutputWeightsTensorCopy(
+ m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true));
+ recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
+ inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
+ }
+
+ // Bias tensors
+ ConstTensor inputGateBiasTensor;
+ if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
+ {
+ ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_InputGateBias->Map(true));
+ inputGateBiasTensor = inputGateBiasTensorCopy;
+ inputParams.m_InputGateBias = &inputGateBiasTensor;
+ }
+
+ ConstTensor forgetGateBiasTensor;
+ if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
+ {
+ ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_ForgetGateBias->Map(true));
+ forgetGateBiasTensor = forgetGateBiasTensorCopy;
+ inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
+ }
+
+ ConstTensor cellBiasTensor;
+ if (m_QuantizedLstmParameters.m_CellBias != nullptr)
+ {
+ ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_CellBias->Map(true));
+ cellBiasTensor = cellBiasTensorCopy;
+ inputParams.m_CellBias = &cellBiasTensor;
+ }
+
+ ConstTensor outputGateBiasTensor;
+ if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
+ {
+ ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
+ m_QuantizedLstmParameters.m_OutputGateBias->Map(true));
+ outputGateBiasTensor = outputGateBiasCopy;
+ inputParams.m_OutputGateBias = &outputGateBiasTensor;
+ }
+
+ visitor.VisitQuantizedLstmLayer(this, inputParams, GetName());
+}
+
+} // namespace armnn