aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp148
1 files changed, 148 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index dc70e6a9c2..1c18551679 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -1277,6 +1277,147 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::UnidirectionalSequenceLstm:
+ {
+ auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
+ const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters();
+
+ // All inputs.
+ const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
+ dataType);
+ const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
+ dataType);
+ const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
+ dataType);
+ // Outputs
+ const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
+
+ // Basic parameters
+ const TensorInfo& inputToForgetWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToCellWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
+ const TensorInfo& inputToOutputWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToForgetWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToCellWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
+ const TensorInfo& recurrentToOutputWeights
+ = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
+ const TensorInfo& forgetGateBias
+ = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
+ const TensorInfo& cellBias
+ = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
+ const TensorInfo& outputGateBias
+ = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
+
+ LstmInputParamsInfo paramsInfo;
+
+ paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
+ paramsInfo.m_InputToCellWeights = &inputToCellWeights;
+ paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
+ paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
+ paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
+ paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
+ paramsInfo.m_ForgetGateBias = &forgetGateBias;
+ paramsInfo.m_CellBias = &cellBias;
+ paramsInfo.m_OutputGateBias = &outputGateBias;
+
+ // Optional parameters
+ TensorInfo optInputToInputWeights;
+ TensorInfo optRecurrentToInputWeights;
+ TensorInfo optCellToInputWeights;
+ TensorInfo optInputGateBias;
+ TensorInfo optProjectionWeights;
+ TensorInfo optProjectionBias;
+ TensorInfo optCellToForgetWeights;
+ TensorInfo optCellToOutputWeights;
+ TensorInfo optInputLayerNormWeights;
+ TensorInfo optForgetLayerNormWeights;
+ TensorInfo optCellLayerNormWeights;
+ TensorInfo optOutputLayerNormWeights;
+
+ if(!descriptor.m_CifgEnabled)
+ {
+ optInputToInputWeights =
+ OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
+
+ optRecurrentToInputWeights =
+ OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
+ optInputGateBias =
+ OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
+ paramsInfo.m_InputGateBias = &optInputGateBias;
+ }
+
+ if(descriptor.m_ProjectionEnabled)
+ {
+ optProjectionWeights =
+ OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_ProjectionWeights = &optProjectionWeights;
+ if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
+ {
+ optProjectionBias =
+ OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
+ paramsInfo.m_ProjectionBias = &optProjectionBias;
+ }
+ }
+
+ if(descriptor.m_PeepholeEnabled)
+ {
+ if(!descriptor.m_CifgEnabled)
+ {
+ optCellToInputWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
+ dataType);
+ paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
+ }
+ optCellToForgetWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
+ optCellToOutputWeights =
+ OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
+ }
+
+ if(descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ optInputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
+ }
+
+ optForgetLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
+
+ optCellLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
+
+ optOutputLayerNormWeights = OverrideDataType(
+ cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
+ paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
+ }
+
+ Optional<TensorInfo> hiddenStateOut;
+ Optional<TensorInfo> cellStateOut;
+
+ result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
+ outputStateIn,
+ cellStateIn,
+ output,
+ hiddenStateOut,
+ cellStateOut,
+ descriptor,
+ paramsInfo,
+ reason);
+ break;
+ }
default:
{
ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
@@ -1759,4 +1900,11 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateUnidirectionalSequenceLstm(
+ const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/,
+ const WorkloadInfo& /*info*/) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
} // namepsace armnn