aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp152
1 files changed, 151 insertions, 1 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index d340f021e2..83eafe7993 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -518,6 +518,14 @@ IConnectableLayer* INetwork::AddLogicalBinaryLayer(const LogicalBinaryDescriptor
return pNetworkImpl->AddLogicalBinaryLayer(descriptor, name);
}
+IConnectableLayer* INetwork::AddUnidirectionalSequenceLstmLayer(
+ const UnidirectionalSequenceLstmDescriptor& descriptor,
+ const LstmInputParams& params,
+ const char* name)
+{
+ return pNetworkImpl->AddUnidirectionalSequenceLstmLayer(descriptor, params, name);
+}
+
void INetwork::Accept(ILayerVisitor& visitor) const
{
return pNetworkImpl->Accept(visitor);
@@ -2603,11 +2611,153 @@ IConnectableLayer* NetworkImpl::AddQLstmLayer(const QLstmDescriptor& descriptor
}
IConnectableLayer* NetworkImpl::AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor,
- const char* name)
+ const char* name)
{
return m_Graph->AddLayer<LogicalBinaryLayer>(logicalBinaryDescriptor, name);
}
+IConnectableLayer* NetworkImpl::AddUnidirectionalSequenceLstmLayer(
+ const UnidirectionalSequenceLstmDescriptor& descriptor,
+ const LstmInputParams& params,
+ const char* name)
+{
+ const auto layer = m_Graph->AddLayer<UnidirectionalSequenceLstmLayer>(descriptor, name);
+
+ //Lstm Basic Parameters
+ layer->m_BasicParameters.m_InputToForgetWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_InputToForgetWeights));
+ layer->m_BasicParameters.m_InputToCellWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_InputToCellWeights));
+ layer->m_BasicParameters.m_InputToOutputWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_InputToOutputWeights));
+ layer->m_BasicParameters.m_RecurrentToForgetWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToForgetWeights));
+ layer->m_BasicParameters.m_RecurrentToCellWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToCellWeights));
+ layer->m_BasicParameters.m_RecurrentToOutputWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToOutputWeights));
+ layer->m_BasicParameters.m_ForgetGateBias =
+ std::make_shared<ScopedTensorHandle>(*(params.m_ForgetGateBias));
+ layer->m_BasicParameters.m_CellBias =
+ std::make_shared<ScopedTensorHandle>(*(params.m_CellBias));
+ layer->m_BasicParameters.m_OutputGateBias =
+ std::make_shared<ScopedTensorHandle>(*(params.m_OutputGateBias));
+
+ //Lstm Cifg parameters
+ if(!descriptor.m_CifgEnabled)
+ {
+ if(params.m_InputToInputWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input To Input Weights cannot be NULL "
+ "when CIFG is disabled.");
+ }
+ if(params.m_RecurrentToInputWeights == nullptr)
+ {
+ throw InvalidArgumentException(
+ "AddUnidirectionalSequenceLstmLayer: Recurrent To Input Weights cannot be NULL "
+ "when CIFG is disabled.");
+ }
+ if(params.m_InputGateBias == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input Gate Bias cannot be NULL "
+ "when CIFG is disabled.");
+ }
+ layer->m_CifgParameters.m_InputToInputWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_InputToInputWeights));
+ layer->m_CifgParameters.m_RecurrentToInputWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToInputWeights));
+ layer->m_CifgParameters.m_InputGateBias =
+ std::make_shared<ScopedTensorHandle>(*(params.m_InputGateBias));
+ }
+
+ //Lstm projection parameters
+ if(descriptor.m_ProjectionEnabled)
+ {
+ if(params.m_ProjectionWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Projection Weights cannot be NULL "
+ "when projection is enabled.");
+ }
+ layer->m_ProjectionParameters.m_ProjectionWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_ProjectionWeights));
+ if(params.m_ProjectionBias != nullptr)
+ {
+ layer->m_ProjectionParameters.m_ProjectionBias =
+ std::make_shared<ScopedTensorHandle>(*(params.m_ProjectionBias));
+ }
+ }
+
+ //Lstm Peephole params
+ if(descriptor.m_PeepholeEnabled)
+ {
+ if(!descriptor.m_CifgEnabled)
+ {
+ if(params.m_CellToInputWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Input Weights "
+ "cannot be NULL when Peephole is enabled and CIFG disabled.");
+ }
+
+ layer->m_PeepholeParameters.m_CellToInputWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_CellToInputWeights));
+ }
+
+ if(params.m_CellToForgetWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Forget Weights cannot be NULL "
+ "when Peephole is enabled.");
+ }
+ if(params.m_CellToOutputWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Output Weights cannot be NULL "
+ "when Peephole is enabled.");
+ }
+
+ layer->m_PeepholeParameters.m_CellToForgetWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_CellToForgetWeights));
+ layer->m_PeepholeParameters.m_CellToOutputWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_CellToOutputWeights));
+ }
+
+ //Lstm Layer Normalization params
+ if(descriptor.m_LayerNormEnabled)
+ {
+ if(!descriptor.m_CifgEnabled)
+ {
+ if(params.m_InputLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input layer normalization weights "
+ "cannot be NULL when layer normalization is enabled and CIFG disabled.");
+ }
+ layer->m_LayerNormParameters.m_InputLayerNormWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_InputLayerNormWeights));
+ }
+
+ if(params.m_ForgetLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Forget layer normalization weights "
+ "cannot be NULL when layer normalization is enabled.");
+ }
+ if(params.m_CellLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell layer normalization weights "
+ "cannot be NULL when layer normalization is enabled.");
+ }
+ if(params.m_OutputLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Output layer normalization weights "
+ "cannot be NULL when layer normalization is enabled.");
+ }
+ layer->m_LayerNormParameters.m_ForgetLayerNormWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_ForgetLayerNormWeights));
+ layer->m_LayerNormParameters.m_CellLayerNormWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_CellLayerNormWeights));
+ layer->m_LayerNormParameters.m_OutputLayerNormWeights =
+ std::make_shared<ScopedTensorHandle>(*(params.m_OutputLayerNormWeights));
+ }
+ return layer;
+}
+
void NetworkImpl::Accept(ILayerVisitor& visitor) const
{
for (auto layer : GetGraph())