aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefQLstmWorkload.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2021-03-22 17:51:06 +0000
committerfinn.williams <finn.williams@arm.com>2021-04-07 16:42:38 +0000
commit4422ceca976a88aac49b21808a43e465bc87a35e (patch)
treed4f7f3d86394f74b679c907ad3f7fc7f4537933f /src/backends/reference/workloads/RefQLstmWorkload.cpp
parentb70ec417989490a2a72c66ecd6c737df1c094f4c (diff)
downloadarmnn-4422ceca976a88aac49b21808a43e465bc87a35e.tar.gz
Fix graph copy memory spike
* Change layer storage of ConstTensors to std::shared_ptr<ConstCpuTensorHandle> * Change clone to share ConstTensor rather than copy * Remove uses of non-const GetTensor() call * Reduce scope of non-optimized network in ExeNet, so memory can be released after use Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: Ibb2c7309d12411d21405bd6024c76bcdf5404545
Diffstat (limited to 'src/backends/reference/workloads/RefQLstmWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefQLstmWorkload.cpp50
1 files changed, 27 insertions, 23 deletions
diff --git a/src/backends/reference/workloads/RefQLstmWorkload.cpp b/src/backends/reference/workloads/RefQLstmWorkload.cpp
index e11ea55add..bcd6a627de 100644
--- a/src/backends/reference/workloads/RefQLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefQLstmWorkload.cpp
@@ -101,18 +101,20 @@ void RefQLstmWorkload::Execute() const
// Weights decoders
std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
- m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
+ m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
- m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
+ m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
- m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
+ m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
- m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
+ m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
+ m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
- m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
+ m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
- m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
+ m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
+ m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
// Optional CIFG params
std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
@@ -198,9 +200,9 @@ void RefQLstmWorkload::Execute() const
if (!cifgEnabled)
{
inputToInputWeightsDecoder = MakeDecoder<float>(
- m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
- recurrentToInputWeightsDecoder = MakeDecoder<float>(
- m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
+ m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
+ recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
+ m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
}
if (peepholeEnabled)
@@ -208,22 +210,22 @@ void RefQLstmWorkload::Execute() const
if (!cifgEnabled)
{
cellToInputWeightsDecoder = MakeDecoder<float>(
- m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
+ m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
}
cellToForgetWeightsDecoder = MakeDecoder<float>(
- m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
+ m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
cellToOutputWeightsDecoder = MakeDecoder<float>(
- m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
+ m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
}
if (projectionEnabled)
{
projectionWeightsDecoder = MakeDecoder<float>(
- m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
+ m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
if (m_ProjectionBiasTensor)
{
projectionBiasDecoder = MakeDecoder<float>(
- m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
+ m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
}
}
@@ -231,38 +233,40 @@ void RefQLstmWorkload::Execute() const
{
if (!cifgEnabled)
{
- inputLayerNormWeightsDecoder = MakeDecoder<float>(
- m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<void>());
+ inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
+ m_InputLayerNormWeightsTensor->GetConstTensor<void>());
// Bias only used if layer norm enabled
armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
inputGateBiasDecoder = MakeDecoder<float>(
- inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<void>());
+ inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
}
forgetLayerNormWeightsDecoder = MakeDecoder<float>(
- m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<void>());
+ m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
+ m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
cellLayerNormWeightsDecoder = MakeDecoder<float>(
- m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<void>());
+ m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
outputLayerNormWeightsDecoder = MakeDecoder<float>(
- m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<void>());
+ m_OutputLayerNormWeightsTensor->GetTensorInfo(),
+ m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
// Bias only used if layer norm enabled
armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
forgetGateBiasDecoder = MakeDecoder<float>(
- forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<void>());
+ forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
cellGateBiasDecoder = MakeDecoder<float>(
- cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<void>());
+ cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
outputGateBiasDecoder = MakeDecoder<float>(
- outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<void>());
+ outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
}
// Initialize internal state tensors with zeroes.